# Part 2- Fine tuning

In [3]:
import os
import kagglehub
import pandas as pd
from copy import deepcopy
import optuna
import wandb
from optuna.pruners import MedianPruner
import torch
from torch.utils.data import DataLoader, Dataset, random_split
from torch import nn, optim
from sklearn.metrics import precision_score, recall_score, f1_score, accuracy_score
from tqdm import tqdm
import gc
from sklearn.model_selection import train_test_split
import random
from torch.utils.data import Subset


ModuleNotFoundError: No module named 'optuna'

In [None]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
device

In [None]:
torch.amp.autocast('cuda')   # FP16
torch.backends.cudnn.benchmark = True

Dataset creation and formatting:

In [None]:
data_path = kagglehub.dataset_download("datatattle/covid-19-nlp-text-classification")
print("Path to dataset files:", data_path)
train_df = pd.read_csv(os.path.join(data_path, "Corona_NLP_train.csv"), encoding="latin1")
test_df = pd.read_csv(os.path.join(data_path, "Corona_NLP_test.csv"), encoding="latin1")

In [None]:
def delete_words_with_http(text, to_remove="http"):
    if not isinstance(text, str):
        return text
    words = text.split()
    kept_words = [word for word in words if to_remove not in word]
    return ' '.join(kept_words)

train_df['clean_text'] = train_df['OriginalTweet'].apply(lambda x: delete_words_with_http(x))
train_df.head()

Unnamed: 0,UserName,ScreenName,Location,TweetAt,OriginalTweet,Sentiment,clean_text
0,3799,48751,London,16-03-2020,@MeNyrbie @Phil_Gahan @Chrisitv https://t.co/i...,Neutral,@MeNyrbie @Phil_Gahan @Chrisitv and and
1,3800,48752,UK,16-03-2020,advice Talk to your neighbours family to excha...,Positive,advice Talk to your neighbours family to excha...
2,3801,48753,Vagabonds,16-03-2020,Coronavirus Australia: Woolworths to give elde...,Positive,Coronavirus Australia: Woolworths to give elde...
3,3802,48754,,16-03-2020,My food stock is not the only one which is emp...,Positive,My food stock is not the only one which is emp...
4,3803,48755,,16-03-2020,"Me, ready to go at supermarket during the #COV...",Extremely Negative,"Me, ready to go at supermarket during the #COV..."


In [None]:
test_df['clean_text']= test_df['OriginalTweet'].apply(lambda x: delete_words_with_http(x))
test_df.head()

Unnamed: 0,UserName,ScreenName,Location,TweetAt,OriginalTweet,Sentiment,clean_text
0,1,44953,NYC,02-03-2020,TRENDING: New Yorkers encounter empty supermar...,Extremely Negative,TRENDING: New Yorkers encounter empty supermar...
1,2,44954,"Seattle, WA",02-03-2020,When I couldn't find hand sanitizer at Fred Me...,Positive,When I couldn't find hand sanitizer at Fred Me...
2,3,44955,,02-03-2020,Find out how you can protect yourself and love...,Extremely Positive,Find out how you can protect yourself and love...
3,4,44956,Chicagoland,02-03-2020,#Panic buying hits #NewYork City as anxious sh...,Negative,#Panic buying hits #NewYork City as anxious sh...
4,5,44957,"Melbourne, Victoria",03-03-2020,#toiletpaper #dunnypaper #coronavirus #coronav...,Neutral,#toiletpaper #dunnypaper #coronavirus #coronav...


In [89]:
train_df.shape

(41157, 7)

In [None]:
train_df, val_df = train_test_split(train_df, test_size=0.2, random_state=42)
val_df.head()

Unnamed: 0,UserName,ScreenName,Location,TweetAt,OriginalTweet,Sentiment,clean_text
31089,34888,79840,Lagos,06-04-2020,Without the there would not be any problem wh...,Neutral,Without the there would not be any problem wha...
35564,39363,84315,,09-04-2020,Rice &amp; wheat prices surge amid fears Covid...,Extremely Negative,Rice &amp; wheat prices surge amid fears Covid...
144,3943,48895,,16-03-2020,When the government says to start social dista...,Positive,When the government says to start social dista...
8202,12001,56953,irlande du nord,19-03-2020,What the shops are doing is obeying the law of...,Positive,What the shops are doing is obeying the law of...
31720,35519,80471,"Zaria, Nigeria",07-04-2020,Kaduna State Task Force on Covid 19 led by the...,Negative,Kaduna State Task Force on Covid 19 led by the...


