<a href="https://colab.research.google.com/github/americanthinker/vectorsearch-applications/blob/main/notebooks/6-EmbeddingModel_FineTuning.ipynb" target="_blank"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

### Run these cells if on Colab

In [1]:
# !curl -o preprocessing.py https://raw.githubusercontent.com/americanthinker/vectorsearch-applications/main/src/preprocessor/preprocessing.py

  % Total    % Received % Xferd  Average Speed   Time    Time     Time  Current
                                 Dload  Upload   Total   Spent    Left  Speed
  0     0    0     0    0     0      0      0 --:--:-- --:--:-- --:--:--     0100  5576  100  5576    0     0  20896      0 --:--:-- --:--:-- --:--:-- 20962


In [2]:
# !curl -o qa_training_triplets.json https://raw.githubusercontent.com/americanthinker/vectorsearch-applications/main/data/qa_training_triplets.json

  % Total    % Received % Xferd  Average Speed   Time    Time     Time  Current
                                 Dload  Upload   Total   Spent    Left  Speed
100  680k  100  680k    0     0  1623k      0 --:--:-- --:--:-- --:--:-- 1620k


In [3]:
# !pip install sentence-transformers loguru --quiet

[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m171.5/171.5 kB[0m [31m1.6 MB/s[0m eta [36m0:00:00[0m
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m62.5/62.5 kB[0m [31m6.3 MB/s[0m eta [36m0:00:00[0m
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m21.3/21.3 MB[0m [31m22.1 MB/s[0m eta [36m0:00:00[0m
[?25h

# Fine-Tuning a SentenceTransformers Embedding Model
***

### Fine-tune High-Level Walkthrough

1. Get baseline retrieval scores (vector Hit Rate, MRR, and total misses) using out-of-the-box baseline model.  You won't know objectively if fine-tuning had any effect if you don't measure the baseline results first.  I know this goes without saying it, but practitioners sometimes want to jump straight into model improvement without first considering their starting point.
2. Collect a training dataset.  This step has already been completed for you, courtesy of `gpt-3.5-turbo`.  The training dataset consists of triplets in the following format:
   - **Anchor**: The context i.e. a random text chunk created by the initial baseline model
   - **Positive**: A query generated by the LLM that can be answered by the anchor context.
   - **Hard Negative**: A query generated by the LLM that is semantically similar to the positive, but cannot be answered by the anchor context.
These triplets were generated using a prompt specifically written for the Huberman Lab corpus so the training data (for the most part) is high quality and contextually relevant.
3. Train the model and set a path where the new model will reside.  I created a `models/` directory in the course repo, and included the directory in the `.gitignore` file so that models aren't being pushed with every commit.
4. Create a new dataset (as you learned in Notebook 1) but this time create the embeddings using the new fine-tuned model.
5. Create a new index on Weaviate using the new dataset you just created.
6. Run the `retrieval_evaluation` function again, but this time instantiate your Weaviate client with the new fine-tuned model, but hold all other parameters constant (i.e. don't change any other parameter from the baseline run).
7. Compare the fine-tuned retrieval results to the baseline results 🥳

## Load Model


In [1]:
import sys
sys.path.append('../')
try:
  from src.preprocessor.preprocessing import FileIO
except ModuleNotFoundError:
  from preprocessing import FileIO

from torch import cuda
from torch.utils.data import DataLoader
from sentence_transformers import SentenceTransformer, losses, InputExample, models

### Execute Model Loading func

In [2]:
def load_pretrained_model(model_name: str='sentence-transformers/all-MiniLM-L6-v2'):
    '''
    Loads sentence transformer modules and returns a pretrained
    model for finetuning.
    '''
    word_embedding_model = models.Transformer(model_name)
    pooling_model = models.Pooling(word_embedding_model.get_word_embedding_dimension())
    model = SentenceTransformer(modules=[word_embedding_model, pooling_model])
    return model

In [3]:
model = load_pretrained_model()
model.device



config.json:   0%|          | 0.00/612 [00:00<?, ?B/s]

model.safetensors:   0%|          | 0.00/90.9M [00:00<?, ?B/s]

tokenizer_config.json:   0%|          | 0.00/350 [00:00<?, ?B/s]

vocab.txt:   0%|          | 0.00/232k [00:00<?, ?B/s]

tokenizer.json:   0%|          | 0.00/466k [00:00<?, ?B/s]

special_tokens_map.json:   0%|          | 0.00/112 [00:00<?, ?B/s]

device(type='cpu')

## Prep Data


### Import Training Dataset

In [12]:
#depending on if you are running locally or on Colab
data_path  = '../data/qa_training_triplets.json'  # '/content/qa_training_triplets.json'
# data_path  = '/content/qa_training_triplets.json'
data = FileIO.load_json(data_path)
len(data)

500

#### Peek at the data

In [14]:
data[0]

{'positive': 'What effects does L-Cetyl-L-Carnitine have on cellular metabolism?',
 'hard_negative': 'What effects does L-Cetyl-L-Carnitine have on hair growth?',
 'anchor': "I confess I have used it in pill form from time to time, but in part because of the fat oxidation effects, but also because of the other effects that it tends to have. So in exploring the effects that L-Cetyl-L-Carnitine has, it has a huge variety of effects on cellular metabolism. It can reduce ammonia in the blood. That is actually a quite strong effect. It can reduce things like C-reactive protein, which is you want C-reactive protein levels to be managed. You do not want them too high. It can slightly reduce blood glucose. It can slightly increase HDLC, the good form of the blood lipid, and slightly reduce overall cholesterol. And as I mentioned, it can slightly modify the pathway involving glucagon such that you get a considerable effect, not a huge effect on fat oxidation, so it can improve fat oxidation rat

### Build list of InputExamples & Create Dataloader

In [15]:
train_examples = [InputExample(texts=[sample['anchor'],
                                      sample['positive'],
                                      sample['hard_negative']
                                     ]) for sample in data]
train_dataloader = DataLoader(train_examples, shuffle=True, batch_size=32, )

#### Training example peek

In [17]:
train_examples[0].__dict__

{'guid': '',
 'texts': ["I confess I have used it in pill form from time to time, but in part because of the fat oxidation effects, but also because of the other effects that it tends to have. So in exploring the effects that L-Cetyl-L-Carnitine has, it has a huge variety of effects on cellular metabolism. It can reduce ammonia in the blood. That is actually a quite strong effect. It can reduce things like C-reactive protein, which is you want C-reactive protein levels to be managed. You do not want them too high. It can slightly reduce blood glucose. It can slightly increase HDLC, the good form of the blood lipid, and slightly reduce overall cholesterol. And as I mentioned, it can slightly modify the pathway involving glucagon such that you get a considerable effect, not a huge effect on fat oxidation, so it can improve fat oxidation rates. It has a number of other effects, some of which I talked about during the month on hormones and that sort of thing. It has strong effects on rates

## Set Loss Function, Epochs, and warm-up


In [18]:
num_epochs = 3
train_loss = losses.MultipleNegativesRankingLoss(model=model)
warmup_steps = int(len(train_dataloader) * num_epochs * 0.1) #10% of train data

## Train model

In [19]:
model.device

device(type='cuda', index=0)

In [20]:
model.fit(train_objectives=[(train_dataloader, train_loss)],
          epochs=num_epochs,
          warmup_steps=warmup_steps)

Epoch:   0%|          | 0/3 [00:00<?, ?it/s]

Iteration:   0%|          | 0/16 [00:00<?, ?it/s]

Iteration:   0%|          | 0/16 [00:00<?, ?it/s]

Iteration:   0%|          | 0/16 [00:00<?, ?it/s]

### Save model
---
Similar to how you have labeled dataset and collections, it's a good idea to stick to a naming convention that will allow you to keep track of fine-tuned models that you create.  I would suggest sticking to something like the following convention:  

`short-hand model name`-`finetuned`-`dataset size`

For example a finetuned version of the `all-MiniLM` could looks like this:  
`allminilm-finetuned-500`

If you want to get even more granular you could add other unique identifiers for experimentation such as adding number of epochs:  
`allminilm-finetuned-500-2` --> an `all-MiniLM` model finetuned on a 500 samples over 2 epochs

**I would also recommend creating a dedicated `models` folder in your top-level directory.  The repo `.gitignore` folder already has the `models` folder included to avoid pushing large file sizes to Github.  But after you've created the folder you should be able to access the model via a path similar to this one:**

`models/allminilm-finetuned-500`

In [25]:
# model.save(path='local path', model_name='name of your model')

### COLAB-specific saving and downloading
---
If you are running this notebook on Google Colab then I recommend the following steps.

#### Save the finetuned model in current dir

In [26]:
#define your path
model_path = './allminilm-finetuned-256'
# model.save(model_path, model_name='name of your model')
model.save(model_path, model_name='mymodel')

#### Zip the model folder into a single file

In [27]:
#ensure the paths match
!zip -r /content/model.zip /content/allminilm-finetuned-256/

  adding: content/allminilm-finetuned-256/ (stored 0%)
  adding: content/allminilm-finetuned-256/config_sentence_transformers.json (deflated 31%)
  adding: content/allminilm-finetuned-256/sentence_bert_config.json (deflated 4%)
  adding: content/allminilm-finetuned-256/README.md (deflated 59%)
  adding: content/allminilm-finetuned-256/model.safetensors (deflated 8%)
  adding: content/allminilm-finetuned-256/tokenizer.json (deflated 71%)
  adding: content/allminilm-finetuned-256/1_Pooling/ (stored 0%)
  adding: content/allminilm-finetuned-256/1_Pooling/config.json (deflated 57%)
  adding: content/allminilm-finetuned-256/config.json (deflated 47%)
  adding: content/allminilm-finetuned-256/tokenizer_config.json (deflated 74%)
  adding: content/allminilm-finetuned-256/special_tokens_map.json (deflated 80%)
  adding: content/allminilm-finetuned-256/modules.json (deflated 53%)
  adding: content/allminilm-finetuned-256/vocab.txt (deflated 53%)


Once you have zipped the model you can download locally as a single zipped file by right-clicking on the file and selecting "Download"

### Model Evaluation
---
Fine-tuning is just the start!  You still have to create a new dataset using the fine-tuned model, index that data on Weaviate, and then evaluate its performance.  This is why having a solid dataset creation and indexing pipeline is key, especially if you plan on running multiple experiments to optimize your results.  Follow this recipe:  
1. Create new dataset (from Notebook 1)
2. Index that dataset and create an easily identifiable collection name i.e. `Huberman_minilm_finetuned_256` (from Notebook 2)
3. Run the `execute_evaluation` function (from Notebook 4)

Assuming you are in the `notebooks` folder when performing the new evaluation and you have created a `models` folder in the top-level directory, the following code snippet will load the Weaviate client with the fine-tuned model and ensure that you are hitting the right collection for evaluation:

In [2]:
from src.database.database_utils import get_weaviate_client

In [3]:
model_path = '../models/minilm-finetuned-500/'
client = get_weaviate_client(model_name_or_path=model_path)
collection_name = 'Huberman_minilm_finetuned_256'

In [5]:
from src.evaluation.retrieval_evaluation import execute_evaluation

In [6]:
golden_dataset = FileIO.load_json("../data/golden_datasets/golden_256.json")
retrieval_results = execute_evaluation(
    dataset=golden_dataset,
    collection_name=collection_name,
    retriever=client,
    reranker=None,      
    alpha=0.16,
    top_k=5,
    retrieve_limit=20,
)

Queries:   0%|          | 0/100 [00:00<?, ?it/s]

Queries:   1%|          | 1/100 [00:00<01:10,  1.40it/s]

Queries:   2%|▏         | 2/100 [00:00<00:37,  2.61it/s]

Queries:   3%|▎         | 3/100 [00:01<00:26,  3.60it/s]

Queries:   4%|▍         | 4/100 [00:01<00:21,  4.39it/s]

Queries:   5%|▌         | 5/100 [00:01<00:18,  5.01it/s]

Queries:   6%|▌         | 6/100 [00:01<00:17,  5.46it/s]

Queries:   7%|▋         | 7/100 [00:01<00:16,  5.79it/s]

Queries:   8%|▊         | 8/100 [00:01<00:15,  6.04it/s]

Queries:   9%|▉         | 9/100 [00:01<00:14,  6.21it/s]

Queries:  10%|█         | 10/100 [00:02<00:14,  6.33it/s]

Queries:  11%|█         | 11/100 [00:02<00:13,  6.43it/s]

Queries:  12%|█▏        | 12/100 [00:02<00:13,  6.49it/s]

Queries:  13%|█▎        | 13/100 [00:02<00:13,  6.52it/s]

Queries:  14%|█▍        | 14/100 [00:02<00:13,  6.55it/s]

Queries:  15%|█▌        | 15/100 [00:02<00:12,  6.57it/s]

Queries:  16%|█▌        | 16/100 [00:02<00:12,  6.58it/s]

Queries:  17%|█▋        | 17/100 [00:03<00:12,  6.59it/s]

Queries:  18%|█▊        | 18/100 [00:03<00:12,  6.60it/s]

Queries:  19%|█▉        | 19/100 [00:03<00:12,  6.57it/s]

Queries:  20%|██        | 20/100 [00:03<00:12,  6.59it/s]

Queries:  21%|██        | 21/100 [00:03<00:11,  6.59it/s]

Queries:  22%|██▏       | 22/100 [00:03<00:11,  6.57it/s]

Queries:  23%|██▎       | 23/100 [00:04<00:11,  6.60it/s]

Queries:  24%|██▍       | 24/100 [00:04<00:11,  6.62it/s]

Queries:  25%|██▌       | 25/100 [00:04<00:11,  6.65it/s]

Queries:  26%|██▌       | 26/100 [00:04<00:11,  6.67it/s]

Queries:  27%|██▋       | 27/100 [00:04<00:10,  6.65it/s]

Queries:  28%|██▊       | 28/100 [00:04<00:10,  6.64it/s]

Queries:  29%|██▉       | 29/100 [00:04<00:10,  6.65it/s]

Queries:  30%|███       | 30/100 [00:05<00:10,  6.64it/s]

Queries:  31%|███       | 31/100 [00:05<00:10,  6.64it/s]

Queries:  32%|███▏      | 32/100 [00:05<00:10,  6.64it/s]

Queries:  33%|███▎      | 33/100 [00:05<00:10,  6.65it/s]

Queries:  34%|███▍      | 34/100 [00:05<00:09,  6.62it/s]

Queries:  35%|███▌      | 35/100 [00:05<00:09,  6.62it/s]

Queries:  36%|███▌      | 36/100 [00:05<00:09,  6.62it/s]

Queries:  37%|███▋      | 37/100 [00:06<00:09,  6.62it/s]

Queries:  38%|███▊      | 38/100 [00:06<00:09,  6.63it/s]

Queries:  39%|███▉      | 39/100 [00:06<00:09,  6.62it/s]

Queries:  40%|████      | 40/100 [00:06<00:09,  6.62it/s]

Queries:  41%|████      | 41/100 [00:06<00:08,  6.61it/s]

Queries:  42%|████▏     | 42/100 [00:06<00:08,  6.62it/s]

Queries:  43%|████▎     | 43/100 [00:07<00:08,  6.63it/s]

Queries:  44%|████▍     | 44/100 [00:07<00:08,  6.62it/s]

Queries:  45%|████▌     | 45/100 [00:07<00:08,  6.63it/s]

Queries:  46%|████▌     | 46/100 [00:07<00:08,  6.63it/s]

Queries:  47%|████▋     | 47/100 [00:07<00:07,  6.63it/s]

Queries:  48%|████▊     | 48/100 [00:07<00:07,  6.62it/s]

Queries:  49%|████▉     | 49/100 [00:07<00:07,  6.62it/s]

Queries:  50%|█████     | 50/100 [00:08<00:07,  6.64it/s]

Queries:  51%|█████     | 51/100 [00:08<00:07,  6.63it/s]

Queries:  52%|█████▏    | 52/100 [00:08<00:07,  6.64it/s]

Queries:  53%|█████▎    | 53/100 [00:08<00:07,  6.64it/s]

Queries:  54%|█████▍    | 54/100 [00:08<00:06,  6.64it/s]

Queries:  55%|█████▌    | 55/100 [00:08<00:06,  6.64it/s]

Queries:  56%|█████▌    | 56/100 [00:09<00:06,  6.62it/s]

Queries:  57%|█████▋    | 57/100 [00:09<00:06,  6.64it/s]

Queries:  58%|█████▊    | 58/100 [00:09<00:06,  6.65it/s]

Queries:  59%|█████▉    | 59/100 [00:09<00:06,  6.66it/s]

Queries:  60%|██████    | 60/100 [00:09<00:06,  6.67it/s]

Queries:  61%|██████    | 61/100 [00:09<00:05,  6.66it/s]

Queries:  62%|██████▏   | 62/100 [00:09<00:05,  6.66it/s]

Queries:  63%|██████▎   | 63/100 [00:10<00:05,  6.67it/s]

Queries:  64%|██████▍   | 64/100 [00:10<00:05,  6.66it/s]

Queries:  65%|██████▌   | 65/100 [00:10<00:05,  6.66it/s]

Queries:  66%|██████▌   | 66/100 [00:10<00:05,  6.67it/s]

Queries:  67%|██████▋   | 67/100 [00:10<00:04,  6.67it/s]

Queries:  68%|██████▊   | 68/100 [00:10<00:04,  6.68it/s]

Queries:  69%|██████▉   | 69/100 [00:10<00:04,  6.67it/s]

Queries:  70%|███████   | 70/100 [00:11<00:04,  6.67it/s]

Queries:  71%|███████   | 71/100 [00:11<00:04,  6.66it/s]

Queries:  72%|███████▏  | 72/100 [00:11<00:04,  6.66it/s]

Queries:  73%|███████▎  | 73/100 [00:11<00:04,  6.64it/s]

Queries:  74%|███████▍  | 74/100 [00:11<00:03,  6.64it/s]

Queries:  75%|███████▌  | 75/100 [00:11<00:03,  6.64it/s]

Queries:  76%|███████▌  | 76/100 [00:12<00:03,  6.63it/s]

Queries:  77%|███████▋  | 77/100 [00:12<00:03,  6.62it/s]

Queries:  78%|███████▊  | 78/100 [00:12<00:03,  6.62it/s]

Queries:  79%|███████▉  | 79/100 [00:12<00:03,  6.60it/s]

Queries:  80%|████████  | 80/100 [00:12<00:03,  6.60it/s]

Queries:  81%|████████  | 81/100 [00:12<00:02,  6.62it/s]

Queries:  82%|████████▏ | 82/100 [00:12<00:02,  6.63it/s]

Queries:  83%|████████▎ | 83/100 [00:13<00:02,  6.64it/s]

Queries:  84%|████████▍ | 84/100 [00:13<00:02,  6.63it/s]

Queries:  85%|████████▌ | 85/100 [00:13<00:02,  6.65it/s]

Queries:  86%|████████▌ | 86/100 [00:13<00:02,  6.63it/s]

Queries:  87%|████████▋ | 87/100 [00:13<00:01,  6.63it/s]

Queries:  88%|████████▊ | 88/100 [00:13<00:01,  6.58it/s]

Queries:  89%|████████▉ | 89/100 [00:13<00:01,  6.58it/s]

Queries:  90%|█████████ | 90/100 [00:14<00:01,  6.58it/s]

Queries:  91%|█████████ | 91/100 [00:14<00:01,  6.58it/s]

Queries:  92%|█████████▏| 92/100 [00:14<00:01,  6.60it/s]

Queries:  93%|█████████▎| 93/100 [00:14<00:01,  6.60it/s]

Queries:  94%|█████████▍| 94/100 [00:14<00:00,  6.62it/s]

Queries:  95%|█████████▌| 95/100 [00:14<00:00,  6.62it/s]

Queries:  96%|█████████▌| 96/100 [00:15<00:00,  6.61it/s]

Queries:  97%|█████████▋| 97/100 [00:15<00:00,  6.63it/s]

Queries:  98%|█████████▊| 98/100 [00:15<00:00,  6.63it/s]

Queries:  99%|█████████▉| 99/100 [00:15<00:00,  6.64it/s]

Queries: 100%|██████████| 100/100 [00:15<00:00,  6.39it/s]


In [7]:
retrieval_results

{'n': 20,
 'top_k': 5,
 'alpha': 0.16,
 'Retriever': '../models/minilm-finetuned-500/',
 'Ranker': 'None',
 'chunk_size': 256,
 'query_props': ['content'],
 'total_misses': 0,
 'total_questions': 100,
 'kw_hit_rate': 0.0,
 'kw_mrr': 0.0,
 'vector_hit_rate': 0.0,
 'vector_mrr': 0.0,
 'hybrid_hit_rate': 0.0,
 'hybrid_mrr': 0.0}