In [None]:
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt

In [None]:
import wandb

from kaggle_secrets import UserSecretsClient
user_secrets = UserSecretsClient()
my_secret = user_secrets.get_secret("wanda-api")

wandb.login(key=my_secret)

In [None]:
MODEL = 'google-t5/t5-base'
BATCH_SIZE = 5
EPOCHS = 6
OUT_DIR = 't5_base_distractors_with_synthesized_dataset_custom_loss'
MAX_SOURCE_LENGTH = 600
MAX_TARGET_LENGTH = 256
LEARNING_RATE = 2e-4

In [None]:
from transformers import T5Tokenizer, T5ForConditionalGeneration

tokenizer = T5Tokenizer.from_pretrained(f"{MODEL}")
model = T5ForConditionalGeneration.from_pretrained(f"{MODEL}")
model.to('cuda')

Get Dataset\
Script for Dataset Prep: https://colab.research.google.com/drive/10yEybyBtr_bab3VRAc1ko3HUOU43S5Bx?usp=sharing

In [None]:
train_data = pd.read_csv('/kaggle/input/synthesized-multiple-choice-questions-dataset/trainset.csv')
val_data = pd.read_csv('/kaggle/input/synthesized-multiple-choice-questions-dataset/testset.csv')

In [None]:
train_data = train_data.sample(frac=1, ignore_index=True)
val_data = val_data.sample(frac=1, ignore_index=True)

In [None]:
train_data.shape, val_data.shape

In [None]:
train_data.head(1)

In [None]:
# add sep tokenizer to sep 3 distractors
tokenizer.add_special_tokens({'sep_token': '<sep>'})

In [None]:
tokenizer.convert_ids_to_tokens([1,
 2,
 32100, # id for sep
 0])

In [None]:
train_data.head(1)

In [None]:
prefix = "make 3 distractors:"
def preprocess_data(dataset, tokenizer):
    prompts = [f"{prefix} question: {question}, answer: {answer}, context: {context}" for question, answer, context in zip(dataset['question'], dataset['answer'], dataset['context'])]
    distractors = [f"{dis1} <sep> {dis2} <sep> {dis3}" for dis1, dis2, dis3 in zip(dataset['dis1'], dataset['dis2'], dataset['dis3'])]

    inputs = tokenizer(
        text=prompts,
        max_length=MAX_SOURCE_LENGTH,
        padding='max_length',
        truncation=True, 
        return_tensors='pt'
    )
    
    labels = tokenizer(
        text_target=distractors,
        max_length=MAX_TARGET_LENGTH,
        padding='max_length',
        truncation=True,
        return_tensors='pt'
    )

    return {
        'input_ids': inputs['input_ids'],
        'attention_mask': inputs['attention_mask'],
        'labels': labels['input_ids'],
        'decoder_attention_mask': labels['attention_mask']
    }

In [None]:
train_set = preprocess_data(train_data, tokenizer)
val_set = preprocess_data(val_data, tokenizer)

In [None]:
import datasets
train_dataset = datasets.Dataset.from_dict(train_set)
val_dataset = datasets.Dataset.from_dict(val_set)

In [None]:
train_dataset.shape, val_dataset.shape

# Set up Training

In [None]:
import torch

def preprocess_logits_for_metrics(logits, labels):
    """
    Original Trainer may have a memory leak. 
    This is a workaround to avoid storing too many tensors that are not needed.
    """
    pred_ids = torch.argmax(logits[0], dim=-1)
    return pred_ids

In [None]:
from nltk.translate import bleu_score

def compute_metrics(eval_pred): # eval_preds: tuple(preds, labels)
    pred_ids, labels_ids = eval_pred
    
    # if use preprocess_logits_for_metrics, don't need the below code
#     # logits: tuple(preds, inputs)
#     if isinstance(logits, tuple):
#         logits = logits[0] # get preds only
#     pred_ids = np.argmax(logits, axis=-1)
    
    pred_str = tokenizer.batch_decode(pred_ids, skip_special_tokens=True)
    
    labels_ids[labels_ids == -100] = tokenizer.pad_token_id
    label_str = tokenizer.batch_decode(labels_ids, skip_special_tokens=True)
    
    bleu = bleu_score.corpus_bleu(list_of_references=[[label] for label in label_str], hypotheses=pred_str)
    return {"bleu": bleu}

In [None]:
import torch.nn.functional as F

def distinct_loss(dis_embds):
    """Return difference values of distractors
    Args:
        dis_embds (list): list of embeddings, one for each distractor, shape: (1, dis_len, emb_dim)
        
    Returns:
        float: semantic difference levels
    """
    # sum pooling along the sequence for each embeddings to have 1 vector of shape (1, emd_dim) for each distractor
    pooled_embds = [torch.mean(embd, dim=1) for embd in dis_embds]
    
    sim_pen = 0.0
    for i in range(len(pooled_embds)):
        for j in range(i + 1, len(pooled_embds)):
            sim_pen += F.cosine_similarity(pooled_embds[i], pooled_embds[j], dim=-1).mean()
            
    num_pair = len(pooled_embds) * (len(pooled_embds) - 1) / 2 
    sim_pen = sim_pen / num_pair
    return sim_pen


