## Custom sbert to retrieve beliefs for dialogs

**Goal:** Train a custom sentence-transformer to match dialogs with beliefs and facts
**Method:**
- [x] Use bge_small_en_15 model as base
- [x] evaluate embedding similarity using base on test split
    - [x] calc embeddings
    - [x] calc cosine scores
    - [x] visualize using matplotlib
- [x] define new model arch
    - [x] add custom modules
    - [x] freeze transformer layer
    - [x] define loss
- [ ] finetune the model on train dataset
- [ ] Evaluate the final model on the test-set
- [ ] Compare results

### Constants

In [1]:
model_name = "BAAI/bge-base-en-v1.5"
query_prefix = "Represent this sentence for searching relevant passages: "
max_len = 512
training_hn_file = "./data/hn-output.jsonl"
eval_file = "./data/eval-output.jsonl"
batch_size = 1280
output_model_path = "./bge-base-custom"

### Imports and utils

In [2]:
%matplotlib inline

from functools import partial
import itertools as it
import os
import random

from datasets import load_from_disk
from FlagEmbedding import FlagModel
import jsonlines as jsonl
from lion_pytorch import Lion
import matplotlib.pyplot as plt
import numpy as np
from numpy import dot
from numpy.linalg import norm
from sentence_transformers import InputExample, SentenceTransformer, losses as ls, models as ml, util
from sentence_transformers.evaluation import SimilarityFunction, TripletEvaluator
import torch
from torch.utils.data import DataLoader, IterableDataset
from tqdm.auto import tqdm

os.environ["CUDA_VISIBLE_DEVICES"] = "0,1,2"

### Datasets

- combined: stacked_samsum + dialogsum
- hn_output: hard negatives mined triple set

In [3]:
def hn_output(file):
    with jsonl.open(file) as reader:
        for entry in reader:
            query = entry["query"]
            pos = [dict(dialog=dialog) for dialog in entry["pos"]]
            neg = [dict(dialog=dialog) for dialog in entry["neg"]]

            for combined in it.product(
                [dict(fact=query)],
                pos,
                neg,
            ):
                yield InputExample(texts=list(combined))

In [3]:
training_data = list(tqdm(hn_output(training_hn_file)))
eval_data = list(tqdm(hn_output(eval_file)))

In [4]:
training_data = list(hn_output(training_hn_file))
eval_data = list(hn_output(eval_file))

### Prepare dataloader

In [5]:
dataloader = DataLoader(training_data, shuffle=True, batch_size=batch_size)
eval_dataloader = DataLoader(eval_data, shuffle=True, batch_size=batch_size // 10)

### Load base model

In [6]:
# Base model
base_model = SentenceTransformer(model_name)

In [7]:
# Freeze base transformer layers
for param in base_model.parameters():
    param.requires_grad = False

In [8]:
device = torch.device("cuda:0")

# Note that we must also set _target_device, or any SentenceTransformer.fit() call will reset
# the body location
base_model._target_device = device
base_model = base_model.to(device)

### Define new model architecture

In [9]:
emb_dims = base_model._first_module().get_word_embedding_dimension()

def dense_projector(dims: int):
    proj_dims = dims // 2
    
    return [
        ml.Dense(dims, proj_dims),
        # ml.Dropout(0.1),
        ml.Dense(proj_dims, proj_dims),
        # ml.Dropout(0.1),
        ml.Dense(proj_dims, dims),
    ]

def asym_module(dims: int, keys: list[str], allow_empty_key: bool = False):
    return ml.Asym(
        {
            key: dense_projector(dims)
            for key in keys
        },
        allow_empty_key=allow_empty_key,
    )

In [8]:
base_model._modules

OrderedDict([('0',
              Transformer({'max_seq_length': 512, 'do_lower_case': True}) with Transformer model: BertModel ),
             ('1',
              Pooling({'word_embedding_dimension': 384, 'pooling_mode_cls_token': True, 'pooling_mode_mean_tokens': False, 'pooling_mode_max_tokens': False, 'pooling_mode_mean_sqrt_len_tokens': False}))])

In [10]:
base_model._modules["2"] = asym_module(emb_dims, ["dialog", "fact"])

In [10]:
base_model._modules

OrderedDict([('0',
              Transformer({'max_seq_length': 512, 'do_lower_case': True}) with Transformer model: BertModel ),
             ('1',
              Pooling({'word_embedding_dimension': 384, 'pooling_mode_cls_token': True, 'pooling_mode_mean_tokens': False, 'pooling_mode_max_tokens': False, 'pooling_mode_mean_sqrt_len_tokens': False})),
             ('2',
              Asym(
                (dialog-0): Dense({'in_features': 384, 'out_features': 192, 'bias': True, 'activation_function': 'torch.nn.modules.activation.Tanh'})
                (dialog-1): Dropout(
                  (dropout_layer): Dropout(p=0.2, inplace=False)
                )
                (dialog-2): Dense({'in_features': 192, 'out_features': 192, 'bias': True, 'activation_function': 'torch.nn.modules.activation.Tanh'})
                (dialog-3): Dropout(
                  (dropout_layer): Dropout(p=0.1, inplace=False)
                )
                (dialog-4): Dense({'in_features': 192, 'out_features

### Define loss

In [11]:
train_loss = ls.TripletLoss(model=base_model)

### Create evaluator

In [12]:
triplet_evaluator = TripletEvaluator.from_input_examples(
    eval_data,
    batch_size=batch_size // 10,
    main_distance_function=SimilarityFunction.COSINE,
    show_progress_bar=True,
    write_csv=True,
)

### Start training

In [13]:
base_model.fit(
    train_objectives=[(dataloader, train_loss)],
    evaluator=triplet_evaluator,
    checkpoint_save_steps=600,
    checkpoint_path=f"{output_model_path}/ckpts",
    scheduler="WarmupCosine",
    save_best_model=True,
    epochs=10,
    warmup_steps=100,
    optimizer_class=Lion,
    optimizer_params=dict(lr=1e-4, weight_decay=1e-2),
    use_amp=True,
    output_path=output_model_path,
    checkpoint_save_total_limit=4,
)

Epoch:   0%|          | 0/3 [00:00<?, ?it/s]

Iteration:   0%|          | 0/7366 [00:00<?, ?it/s]

Batches:   0%|          | 0/586 [00:00<?, ?it/s]

Batches:   0%|          | 0/586 [00:00<?, ?it/s]

Batches:   0%|          | 0/586 [00:00<?, ?it/s]

Batches:   0%|          | 0/586 [00:00<?, ?it/s]

Batches:   0%|          | 0/586 [00:00<?, ?it/s]

Batches:   0%|          | 0/586 [00:00<?, ?it/s]

Iteration:   0%|          | 0/7366 [00:00<?, ?it/s]


KeyboardInterrupt



In [None]:
base_model.save_to_hub("julep-ai/dialog-bge-base", private=True)