# Fine-tuning MLMs 🤖⚙️

This brief research is intended to explore the different fine-tuning approaches that can be applied for adapting bert-like MLMs to a custom domain datasets.

Important points:
* Dataset: [medical_questions_pairs](https://huggingface.co/datasets/medical_questions_pairs)
* Model: [bert-base-cased](https://huggingface.co/bert-base-cased)
* We will define auxiliar functions in auxiliar.py file
* We will be logging the results in Weight&Biases.
<br>

<figure>
  <img src="../data/images/adaptive_fine-tuning.png">
  
  <figcaption style='text-align:center';>
  Framework for fine-tuning LMs. 
  <a href="https://ruder.io/recent-advances-lm-fine-tuning/">Sebastian Rude's post</a>
  </figcaption>
</figure>

In [1]:
import torch
import config

if torch.cuda.is_available():
   device = torch.device("cuda:0")
else:
    device = torch.device("cpu")

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
device

device(type='cuda', index=0)

## 1. Data preparation

### 1.1. Import and set creation

Import data and create partitions.

In [6]:
from datasets import load_dataset

# Download and extract data
data = load_dataset("medical_questions_pairs")
data = data['train']

# Split it
data = data.train_test_split(test_size=0.07, seed=config.SEED)

Found cached dataset medical_questions_pairs (C:/Users/Juanju/.cache/huggingface/datasets/medical_questions_pairs/default/0.0.0/db30a35b934dceb7abed5ef6b73a432bb59682d00e26f9a1acd960635333bc80)
100%|██████████| 1/1 [00:00<00:00, 167.09it/s]
Loading cached split indices for dataset at C:\Users\Juanju\.cache\huggingface\datasets\medical_questions_pairs\default\0.0.0\db30a35b934dceb7abed5ef6b73a432bb59682d00e26f9a1acd960635333bc80\cache-3a6913e31ee3f147.arrow and C:\Users\Juanju\.cache\huggingface\datasets\medical_questions_pairs\default\0.0.0\db30a35b934dceb7abed5ef6b73a432bb59682d00e26f9a1acd960635333bc80\cache-55366722f45172c0.arrow


In [8]:
data

DatasetDict({
    train: Dataset({
        features: ['dr_id', 'question_1', 'question_2', 'label'],
        num_rows: 2834
    })
    test: Dataset({
        features: ['dr_id', 'question_1', 'question_2', 'label'],
        num_rows: 214
    })
})

As we can see, there is not that much ammount of samples. We will have to take that into consideration when training the models.

### 1.2. Tokenize and encode data

As mentioned, we will use **bert-base-cased** tokenizer

In [7]:
from transformers import AutoTokenizer

tokenizer = AutoTokenizer.from_pretrained(config.checkpoint, use_fast=True)

In [9]:
data = data.map(lambda x: tokenizer(x['question_1'], x['question_1'], truncation=True, padding='max_length'), batched=True)

Loading cached processed dataset at C:\Users\Juanju\.cache\huggingface\datasets\medical_questions_pairs\default\0.0.0\db30a35b934dceb7abed5ef6b73a432bb59682d00e26f9a1acd960635333bc80\cache-e4dd26900600fbe0.arrow
100%|██████████| 1/1 [00:00<00:00,  6.87ba/s]


In [10]:
from transformers import DataCollatorWithPadding

data_collator = DataCollatorWithPadding(tokenizer=tokenizer)

## 2. Exp 1: Baseline model training

Our first experiment consists on a basic training without any fine-tuning. We will freeze all parameters from the base model and just train the las FC layer. 

In [11]:
from transformers import AutoModelForSequenceClassification

model = AutoModelForSequenceClassification.from_pretrained(config.checkpoint, num_labels=2)

# freeze all params
for param in model.bert.parameters():
    param.requires_grad = False

Some weights of the model checkpoint at bert-base-cased were not used when initializing BertForSequenceClassification: ['cls.predictions.transform.dense.bias', 'cls.seq_relationship.bias', 'cls.predictions.transform.LayerNorm.bias', 'cls.predictions.bias', 'cls.predictions.transform.dense.weight', 'cls.predictions.transform.LayerNorm.weight', 'cls.seq_relationship.weight', 'cls.predictions.decoder.weight']
- This IS expected if you are initializing BertForSequenceClassification from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing BertForSequenceClassification from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).
Some weights of BertForSequenceClassification were not initialized from the model checkpoint at b

### 2.1. Init WandB

In [12]:
import wandb

wandb.login()

Failed to detect the name of this notebook, you can set it manually with the WANDB_NOTEBOOK_NAME environment variable to enable code saving.
[34m[1mwandb[0m: Currently logged in as: [33mjjceamoran[0m. Use [1m`wandb login --relogin`[0m to force relogin


True

In [13]:
run_name = 'baseline_training'
notes = "This experiment consists on a basic bert training with all encoder's layers frozen"
run = wandb.init(project='fine-tuning-mlms',
           name=run_name,
           notes=notes,
           job_type='train')


In [14]:
from transformers import Trainer, TrainingArguments
from training_aux import compute_metrics
import sklearn

training_args = TrainingArguments(
    output_dir="./experiments/" + run_name,
    learning_rate=2e-5, # low learning rate.
    per_device_train_batch_size=16,
    per_device_eval_batch_size=16,
    num_train_epochs=5,
    weight_decay=0.01,
    evaluation_strategy="epoch",
    save_strategy="epoch",
    load_best_model_at_end=True,
    report_to='wandb',
    run_name=run_name
)

trainer = Trainer(
    model=model,
    args=training_args,
    train_dataset=data['train'],
    eval_dataset=data['test'],
    tokenizer=tokenizer,
    data_collator=data_collator,
    compute_metrics=compute_metrics
)

In [15]:
trainer.train()

The following columns in the training set don't have a corresponding argument in `BertForSequenceClassification.forward` and have been ignored: dr_id, question_1, question_2. If dr_id, question_1, question_2 are not expected by `BertForSequenceClassification.forward`,  you can safely ignore this message.
***** Running training *****
  Num examples = 2834
  Num Epochs = 5
  Instantaneous batch size per device = 16
  Total train batch size (w. parallel, distributed & accumulation) = 16
  Gradient Accumulation steps = 1
  Total optimization steps = 890
  Number of trainable parameters = 1538
Automatic Weights & Biases logging enabled, to disable set os.environ["WANDB_DISABLED"] = "true"
  0%|          | 0/890 [00:00<?, ?it/s]You're using a BertTokenizerFast tokenizer. Please note that with a fast tokenizer, using the `__call__` method is faster than using a method to encode the text followed by a call to the `pad` method to get a padded encoding.
 20%|██        | 178/890 [06:52<20:42,  1.

{'eval_loss': 0.6930452585220337, 'eval_accuracy': 0.53, 'eval_f1': 0, 'eval_runtime': 31.8098, 'eval_samples_per_second': 6.727, 'eval_steps_per_second': 0.44, 'epoch': 1.0}


Model weights saved in ./experiments/baseline_training\checkpoint-178\pytorch_model.bin
tokenizer config file saved in ./experiments/baseline_training\checkpoint-178\tokenizer_config.json
Special tokens file saved in ./experiments/baseline_training\checkpoint-178\special_tokens_map.json
 40%|████      | 356/890 [14:35<14:18,  1.61s/it]  The following columns in the evaluation set don't have a corresponding argument in `BertForSequenceClassification.forward` and have been ignored: dr_id, question_1, question_2. If dr_id, question_1, question_2 are not expected by `BertForSequenceClassification.forward`,  you can safely ignore this message.
***** Running Evaluation *****
  Num examples = 214
  Batch size = 16
                                                 
 40%|████      | 356/890 [15:03<14:18,  1.61s/it]Saving model checkpoint to ./experiments/baseline_training\checkpoint-356
Configuration saved in ./experiments/baseline_training\checkpoint-356\config.json


{'eval_loss': 0.693488597869873, 'eval_accuracy': 0.47, 'eval_f1': 0, 'eval_runtime': 27.9652, 'eval_samples_per_second': 7.652, 'eval_steps_per_second': 0.501, 'epoch': 2.0}


Model weights saved in ./experiments/baseline_training\checkpoint-356\pytorch_model.bin
tokenizer config file saved in ./experiments/baseline_training\checkpoint-356\tokenizer_config.json
Special tokens file saved in ./experiments/baseline_training\checkpoint-356\special_tokens_map.json
 56%|█████▌    | 500/890 [20:18<14:03,  2.16s/it]  

{'loss': 0.7037, 'learning_rate': 8.764044943820226e-06, 'epoch': 2.81}


 60%|██████    | 534/890 [21:30<09:30,  1.60s/it]The following columns in the evaluation set don't have a corresponding argument in `BertForSequenceClassification.forward` and have been ignored: dr_id, question_1, question_2. If dr_id, question_1, question_2 are not expected by `BertForSequenceClassification.forward`,  you can safely ignore this message.
***** Running Evaluation *****
  Num examples = 214
  Batch size = 16
                                                 
 60%|██████    | 534/890 [21:57<09:30,  1.60s/it]Saving model checkpoint to ./experiments/baseline_training\checkpoint-534
Configuration saved in ./experiments/baseline_training\checkpoint-534\config.json


{'eval_loss': 0.691881537437439, 'eval_accuracy': 0.53, 'eval_f1': 1, 'eval_runtime': 27.8375, 'eval_samples_per_second': 7.687, 'eval_steps_per_second': 0.503, 'epoch': 3.0}


Model weights saved in ./experiments/baseline_training\checkpoint-534\pytorch_model.bin
tokenizer config file saved in ./experiments/baseline_training\checkpoint-534\tokenizer_config.json
Special tokens file saved in ./experiments/baseline_training\checkpoint-534\special_tokens_map.json
 80%|████████  | 712/890 [28:22<04:45,  1.60s/it]  The following columns in the evaluation set don't have a corresponding argument in `BertForSequenceClassification.forward` and have been ignored: dr_id, question_1, question_2. If dr_id, question_1, question_2 are not expected by `BertForSequenceClassification.forward`,  you can safely ignore this message.
***** Running Evaluation *****
  Num examples = 214
  Batch size = 16
                                                 
 80%|████████  | 712/890 [28:50<04:45,  1.60s/it]Saving model checkpoint to ./experiments/baseline_training\checkpoint-712
Configuration saved in ./experiments/baseline_training\checkpoint-712\config.json


{'eval_loss': 0.6938235759735107, 'eval_accuracy': 0.47, 'eval_f1': 0, 'eval_runtime': 28.0162, 'eval_samples_per_second': 7.638, 'eval_steps_per_second': 0.5, 'epoch': 4.0}


Model weights saved in ./experiments/baseline_training\checkpoint-712\pytorch_model.bin
tokenizer config file saved in ./experiments/baseline_training\checkpoint-712\tokenizer_config.json
Special tokens file saved in ./experiments/baseline_training\checkpoint-712\special_tokens_map.json
100%|██████████| 890/890 [35:16<00:00,  1.60s/it]The following columns in the evaluation set don't have a corresponding argument in `BertForSequenceClassification.forward` and have been ignored: dr_id, question_1, question_2. If dr_id, question_1, question_2 are not expected by `BertForSequenceClassification.forward`,  you can safely ignore this message.
***** Running Evaluation *****
  Num examples = 214
  Batch size = 16
                                                 
100%|██████████| 890/890 [35:43<00:00,  1.60s/it]Saving model checkpoint to ./experiments/baseline_training\checkpoint-890
Configuration saved in ./experiments/baseline_training\checkpoint-890\config.json


{'eval_loss': 0.6934771537780762, 'eval_accuracy': 0.47, 'eval_f1': 0, 'eval_runtime': 27.8555, 'eval_samples_per_second': 7.683, 'eval_steps_per_second': 0.503, 'epoch': 5.0}


Model weights saved in ./experiments/baseline_training\checkpoint-890\pytorch_model.bin
tokenizer config file saved in ./experiments/baseline_training\checkpoint-890\tokenizer_config.json
Special tokens file saved in ./experiments/baseline_training\checkpoint-890\special_tokens_map.json


Training completed. Do not forget to share your model on huggingface.co/models =)


Loading best model from ./experiments/baseline_training\checkpoint-534 (score: 0.691881537437439).
100%|██████████| 890/890 [35:45<00:00,  2.41s/it]

{'train_runtime': 2145.446, 'train_samples_per_second': 6.605, 'train_steps_per_second': 0.415, 'train_loss': 0.7028394720527563, 'epoch': 5.0}





TrainOutput(global_step=890, training_loss=0.7028394720527563, metrics={'train_runtime': 2145.446, 'train_samples_per_second': 6.605, 'train_steps_per_second': 0.415, 'train_loss': 0.7028394720527563, 'epoch': 5.0})