## Models

In [None]:
from transformers import AutoTokenizer, AutoModelForSequenceClassification, RobertaForSequenceClassification

tokenizer = AutoTokenizer.from_pretrained("cardiffnlp/twitter-roberta-base-sentiment", use_fast=False)
model1 = AutoModelForSequenceClassification.from_pretrained("cardiffnlp/twitter-roberta-base-sentiment", num_labels=5, ignore_mismatched_sizes=True).to(device)
model2 = AutoModelForSequenceClassification.from_pretrained('cardiffnlp/twitter-xlm-roberta-base-sentiment', num_labels=5, ignore_mismatched_sizes=True).to(device) # initialize RoBerta large from HF, num_labels=2 -> 2 classes.
models=[model1, model2]

Some weights of XLMRobertaForSequenceClassification were not initialized from the model checkpoint at cardiffnlp/twitter-xlm-roberta-base-sentiment and are newly initialized because the shapes did not match:
- classifier.out_proj.weight: found shape torch.Size([3, 768]) in the checkpoint and torch.Size([5, 768]) in the model instantiated
- classifier.out_proj.bias: found shape torch.Size([3]) in the checkpoint and torch.Size([5]) in the model instantiated
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


In [91]:
class Tweets(Dataset): # Dataset Class
    def __init__(self, dataframe, tokenizer):
        # Map sentiment strings to numerical labels
        sentiment_mapping = {
            'Extremely Negative': 0,
            'Negative': 1,
            'Neutral': 2,
            'Positive': 3,
            'Extremely Positive': 4
        }
        self.texts = dataframe['clean_text'].tolist()
        self.labels = dataframe['Sentiment'].map(sentiment_mapping).tolist() # Map sentiment strings to numerical labels
        self.tokenizer = tokenizer # Tokenizer for text processing

    def __len__(self): #Returns the total number of samples in the dataset.
        # This method is required for PyTorch's DataLoader to work !!
        return len(self.texts)

    def __getitem__(self, idx): #Retrieves a single data sample and its label at the specified index.
        text = self.texts[idx]
        label = self.labels[idx]
        label = torch.tensor(label, dtype=torch.long) 
        # Tokenize the text using the provided tokenizer
        encoding = self.tokenizer(
            text,
            padding='max_length', # Add padding to reach the maximum length
            truncation=True, # Trim if the text is longer than max_length
            max_length=512, # Maximum sequence length allowed
            return_tensors='pt')

        return {
            'input_ids': encoding['input_ids'].squeeze(),
            'attention_mask': encoding['attention_mask'].squeeze(),
            'labels': label.detach().clone() if isinstance(label, torch.Tensor) else torch.tensor(label, dtype=torch.long)
        }


In [92]:
def early_stop_check(patience, best_val_accuracy, best_val_accuracy_epoch, current_val_accuracy, current_val_accuracy_epoch):
    early_stop_flag = False
    if current_val_accuracy > best_val_accuracy:
        best_val_accuracy = current_val_accuracy
        best_val_accuracy_epoch = current_val_accuracy_epoch
    else:
        if current_val_accuracy_epoch - best_val_accuracy_epoch > patience:
            early_stop_flag = True
    return best_val_accuracy, best_val_accuracy_epoch, early_stop_flag