In [None]:
# customer trainer
import torch 
from transformers import Trainer

class CustomTrainer(Trainer):
    def compute_loss(self, model, inputs, return_outputs=False):
        if self.label_smoother is not None and "labels" in inputs:
            labels = inputs.pop("labels")
        else:
            labels = None
        outputs = model(**inputs)
        # Save past state if it exists
        # TODO: this needs to be fixed and made cleaner later.
        if self.args.past_index >= 0:
            self._past = outputs[self.args.past_index]

        if labels is not None:
            unwrapped_model = self.accelerator.unwrap_model(model)
            if _is_peft_model(unwrapped_model):
                model_name = unwrapped_model.base_model.model._get_name()
            else:
                model_name = unwrapped_model._get_name()
            if model_name in MODEL_FOR_CAUSAL_LM_MAPPING_NAMES.values():
                loss = self.label_smoother(outputs, labels, shift_labels=True)
            else:
                loss = self.label_smoother(outputs, labels)
        else:
            if isinstance(outputs, dict) and "loss" not in outputs:
                raise ValueError(
                    "The model did not return a loss from the inputs, only the following keys: "
                    f"{','.join(outputs.keys())}. For reference, the inputs it received are {','.join(inputs.keys())}."
                )
            # We don't use .loss here since the model may return tuples instead of ModelOutput.
            loss = outputs["loss"] if isinstance(outputs, dict) else outputs[0]
        
        # begin my modification
        logits = outputs['logits'] # shape (batch_size, seq_length, vocab_size)
        special_tokens = {'pad': 0, 'eos': 1, 'unk': 2}
        distinct_losses = []
        
        for seq in logits: # seq (seq_length, vocab_size)
            seq_ids = torch.argmax(seq, dim=-1) # (seq_len)
                     
            mask = ~torch.isin(seq_ids, torch.Tensor([0, 1, 2]).to('cuda')) # hard code special token ids for convenience
            seq_ids = seq_ids[mask]
            
            sep_pos = (seq_ids == 32100).nonzero() # hard to for <sep> token
            sep_pos = sep_pos.squeeze(-1)
            
            if sep_pos.shape[0] == 2:
                dis1 = seq_ids[0:sep_pos[0]].unsqueeze(0) # dis: (1, dis_len)
                dis2 = seq_ids[sep_pos[0] + 1:sep_pos[1]].unsqueeze(0)
                dis3 = seq_ids[sep_pos[1] + 1:].unsqueeze(0)
            else: # not enough <sep> generated
                dis_len = int(seq_ids.shape[0] / 3) # avg len of a distractor

                dis1 = seq_ids[0:dis_len].unsqueeze(0) # dis: (1, dis_len)
                dis2 = seq_ids[dis_len:2*dis_len].unsqueeze(0)
                dis3 = seq_ids[2*dis_len:].unsqueeze(0)

            if dis1.shape[1] == 0 or dis2.shape[1] == 0 or dis3.shape[1] == 0: # dis: (1, dis_len)
                distinct_losses.append(0.5)
            else:
                dis_embs = [model.encoder(input_ids=dis, return_dict=True).last_hidden_state for dis in [dis1, dis2, dis3]]
                distinct_losses.append(distinct_loss(dis_embs))
        
        d_loss = torch.mean(torch.Tensor(distinct_losses))

        alpha = 0.5 # weight
        total_loss = ((1 - alpha) * loss) + (alpha * d_loss)
        
        return (total_loss, outputs) if return_outputs else total_loss

In [None]:
from transformers import TrainingArguments

args = TrainingArguments(
    output_dir=OUT_DIR,
    learning_rate=LEARNING_RATE,
    num_train_epochs=EPOCHS,
    per_device_train_batch_size=BATCH_SIZE,
    per_device_eval_batch_size=BATCH_SIZE,
    warmup_steps=500 , 
    weight_decay=0.01,
    load_best_model_at_end=True,
    evaluation_strategy='steps',
    metric_for_best_model="eval_loss",
    greater_is_better=False,
    save_total_limit=2,
    save_steps=1500,
    eval_steps=1500,
)

In [None]:
trainer = CustomTrainer(
    model=model,
    tokenizer=tokenizer,
    args=args,
    train_dataset=train_dataset,
    eval_dataset=val_dataset,
    compute_metrics=compute_metrics,
    preprocess_logits_for_metrics=preprocess_logits_for_metrics
)

In [None]:
history = trainer.train()