# Fine tuning

In this notebook, we finetune an opensource sentencetransformers embedding model on our synthetically generated dataset.

### Load pretrained model

In [1]:
from sentence_transformers import SentenceTransformer



In [2]:
model_id = "BAAI/bge-large-en-v1.5"
model = SentenceTransformer(model_id)

In [3]:
model

SentenceTransformer(
  (0): Transformer({'max_seq_length': 512, 'do_lower_case': True}) with Transformer model: BertModel 
  (1): Pooling({'word_embedding_dimension': 1024, 'pooling_mode_cls_token': True, 'pooling_mode_mean_tokens': False, 'pooling_mode_max_tokens': False, 'pooling_mode_mean_sqrt_len_tokens': False})
  (2): Normalize()
)

### Define dataloader

In [1]:
import json

from torch.utils.data import DataLoader
from sentence_transformers import InputExample



In [2]:
DATA_PATH = "/home/75y/data_ragMimic/data/"

In [3]:
TRAIN_DATASET_FPATH = DATA_PATH+'train_dataset.json'
VAL_DATASET_FPATH = DATA_PATH+'val_dataset.json'

# We use a very small batchsize to run this toy example on a local machine. 
# This should typically be much larger. 
BATCH_SIZE = 16

In [4]:
with open(TRAIN_DATASET_FPATH, 'r+') as f:
    train_dataset = json.load(f)

with open(VAL_DATASET_FPATH, 'r+') as f:
    val_dataset = json.load(f)

In [5]:
dataset = train_dataset

corpus = dataset['corpus']
queries = dataset['queries']
relevant_docs = dataset['relevant_docs']

examples = []
for query_id, query in queries.items():
    node_id = relevant_docs[query_id][0]
    text = corpus[node_id]
    example = InputExample(texts=[query, text])
    examples.append(example)

In [6]:
loader = DataLoader(
    examples, batch_size=BATCH_SIZE
)

In [11]:
examples[-2].__dict__

{'guid': '',
 'texts': ["What was the patient's discharge diagnosis?",
  'they recommended  sling for the l arm and follow up in their clinic in 2 weeks.  \nhis pain was well controlled with tylenol alone.  he was doing  well and was discharged to home.  he was instructed to stop his \neliquis until he sees his cardiologist.  \n \nmedications on admission: eliquis\nbactrim \n \ndischarge medications: bactrim\n \ndischarge disposition: home\n \ndischarge diagnosis: 1. s/p motor vehicle collision\n2. displaced comminuted left clavicle fx\n3. right abdominal wall hematoma\n4. l5 right transverse process fx\n\n \ndischarge condition: mental status: clear and coherent.\nlevel of consciousness: alert and interactive.\nactivity status: ambulatory - independent.\n\n \ndischarge instructions: you came to the hospital after a car accident.  you were found  to have a broken left collarbone, a broken piece of bone in your  low back, and a blood collection in your right abdomen.  you  were monitore

### Define loss

**MultipleNegativesRankingLoss** is a great loss function if you only have positive pairs, for example, only pairs of similar texts like pairs of paraphrases, pairs of duplicate questions, pairs of (query, response), or pairs of (source_language, target_language).

This loss function works great to train embeddings for retrieval setups where you have positive pairs (e.g. (query, relevant_doc)) as it will sample in each batch n-1 negative docs randomly.

The performance usually increases with increasing batch sizes.

For more detals, see:
* [docs](https://www.sbert.net/docs/package_reference/losses.html)

In [10]:
from sentence_transformers import losses

In [11]:
loss = losses.MultipleNegativesRankingLoss(model)

### Define evaluator 

We setup an evaluator with our val split of the dataset to monitor how well the embedding model is performing during training.

In [12]:
from sentence_transformers.evaluation import InformationRetrievalEvaluator

In [13]:
dataset = val_dataset

corpus = dataset['corpus']
queries = dataset['queries']
relevant_docs = dataset['relevant_docs']

evaluator = InformationRetrievalEvaluator(queries, corpus, relevant_docs)

### Run training 

The training loop is very straight forward to steup thanks to sentencetransformers' high-level model training API.
All we need to do is plugging in the data loader, loss function, and evaluator that we defined in the previous cells (along with a couple of additional minor settings).

In [14]:
# We train the model for very few epochs in this toy example.
# This should typically be higher for better performance.
EPOCHS = 3

In [15]:
warmup_steps = int(len(loader) * EPOCHS * 0.1)

model.fit(
    train_objectives=[(loader, loss)],
    epochs=EPOCHS,
    warmup_steps=warmup_steps,
    output_path='exp_finetune',
    show_progress_bar=True,
    evaluator=evaluator, 
    evaluation_steps=1000,
)

Epoch:   0%|          | 0/3 [00:00<?, ?it/s]

Iteration:   0%|          | 0/6157 [00:00<?, ?it/s]

Iteration:   0%|          | 0/6157 [00:00<?, ?it/s]

Iteration:   0%|          | 0/6157 [00:00<?, ?it/s]

IOPub message rate exceeded.
The Jupyter server will temporarily stop sending output
to the client in order to avoid crashing it.
To change this limit, set the config variable
`--ServerApp.iopub_msg_rate_limit`.

Current values:
ServerApp.iopub_msg_rate_limit=1000.0 (msgs/sec)
ServerApp.rate_limit_window=3.0 (secs)

IOPub message rate exceeded.
The Jupyter server will temporarily stop sending output
to the client in order to avoid crashing it.
To change this limit, set the config variable
`--ServerApp.iopub_msg_rate_limit`.

Current values:
ServerApp.iopub_msg_rate_limit=1000.0 (msgs/sec)
ServerApp.rate_limit_window=3.0 (secs)