In [None]:
def train_model_with_hyperparams(model, train_loader, val_loader, optimizer, criterion, epochs, patience, trial):
    best_val_accuracy = 0.0
    best_val_accuracy_epoch = 0
    early_stop_flag = False
    best_model_state = None

    for epoch in range(1, epochs + 1):
        model.train() # Enable training mode
        train_loss = 0.0
        total_train_samples = 0
        correct_train_predictions = 0

        # Wrap train_loader with tqdm for progress visualization
        train_loader_tqdm = tqdm(train_loader, desc=f"Epoch {epoch}/{epochs} [Training]")

        for batch in train_loader_tqdm: #Iterates over the train_loader, which is a DataLoader object containing batches of training data.
            input_ids = batch['input_ids'].to(device)
            attention_mask = batch['attention_mask'].to(device)
            labels = batch['labels'].to(device)

            optimizer.zero_grad() # Reset gradients
            outputs = model(input_ids, attention_mask=attention_mask) # Forward pass
            logits = outputs.logits # save the logits (the raw output of the model)
            loss = criterion(logits, labels) # Calculate loss

            loss.backward() # Backward pass
            optimizer.step() # Update weights using the optimizer

            # Accumulate training loss and predictions
            train_loss += loss.item() * input_ids.size(0)
            total_train_samples += input_ids.size(0)
            correct_train_predictions += (logits.argmax(dim=1) == labels).sum().item()

            # Update tqdm description with current loss and accuracy
            train_loader_tqdm.set_postfix(loss=loss.item())


        train_loss /= total_train_samples
        train_accuracy = correct_train_predictions / total_train_samples

        ###  Validation loop  ###
        model.eval() # Enable evaluation mode
        val_loss = 0.0
        total_val_samples = 0
        correct_val_predictions = 0

        all_val_labels = []
        all_val_preds = []

        # Wrap val_loader with tqdm for progress visualization
        val_loader_tqdm = tqdm(val_loader, desc=f"Epoch {epoch}/{epochs} [Validation]")

        with torch.no_grad(): # Disable gradient computation
            for batch in val_loader_tqdm: # iterate on the val_loader's batches
                input_ids = batch['input_ids'].to(device)
                attention_mask = batch['attention_mask'].to(device)
                labels = batch['labels'].to(device)

                outputs = model(input_ids, attention_mask=attention_mask)
                logits = outputs.logits
                loss = criterion(logits, labels)

                val_loss += loss.item() * input_ids.size(0)
                total_val_samples += input_ids.size(0)
                correct_val_predictions += (logits.argmax(dim=1) == labels).sum().item()

                all_val_labels.extend(labels.cpu().numpy())
                all_val_preds.extend(logits.argmax(dim=1).cpu().numpy())

                # Update tqdm description with current loss and accuracy
                val_loader_tqdm.set_postfix(loss=loss.item())


        # calculate metrics
        val_loss /= total_val_samples
        val_accuracy = correct_val_predictions / total_val_samples
        val_precision = precision_score(all_val_labels, all_val_preds, average='macro')
        val_recall = recall_score(all_val_labels, all_val_preds, average='macro')
        val_f1 = f1_score(all_val_labels, all_val_preds, average='macro')
        
        # Check for early stopping
        best_val_accuracy, best_val_accuracy_epoch, early_stop_flag = early_stop_check(patience, best_val_accuracy, best_val_accuracy_epoch, val_accuracy, epoch)
        
        if trial is not None:
            trial.report(val_f1, step=epoch)

            if trial.should_prune():
                print(f"[Optuna] Trial pruned at epoch {epoch}")
                raise optuna.TrialPruned()
            
        # Save the best model under the best_model_state parameter
        if val_accuracy == best_val_accuracy:
            best_model_state = model.state_dict()

        # Log metrics to Weights & Biases - THIS IS WHERE WE TRACK THE RESULTS AND THE PROCESS
        wandb.log({ #log == logging of the training process (e.g. results) - will be done each epoch
            "Epoch": epoch,
            "Train Loss": train_loss,
            "Train Accuracy": train_accuracy,
            "Validation Loss": val_loss,
            "Validation Accuracy": val_accuracy,
            "Validation Precision": val_precision,
            "Validation Recall": val_recall,
            "Validation F1": val_f1})

        if early_stop_flag:  # Checks whether the early stopping condition has been met, as indicated by the early_stop_flag
            break# Exits the training loop immediately if the early stopping condition is satisfied

    if best_model_state is not None: # Save the best model as a .pt file
        torch.save(best_model_state, f"best_model_trial_{trial.number+1}.pt")
        artifact = wandb.Artifact("model", type="model")
        artifact.add_file(f"best_model_trial_{trial.number+1}.pt")
        wandb.log_artifact(artifact)

    return best_val_accuracy

