<!--<badge>--><a href="https://colab.research.google.com/github/ankur-98/BERT_GLUE/blob/main/multi_task.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a><!--</badge>-->

# For colab run:

If files imported on google drive

In [None]:
from google.colab import drive
drive.mount('/content/drive')
# change directory to the folder where the branch is stored if files imported in google drive
import os
os.chdir('/content/drive/MyDrive/Colab_Notebooks/training_GLUE')

To train directly from the Github repository

In [2]:
! git clone https://github.com/blanchardtom/PrivateTinyBertGlue.git
os.chdir('PrivateTinyBertGlue')

In [None]:
! pip install datasets transformers torch tqdm opacus peft evaluate

# Imports

In [5]:
import torch
import numpy as np
from tqdm import tqdm
from transformers import BertForSequenceClassification, get_linear_schedule_with_warmup, AdamW
from peft import LoraConfig, PromptTuningConfig, PrefixTuningConfig, TaskType, get_peft_model
import evaluate
from opacus import PrivacyEngine


from util import *
from train import *
from dataloader import get_dataloader

# Configs
### Tasks under scrutiny : {"mnli", "qnli", "qqp", "sst2"}
### Types of fine-tuning : {"lora", "prefix, "prompt", "linear", "full"}

In [7]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model_checkpoint = "prajjwal1/bert-tiny"
tasks = ["sst2", "qnli", "qqp", "mnli"]
batch_sizes = [128, 512, 1024, 1024] # adaptative batch_sizes for datasets of different sizes
lr = 5e-4
epochs = 100

# learning config : to be changed for each type of fine-tuning
learning_way = "lora"

# Load Dataloader and Pre-trained BERT Model

In [None]:
num_labels = [3 if task.startswith("mnli") else 2 for task in tasks]
train_epoch_iterator = [get_dataloader(task, model_checkpoint, "train", batch_size=batch_size) for task, batch_size in zip(tasks, batch_sizes)]
eval_epoch_iterator = [get_dataloader(task, model_checkpoint, "validation", batch_size=batch_size) for task, batch_size in zip(tasks, batch_sizes)]
BERT_models = [BertForSequenceClassification.from_pretrained(model_checkpoint, num_labels=num_label).to(device) for num_label in num_labels]

# Choose of the learning way for selected tasks

In [9]:
peft_configs = []
if learning_way not in ['full', 'linear'] :
    for task in tasks :
        if learning_way == 'lora' :
            peft_configs.append(LoraConfig(task_type=TaskType.SEQ_CLS, r=8, lora_alpha=16, lora_dropout=0.1))
        elif learning_way == 'prefix' :
            peft_configs.append(PrefixTuningConfig(task_type=TaskType.SEQ_CLS, dropout=0.1, task_embedding=True, cls_token=True, num_virtual_tokens=10))
        elif learning_way == 'prompt' :
            peft_configs.append(PromptTuningConfig(task_type=TaskType.SEQ_CLS, num_virtual_tokens=10))
    models = [get_peft_model(BERT_model, peft_config) for BERT_model, peft_config in zip(BERT_models, peft_configs)]
elif learning_way == 'linear' :
    for model in BERT_models :
        for param in model.parameters() :
            param.require_grad = False
        # keeping only the last layer
        list(model.parameters())[-1].requires_grad = True

models

