<a href="https://colab.research.google.com/github/lucarinelli/conditional_text_generation/blob/main/CTRL_Distillation.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# Import utilities

In [None]:
!rm -r conditional_text_generation
!git clone https://github.com/lucarinelli/conditional_text_generation.git

In [None]:
!pip install import-ipynb

%cd conditional_text_generation/notebooks

import import_ipynb

from Conditional_Text_Generation_Skeleton import *

%cd ../..

# WanDB

In [None]:
run = wandb.init()
artifact = run.use_artifact('polito_aiml2021_textgen/ctrl_dry_runs/GPT2_supercategories_5_epochs_v1:v0', type='model')
artifact_dir = artifact.download()

#Configuration

In [None]:
experiment_parameters["low_cuda"]= True  # True/False, used to move some operations of the distil process on the cpu in order to don't overflow cuda memory
experiment_parameters["training_args"].temperature= 1
experiment_parameters["max_train_set_len"] = 12  # positive integer, maximum number of items for the training set used
experiment_parameters["max_val_set_len"] = 12  # positive integer, maximum number of items for the validation set used

In [None]:
tokenizer, dataset_train_encoded, dataset_val_encoded, references = initialize_env()

# Models




In [None]:
from transformers import GPT2LMHeadModel

model = GPT2LMHeadModel.from_pretrained(experiment_parameters['model'], pad_token_id=tokenizer.eos_token_id)
model.resize_token_embeddings(len(tokenizer))

teacher = GPT2LMHeadModel.from_pretrained(artifact_dir)
teacher= teacher.cuda()
teacher.eval()

# Class room

In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.optim import Optimizer

class MyDistilTrainer(MyTrainer):
    def __init__(self, 
        teacher: Union[PreTrainedModel, torch.nn.Module] = None,
        model: Union[PreTrainedModel, torch.nn.Module] = None,
        args: TrainingArguments = None,
        data_collator: Optional[DataCollator] = None,
        train_dataset: Optional[Dataset] = None,
        eval_dataset: Optional[Dataset] = None,
        tokenizer: Optional[PreTrainedTokenizerBase] = None,
        model_init: Callable[[], PreTrainedModel] = None,
        compute_metrics: Optional[Callable[[EvalPrediction], Dict]] = None,
        callbacks: Optional[List[TrainerCallback]] = None,
        optimizers: Tuple[torch.optim.Optimizer, torch.optim.lr_scheduler.LambdaLR] = (None, None),): 

      self.teacher = teacher
      super().__init__(model, args, data_collator,
                       train_dataset, eval_dataset,
                       tokenizer, model_init, compute_metrics, 
                       callbacks, optimizers)
        
      self.loss = nn.KLDivLoss(reduction='batchmean')

    def compute_loss(self, model, inputs, return_outputs=False):
      with torch.no_grad():
          teacher_output = teacher(**inputs)

      student_output = self.model(**inputs)

      student_logits = student_output.logits
      teacher_logits = teacher_output.logits

      if experiment_parameters["low_cuda"] :
          student_logits = student_logits.cpu()
          teacher_logits = teacher_logits.cpu()
        

      student_sm = F.log_softmax(student_logits/self.args.temperature, dim=-1)
      teacher_sm = F.softmax(teacher_logits/self.args.temperature, dim=-1)

      loss = self.loss(input=student_sm, target=teacher_sm)

      if experiment_parameters["low_cuda"] :
        loss = loss.cuda()
       
      return (loss, student_logits) if return_outputs else loss   



# Training

In [None]:
trainer = MyDistilTrainer(
    teacher = teacher,
    model=model,                         # the instantiated 🤗 Transformers model to be trained
    args=experiment_parameters["training_args"],                  # training arguments, defined above
    train_dataset=dataset_train_encoded,         # training dataset
    eval_dataset=dataset_val_encoded,
    compute_metrics=lambda a,b: compute_metrics(tokenizer, references, a, b),
    )

In [None]:
trainer.train()

after_training(trainer)

# 