In [None]:
def objective(trial):
    model_name = trial.suggest_categorical("model_name", [
    "cardiffnlp/twitter-roberta-base-sentiment",
    "cardiffnlp/twitter-xlm-roberta-base-sentiment"
])

    model = RobertaForSequenceClassification.from_pretrained(model_name, num_labels=5, ignore_mismatched_sizes=True).to(device) # Initialize RoBERTa model for sequence classification with 5 labels
    tokenizer = AutoTokenizer.from_pretrained(model_name, use_fast=False)
    # Hyperparameter suggestions
    learning_rate = trial.suggest_loguniform("learning_rate", 1e-5, 5e-5)
    weight_decay = trial.suggest_loguniform("weight_decay", 1e-7, 0.2)
    patience = trial.suggest_int("patience", 7, 10)
    batch_size = trial.suggest_categorical("batch_size", [16, 32, 64]) # Batch size can be 16, 32, or 64
    train_dataset = Tweets(train_df.sample(frac=0.2, random_state=trial.number), tokenizer) 
    val_dataset = Tweets(val_df, tokenizer) # Create dataset objects for training, validation, and testing 
    train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True) # insert into a DataLoader
    val_loader = DataLoader(val_dataset, batch_size=batch_size, shuffle=False) # insert into a DataLoader

    
    num_layers = trial.suggest_int("num_layers", 1, 4)
    for param in model.roberta.parameters():    # Freeze layers
        param.requires_grad = False
    for param in model.roberta.encoder.layer[-num_layers:].parameters():     # unfreeze the last "num_layers" of the encoder
        param.requires_grad = True
    for param in model.classifier.parameters():    #unfreeze the classifier
        param.requires_grad = True

    # Define optimizer and loss function
    criterion = nn.CrossEntropyLoss()
    optimizer = optim.Adam(model.parameters(), lr=learning_rate, weight_decay=weight_decay)

    # Initialize Weights & Biases - the values in the config are the properties of each trial.
    wandb.init(project="moti-matan-tel-aviv-university",
            config={
        "model_name": model_name,
        "learning_rate": learning_rate,
        "weight_decay": weight_decay,
        "patience": patience,
        "batch_size": batch_size,
        "num_layers": num_layers,
        "architecture": "RoBERTa",
        "dataset": "datatattle/covid-19-nlp-text-classification"},
        name=f"{model_name}+trial_{trial.number+1}") # The name that will be saved in the W&B platform

    # Train the model and get the best validation accuracy
    best_val_accuracy = train_model_with_hyperparams(model, train_loader, val_loader, optimizer, criterion, epochs=5, patience=patience, trial=trial)

    wandb.finish() # Finish the Weights & Biases run

    return best_val_accuracy # Return best validation acc as the objective to maximize

In [95]:
# Optuna Study
study = optuna.create_study(direction="maximize",pruner=MedianPruner(n_warmup_steps=1))  # Specifies that the goal of the optimization is to maximize the objective function
study.optimize(objective, n_trials=5)


[I 2025-06-21 17:45:08,822] A new study created in memory with name: no-name-34afd369-b8af-4a0f-9f47-989fcff36c88
Some weights of RobertaForSequenceClassification were not initialized from the model checkpoint at cardiffnlp/twitter-roberta-base-sentiment and are newly initialized because the shapes did not match:
- classifier.out_proj.weight: found shape torch.Size([3, 768]) in the checkpoint and torch.Size([5, 768]) in the model instantiated
- classifier.out_proj.bias: found shape torch.Size([3]) in the checkpoint and torch.Size([5]) in the model instantiated
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.

suggest_loguniform has been deprecated in v3.0.0. This feature will be removed in v6.0.0. See https://github.com/optuna/optuna/releases/tag/v3.0.0. Use suggest_float(..., log=True) instead.


suggest_loguniform has been deprecated in v3.0.0. This feature will be removed in v6.0.0. See https://github.com/optuna/optuna/re

Epoch 1/5 [Training]:   2%|▏         | 7/412 [00:01<00:52,  7.67it/s, loss=1.54]