[PeftModelForSequenceClassification(
   (base_model): LoraModel(
     (model): BertForSequenceClassification(
       (bert): BertModel(
         (embeddings): BertEmbeddings(
           (word_embeddings): Embedding(30522, 128, padding_idx=0)
           (position_embeddings): Embedding(512, 128)
           (token_type_embeddings): Embedding(2, 128)
           (LayerNorm): LayerNorm((128,), eps=1e-12, elementwise_affine=True)
           (dropout): Dropout(p=0.1, inplace=False)
         )
         (encoder): BertEncoder(
           (layer): ModuleList(
             (0-1): 2 x BertLayer(
               (attention): BertAttention(
                 (self): BertSelfAttention(
                   (query): lora.Linear(
                     (base_layer): Linear(in_features=128, out_features=128, bias=True)
                     (lora_dropout): ModuleDict(
                       (default): Dropout(p=0.1, inplace=False)
                     )
                     (lora_A): ModuleDict(
              

In [10]:
for model in models :
  model.print_trainable_parameters()

trainable params: 8,450 || all params: 4,394,628 || trainable%: 0.1922802112033146
trainable params: 8,450 || all params: 4,394,628 || trainable%: 0.1922802112033146
trainable params: 8,450 || all params: 4,394,628 || trainable%: 0.1922802112033146
trainable params: 8,579 || all params: 4,394,886 || trainable%: 0.19520415319077675


# Optimizer, LR Scheduler and metrics

In [11]:
# Optimizer list optional here but can be used for lr customization for different models
optimizers = [torch.optim.AdamW(model.parameters(), lr=lr) for model in models]
lr_schedulers = [get_linear_schedule_with_warmup(optimizer, num_warmup_steps=0, num_training_steps=epochs*max([len(dataloader) for dataloader in train_epoch_iterator]))
                 for optimizer in optimizers]
metrics = [evaluate.load("glue", task) for task in tasks]

# Privacy engine

In [None]:
privacy_engines = [PrivacyEngine() for _ in tasks]
for i in range(len(models)):
    models[i], optimizers[i], train_epoch_iterator[i] = privacy_engines[i].make_private_with_epsilon(
        module=models[i],
        optimizer=optimizers[i],
        data_loader=train_epoch_iterator[i],
        max_grad_norm=0.1,
        target_epsilon=np.inf, # either np.inf or 8.0
        epochs=epochs,
        target_delta=1e-5
    )

# Training loop

In [None]:
iterators = [iter(train_epoch_iterator[i]) for i in range(len(tasks))]
iterators_eval = [iter(eval_epoch_iterator[i]) for i in range(len(tasks))]
for i in range(len(tasks)) :
    print(f"\nTraining for {tasks[i]} begins in batches of {batch_sizes[i]}.")
    for e in range(epochs):
        print(f"\nEpoch {e+1}/{epochs}")
        tr_loss = [0 for _ in tasks]
        global_steps = 0
        pbar = tqdm(train_epoch_iterator[i], total=len(train_epoch_iterator[i]))
        for _ in pbar :
            global_steps += 1
            iterator = iterators[i]
            inputs = prepare_inputs(next(iterator), device)
            step_loss = training_step(models[i], inputs, optimizers[i], lr_schedulers[i])
            step_loss = step_loss.item()
            tr_loss[i] += step_loss

            batch_loss = (tr_loss[i]/(global_steps*batch_sizes[i]))
            pbar.set_description(f"Task : {tasks[i]} " + str(batch_loss), refresh=True)
            pbar.update()


        eval_loss = [0 for _ in tasks]
        global_steps = 0
        pbar_eval = tqdm(eval_epoch_iterator[i], total=len(eval_epoch_iterator[i]))
        for _ in pbar_eval:
            global_steps += 1
            iterator = iterators_eval[i]
            inputs = prepare_inputs(next(iterator), device)
            step_loss, predictions, labels = eval_step(models[i], inputs)
            eval_loss[i] += step_loss.item()

            batch_loss_eval = (eval_loss[i]/(global_steps*batch_sizes[i]))
            metrics[i].add_batch(predictions=predictions, references=labels)
            eval_metric = metrics[i].compute()
            pbar_eval.set_description(f"Task : {tasks[i]} " + 'Eval_loss : '+ str(batch_loss_eval) + 'accuracy : ' + str(eval_metric), refresh=True)
            pbar_eval.update()

    final_metrics = [metrics[i].compute() for i in range(len(tasks))]
    print("\n".join([f'Evaluating on {task} : {final_metric}' for task, final_metric in zip(tasks, final_metrics)]))