In [1]:
model_name = "BAAI/bge-base-en-v1.5"
query_prefix = "Represent this sentence for searching relevant passages: "
max_len = 512
training_hn_file = "./data/hn-output.jsonl"
eval_file = "./data/eval-output.jsonl"
batch_size = 1280
output_model_path = "./bge-base-custom"

In [2]:
%matplotlib inline

from functools import partial
import itertools as it
import os
import random

from datasets import load_from_disk
from FlagEmbedding import FlagModel
import jsonlines as jsonl
from lion_pytorch import Lion
import matplotlib.pyplot as plt
import numpy as np
from numpy import dot
from numpy.linalg import norm
from sentence_transformers import InputExample, SentenceTransformer, losses as ls, models as ml, util
from sentence_transformers.evaluation import SimilarityFunction, TripletEvaluator
import torch
from torch.utils.data import DataLoader, IterableDataset
from tqdm.auto import tqdm

os.environ["CUDA_VISIBLE_DEVICES"] = "0,1,2"

In [3]:
def hn_output(file):
    with jsonl.open(file) as reader:
        for entry in reader:
            query = entry["query"]
            pos = [dict(dialog=dialog) for dialog in entry["pos"]]
            neg = [dict(dialog=dialog) for dialog in entry["neg"]]

            for combined in it.product(
                [dict(fact=query)],
                pos,
                neg,
            ):
                yield InputExample(texts=list(combined))

In [12]:
training_data = list(tqdm(hn_output(training_hn_file)))
eval_data = list(tqdm(hn_output(eval_file)))

0it [00:00, ?it/s]

0it [00:00, ?it/s]

In [13]:
dataloader = DataLoader(training_data, shuffle=True, batch_size=batch_size)
eval_dataloader = DataLoader(eval_data, shuffle=True, batch_size=batch_size // 10)

In [4]:
# Base model
base_model = SentenceTransformer(model_name)

Downloading (…)3ac3f/.gitattributes:   0%|          | 0.00/1.52k [00:00<?, ?B/s]

Downloading (…)_Pooling/config.json:   0%|          | 0.00/190 [00:00<?, ?B/s]

Downloading (…)440fe3ac3f/README.md:   0%|          | 0.00/89.0k [00:00<?, ?B/s]

Downloading (…)0fe3ac3f/config.json:   0%|          | 0.00/777 [00:00<?, ?B/s]

Downloading (…)ce_transformers.json:   0%|          | 0.00/124 [00:00<?, ?B/s]

Downloading pytorch_model.bin:   0%|          | 0.00/438M [00:00<?, ?B/s]

Downloading (…)nce_bert_config.json:   0%|          | 0.00/52.0 [00:00<?, ?B/s]

Downloading (…)cial_tokens_map.json:   0%|          | 0.00/125 [00:00<?, ?B/s]

Downloading (…)3ac3f/tokenizer.json:   0%|          | 0.00/711k [00:00<?, ?B/s]

Downloading (…)okenizer_config.json:   0%|          | 0.00/366 [00:00<?, ?B/s]

Downloading (…)440fe3ac3f/vocab.txt:   0%|          | 0.00/232k [00:00<?, ?B/s]

Downloading (…)fe3ac3f/modules.json:   0%|          | 0.00/229 [00:00<?, ?B/s]

In [5]:
# Freeze base transformer layers
for param in base_model.parameters():
    param.requires_grad = False

In [6]:
device = torch.device("cuda:0")

# Note that we must also set _target_device, or any SentenceTransformer.fit() call will reset
# the body location
base_model._target_device = device
base_model = base_model.to(device)

In [7]:
emb_dims = base_model._first_module().get_word_embedding_dimension() # 768

def dense_projector(dims: int):
    proj_dims = dims * 2  # 1536
    
    return [
        ml.Dense(dims, proj_dims),  # 768 -> 1536
        # ml.Dropout(0.1),
        ml.Dense(proj_dims, proj_dims), # 1536 -> 1536
        # ml.Dropout(0.1),
        ml.Dense(proj_dims, dims),  # 1536 -> 768
    ]

def asym_module(dims: int, keys: list[str], allow_empty_key: bool = False):
    return ml.Asym(
        {
            key: dense_projector(dims)
            for key in keys
        },
        allow_empty_key=allow_empty_key,
    )

In [8]:
base_model._modules

OrderedDict([('0',
              Transformer({'max_seq_length': 512, 'do_lower_case': True}) with Transformer model: BertModel ),
             ('1',
              Pooling({'word_embedding_dimension': 768, 'pooling_mode_cls_token': True, 'pooling_mode_mean_tokens': False, 'pooling_mode_max_tokens': False, 'pooling_mode_mean_sqrt_len_tokens': False}))])

In [9]:
base_model._modules["2"] = asym_module(emb_dims, ["dialog", "fact"])

In [10]:
base_model._modules

OrderedDict([('0',
              Transformer({'max_seq_length': 512, 'do_lower_case': True}) with Transformer model: BertModel ),
             ('1',
              Pooling({'word_embedding_dimension': 768, 'pooling_mode_cls_token': True, 'pooling_mode_mean_tokens': False, 'pooling_mode_max_tokens': False, 'pooling_mode_mean_sqrt_len_tokens': False})),
             ('2',
              Asym(
                (dialog-0): Dense({'in_features': 768, 'out_features': 1536, 'bias': True, 'activation_function': 'torch.nn.modules.activation.Tanh'})
                (dialog-1): Dense({'in_features': 1536, 'out_features': 1536, 'bias': True, 'activation_function': 'torch.nn.modules.activation.Tanh'})
                (dialog-2): Dense({'in_features': 1536, 'out_features': 768, 'bias': True, 'activation_function': 'torch.nn.modules.activation.Tanh'})
                (fact-0): Dense({'in_features': 768, 'out_features': 1536, 'bias': True, 'activation_function': 'torch.nn.modules.activation.Tanh'})
     

In [11]:
train_loss = ls.TripletLoss(model=base_model)

In [14]:
triplet_evaluator = TripletEvaluator.from_input_examples(
    eval_data,  # Triplet is ({dialog: <some_dialog>}, {fact: <relevant_fact>}, [{fact: <negative_irrelevant_fact>}])
    batch_size=batch_size // 10,
    main_distance_function=SimilarityFunction.COSINE,
    show_progress_bar=True,
    write_csv=True,
)

In [None]:
base_model.fit(
    train_objectives=[(dataloader, train_loss)],
    evaluator=triplet_evaluator,
    checkpoint_save_steps=600,
    checkpoint_path=f"{output_model_path}/ckpts",
    scheduler="WarmupCosine",
    save_best_model=True,
    epochs=12,
    warmup_steps=100,
    optimizer_class=Lion,
    optimizer_params=dict(lr=1e-4, weight_decay=1e-2),
    use_amp=True,
    output_path=output_model_path,
    checkpoint_save_total_limit=4,
)

Epoch:   0%|          | 0/12 [00:00<?, ?it/s]

Iteration:   0%|          | 0/1321 [00:00<?, ?it/s]

Batches:   0%|          | 0/293 [00:00<?, ?it/s]

Batches:   0%|          | 0/293 [00:00<?, ?it/s]

Batches:   0%|          | 0/293 [00:00<?, ?it/s]

Iteration:   0%|          | 0/1321 [00:00<?, ?it/s]