Epoch 1/5 [Training]: 100%|██████████| 412/412 [00:52<00:00,  7.83it/s, loss=1.45] 
Epoch 1/5 [Validation]: 100%|██████████| 515/515 [00:54<00:00,  9.37it/s, loss=1.16] 
Epoch 2/5 [Training]: 100%|██████████| 412/412 [00:52<00:00,  7.77it/s, loss=1.5]  
Epoch 2/5 [Validation]: 100%|██████████| 515/515 [00:55<00:00,  9.32it/s, loss=1.03] 
Epoch 3/5 [Training]: 100%|██████████| 412/412 [00:53<00:00,  7.73it/s, loss=1.19] 
Epoch 3/5 [Validation]: 100%|██████████| 515/515 [00:55<00:00,  9.32it/s, loss=1.04] 
Epoch 4/5 [Training]: 100%|██████████| 412/412 [00:53<00:00,  7.72it/s, loss=1.04] 
Epoch 4/5 [Validation]: 100%|██████████| 515/515 [00:55<00:00,  9.32it/s, loss=0.996]
Epoch 5/5 [Training]: 100%|██████████| 412/412 [00:53<00:00,  7.76it/s, loss=1.01] 
Epoch 5/5 [Validation]: 100%|██████████| 515/515 [00:55<00:00,  9.36it/s, loss=1.04] 


0,1
Epoch,▁▃▅▆█
Train Accuracy,▁▄▆██
Train Loss,█▅▃▂▁
Validation Accuracy,▁▄▅▆█
Validation F1,▁▅▆▇█
Validation Loss,█▅▄▃▁
Validation Precision,▄▁█▆▇
Validation Recall,▁▆▆██

0,1
Epoch,5.0
Train Accuracy,0.48352
Train Loss,1.1852
Validation Accuracy,0.4887
Validation F1,0.50448
Validation Loss,1.18387
Validation Precision,0.51307
Validation Recall,0.50589


[I 2025-06-21 17:54:13,396] Trial 0 finished with value: 0.48870262390670555 and parameters: {'model_name': 'cardiffnlp/twitter-roberta-base-sentiment', 'learning_rate': 3.7011997967112637e-05, 'weight_decay': 0.007537992022984844, 'patience': 9, 'batch_size': 16, 'num_layers': 1}. Best is trial 0 with value: 0.48870262390670555.
You are using a model of type xlm-roberta to instantiate a model of type roberta. This is not supported for all configurations of models and can yield errors.
Some weights of RobertaForSequenceClassification were not initialized from the model checkpoint at cardiffnlp/twitter-xlm-roberta-base-sentiment and are newly initialized because the shapes did not match:
- classifier.out_proj.weight: found shape torch.Size([3, 768]) in the checkpoint and torch.Size([5, 768]) in the model instantiated
- classifier.out_proj.bias: found shape torch.Size([3]) in the checkpoint and torch.Size([5]) in the model instantiated
You should probably TRAIN this model on a down-strea

ValueError: Converting from SentencePiece and Tiktoken failed, if a converter for SentencePiece is available, provide a model path with a SentencePiece tokenizer.model file.Currently available slow->fast converters: ['AlbertTokenizer', 'BartTokenizer', 'BarthezTokenizer', 'BertTokenizer', 'BigBirdTokenizer', 'BlenderbotTokenizer', 'CamembertTokenizer', 'CLIPTokenizer', 'CodeGenTokenizer', 'ConvBertTokenizer', 'DebertaTokenizer', 'DebertaV2Tokenizer', 'DistilBertTokenizer', 'DPRReaderTokenizer', 'DPRQuestionEncoderTokenizer', 'DPRContextEncoderTokenizer', 'ElectraTokenizer', 'FNetTokenizer', 'FunnelTokenizer', 'GPT2Tokenizer', 'HerbertTokenizer', 'LayoutLMTokenizer', 'LayoutLMv2Tokenizer', 'LayoutLMv3Tokenizer', 'LayoutXLMTokenizer', 'LongformerTokenizer', 'LEDTokenizer', 'LxmertTokenizer', 'MarkupLMTokenizer', 'MBartTokenizer', 'MBart50Tokenizer', 'MPNetTokenizer', 'MobileBertTokenizer', 'MvpTokenizer', 'NllbTokenizer', 'OpenAIGPTTokenizer', 'PegasusTokenizer', 'Qwen2Tokenizer', 'RealmTokenizer', 'ReformerTokenizer', 'RemBertTokenizer', 'RetriBertTokenizer', 'RobertaTokenizer', 'RoFormerTokenizer', 'SeamlessM4TTokenizer', 'SqueezeBertTokenizer', 'T5Tokenizer', 'UdopTokenizer', 'WhisperTokenizer', 'XLMRobertaTokenizer', 'XLNetTokenizer', 'SplinterTokenizer', 'XGLMTokenizer', 'LlamaTokenizer', 'CodeLlamaTokenizer', 'GemmaTokenizer', 'Phi3Tokenizer']

