In [None]:
import os 
from typing import Optional, Union
import pandas as pd, numpy as np, torch
from datasets import dataset
from dataclasses import dataclass 
from transformers import AutoTokenizer
from transformers import EarlyStoppingCallback
from transformers import TrainingArguments, Trainer, AutoModelForMultipleChoice 
from transformers.tokenization_utils_base import PreTrainedTokenizerBase, PaddingStrategy



In [None]:
class Cfg:
    use_peft = False 
    freeze_layers = 18
    freeze_embeddings = True
    max_input = 256
    model = 'microsoft/deberta-v3-large'

    

### Load Your Dataset
* Should have the columns: prompt, context, A, B, C, D, E, answer

In [None]:
df_train = pd.read_csv("Your path to train.csv")
df_valid = pd.read_csv("Your path to val.csv")


### Data Loader


In [None]:
tokenizer = AutoTokenizer.from_pretrained(Cfg.model)

option_to_index = {option: idx for idx, option in enumerate('ABCDE')}
index_to_option = {idx: option for option, idx in option_to_index.items()}

def preprocess(example):
    first_sentence = ["[CLS] " + example['context'] ] * 5
    second_sentence = [" #### " + example['prompt'] + " [SEP] " + example[option] + " [SEP] " for option in 'ABCDE']
    tokenized_examples = tokenizer(first_sentence, second_sentence, truncation="only_first", max_length=Cfg.max_input, padding='max_length', add_special_tokens=False)
    tokenized_examples["labels"] = [option_to_index[example["answer"]]]

    return tokenized_examples

@dataclass
class DataCollatorForMultipleChoice:
    """
    Data collator that will dynamically pad the inputs for multiple choice received.
    """

    tokenizer: PreTrainedTokenizerBase
    padding: Union[bool, str, PaddingStrategy] = True
    max_length: Optional[int] = None
    pad_to_multiple_of: Optional[int] = None
    
    def __call__(self, features):
        label_name = 'label' if 'label' in features[0].keys() else 'labels'
        labels = [feature.pop(label_name) for feature in features]
        batch_size = len(features)
        num_choices = len(features[0]['input_ids'])
        flattened_features = [
            [{k: v[i] for k, v in feature.items()} for i in range(num_choices)] for feature in features
        ]
        flattened_features = sum(flattened_features, [])
        
        batch = self.tokenizer.pad(
            flattened_features,
            padding=self.padding,
            max_length=self.max_length,
            pad_to_multiple_of=self.pad_to_multiple_of,
            return_tensors='pt',
        )
        batch = {k: v.view(batch_size, num_choices, -1) for k, v in batch.items()}
        batch['labels'] = torch.tensor(labels, dtype=torch.int64)
        return batch

In [None]:
# Preprocessing the datasets.
dataset_valid = dataset.Dataset.from_pandas(df_valid)
dataset_train = dataset.Dataset.from_pandas(df_train)

In [None]:
#tokenize the dataset
tokenized_dataset_valid = dataset_valid.map(preprocess, remove_columns=['prompt', 'context', 'A', 'B', 'C', 'D', 'E', 'answer'])
tokenized_dataset_train = dataset_train.map(preprocess, remove_columns=['prompt', 'context', 'A', 'B', 'C', 'D', 'E', 'answer'])


### Build The Model

In [None]:
model = AutoModelForMultipleChoice.from_pretrained(Cfg.model)

if Cfg.use_peft:
    print('Using PEFT')
    from peft import LoraConfig, get_peft_model, TaskType
    peft_config = LoraConfig(
        r=8,
        lora_alpha = 4,
        task_type = TaskType.SEQ_CLS, 
        lora_dropout = 0.1,
        bias = "none",
        inference_mode = False,
        target_modules = ["query_proj", "value_proj"],
        modules_to_save = ["classifier", "pooler"]
    )
    model = get_peft_model(model, peft_config)
    model.print_trainable_parameters()

if Cfg.freeze_embeddings:
    print('Freezing Embeddings')
    from param in model.deberta.embeddings.parameters():
        param.requires_grad = False
    
    
if Cfg.freeze_layers:
    print('Freezing Layers')
    for layer in model.deberta.encoder.layer[:Cfg.freeze_layers]:
        for param in layer.parameters():
            param.requires_grad = False