In [None]:
torch.save(model.state_dict(), "model_weights.pt")
artifact = wandb.Artifact("model", type="model")
artifact.add_file("model_weights.pt")
wandb.log_artifact(artifact)
# Save the study results
torch.save(model.state_dict(), "model2_weights.pt")
artifact = wandb.Artifact("model2", type="model")
artifact.add_file("model2_weights.pt")
wandb.log_artifact(artifact)

In [None]:
# ניקוי זיכרון מה-GPU
torch.cuda.empty_cache()

# ניקוי זיכרון מה-Python
gc.collect()

In [None]:
# Function to evaluate the model
def evaluate_model(model_path, test_loader):
    # Load the model
    model.load_state_dict(torch.load(model_path)) # loading the trained model
    model = model.to(device)
    model.eval() # eval mode

    all_labels = []
    all_preds = []

    # Wrap test_loader with tqdm for progress visualization
    test_loader_tqdm = tqdm(test_loader, desc="Evaluating Model")

    #same idea... just testing and getting resutls...
    with torch.no_grad():
        for batch in test_loader_tqdm:
            input_ids = batch['input_ids'].to(device)
            attention_mask = batch['attention_mask'].to(device)
            labels = batch['labels'].to(device)

            outputs = model(input_ids, attention_mask=attention_mask)
            logits = outputs.logits
            preds = logits.argmax(dim=1)

            all_labels.extend(labels.cpu().numpy())
            all_preds.extend(preds.cpu().numpy())

            # Update tqdm description (optional, can show batch progress)
            test_loader_tqdm.set_postfix(batch=test_loader_tqdm.n)

    # Calculate metrics
    accuracy = accuracy_score(all_labels, all_preds)
    precision = precision_score(all_labels, all_preds, average='binary')
    recall = recall_score(all_labels, all_preds, average='binary')
    f1 = f1_score(all_labels, all_preds, average='binary')

    return {"Accuracy": accuracy, "Precision": precision, "Recall": recall, "F1 Score": f1}

In [96]:
# Load the test data set
test_dataset = Tweets(test_df, tokenizer.from_pretrained('roberta-large', use_fast=False)) # Create a dataset object for the test set
test_loader = DataLoader(test_dataset, batch_size=32, shuffle=False)

In [None]:
# Test multiple models
model1_path= 'cardiffnlp/twitter-roberta-base-sentiment'
model2_path = 'cardiffnlp/twitter-xlm-roberta-base-sentiment'

model_paths = [model1_path, model2_path]  # Replace with actual model paths
for model_path in model_paths:
    metrics = evaluate_model(model_path, test_loader)
    print(f"Metrics for {model_path}:")
    for key, value in metrics.items():
        print(f"{key}: {value:.4f}")

### FT using HF trainer

### Model Compression

In [None]:
import torch
from transformers import AutoTokenizer, AutoModelForSequenceClassification, TrainingArguments, Trainer
from datasets import load_dataset
from torch.nn import functional as F
from torch import nn

#### Metrics definitions

In [None]:
def size_of_model(model) -> int:
    return sum(p.numel() for p in model.parameters())

def non_zero_parameters_count(model):
    return sum((p != 0).sum().item() for p in model.parameters())

#### Quantization

In [None]:
from torch.quantization import quantize_dynamic

def compress_quantize_model(model):
    return quantize_dynamic(model, {nn.Linear}, dtype=torch.qint8)

#### Pruning

In [None]:
from torch.nn.utils import prune

def compress_prune_model(model, prune_percent = 0.4):
    model_to_prune = model
    parameters_to_prune = [ (module, 'weight') for module in model_to_prune.modules() if isinstance(module, nn.Linear)]
    prune.global_unstructured(
    parameters_to_prune,
    pruning_method=prune.L1Unstructured,
    amount=prune_percent)
    return model_to_prune

#### Knowledge Distilation

In [None]:
class DistillationTrainer(Trainer):
    def __init__(self,  teacher_model=None, temperature=2.0, alpha=0.5, *args, **kwargs):
        super().__init__(*args, **kwargs)
        self.teacher = teacher_model
        self.temperature = temperature
        self.alpha = alpha

    def compute_loss(self, model, inputs, return_outputs=False, num_items_in_batch=None):
        outputs_student = model(**inputs)
        with torch.no_grad():
            outputs_teacher = self.teacher(**inputs)

        loss_ce = F.cross_entropy(outputs_student.logits, inputs["labels"])
        loss_kl = F.kl_div(
            F.log_softmax(outputs_student.logits / self.temperature, dim=-1),
            F.softmax(outputs_teacher.logits / self.temperature, dim=-1),
            reduction="batchmean") * (self.temperature ** 2)
        loss = self.alpha * loss_ce + (1 - self.alpha) * loss_kl
        
        return (loss, outputs_student) if return_outputs else loss


def compress_distilled_model(model):
    student_distilled_model = AutoModelForSequenceClassification.from_pretrained(
                            'distilroberta-base',
                             num_labels=5)
    trainer_distill = DistillationTrainer(
    model=student_distilled_model,
    teacher_model=model,
    train_dataset=tokenized["train"],
    eval_dataset=tokenized["test"],
    compute_metrics=compute_metrics)

    trainer_distill.train()
    return student_distilled_model

In [None]:
import torch
import torch.nn as nn
import torch.nn.utils.prune as prune
import torch.quantization
import time
import pandas as pd
import copy
import numpy as np


def evaluate_model_to_compressed(model, test_dataloader, device='cpu'):
    model.to(device)
    model.eval()
    correct = 0
    total = 0
    inference_times = []
    
    with torch.no_grad():
        for inputs, labels in test_dataloader:
            inputs, labels = inputs.to(device), labels.to(device)
            start_time = time.time()
            outputs = model(inputs)
            inference_times.append(time.time() - start_time)

            _, predicted = torch.max(outputs, 1)
            correct += (predicted == labels).sum().item()
            total += labels.size(0)

    accuracy = 100 * correct / total
    avg_inference_time = 1000 * sum(inference_times) / len(inference_times)  # ms per batch

    num_params = sum(p.numel() for p in model.parameters())
    nonzero_params = sum(torch.count_nonzero(p).item() for p in model.parameters())
    size_mb = sum(p.element_size() * p.numel() for p in model.parameters()) / 1e6

    return {
        'accuracy (%)': accuracy,
        'inference_time (ms)': avg_inference_time,
        'total_params': num_params,
        'nonzero_params': nonzero_params,
        'model_size (MB)': size_mb,
        'sparsity (%)': 100 * (1 - nonzero_params / num_params)
    }
# Pruning
def apply_pruning(model):
    model = copy.deepcopy(model)
    parameters_to_prune = [(module, 'weight') for module in model.modules() if isinstance(module, nn.Linear)]
    prune.global_unstructured(parameters_to_prune, pruning_method=prune.L1Unstructured, amount=0.5)
    return model

# Quantization (static)
def apply_quantization(model):
    model = copy.deepcopy(model)
    model.eval()
    model.qconfig = torch.quantization.default_qconfig
    torch.quantization.prepare(model, inplace=True)
    torch.quantization.convert(model, inplace=True)
    return model



def collect_compression_metrics(model):
    results = {}
    results['original'] = evaluate_model_to_compressed(model)
    pruned_model = apply_pruning(model)
    results['pruned'] = evaluate_model_to_compressed(pruned_model)
    try:
        quantized_model = apply_quantization(model)
        results['quantized'] = evaluate_model_to_compressed(quantized_model)
    except Exception:
        results['quantized'] = {k: np.nan for k in results['original'].keys()}  # Fallback if quant fails

    # Distilled (simulated)
    distilled_model = SmallModel()
    results['distilled'] = evaluate_model_to_compressed(distilled_model)

    # Convert to DataFrame
    df = pd.DataFrame(results).T

    # Percent difference from best per metric
    percent_diff = df.apply(lambda col: 100 * (col - col.min()) / col.min() if col.dtype != 'O' else col)
    percent_diff.columns = [f"{col} Δ%" for col in df.columns]
    df = pd.concat([df, percent_diff], axis=1)

    return df

# Create and evaluate
model = SimpleModel()
compression_metrics_df = collect_compression_metrics(model)
compression_metrics_df
