In [None]:
# !pip install datasets transformers scikit-learn pandas numpy

import os
import random
import numpy as np
import torch
os.environ["OMP_NUM_THREADS"] = "14"
torch.set_num_threads(14)
from datasets import load_dataset, DatasetDict
from transformers import (set_seed, AutoTokenizer, AutoModelForSequenceClassification, TrainingArguments, Trainer)
from torch.utils.data import DataLoader
from sklearn.metrics import (confusion_matrix, accuracy_score, precision_score, recall_score, f1_score, roc_auc_score, 
                             average_precision_score, matthews_corrcoef, precision_recall_fscore_support)
SEED = 42
set_seed(SEED)
random.seed(SEED)
np.random.seed(SEED)

In [None]:
class IMDbDistillationPipeline:
    """
    The class sets up a pipeline for:
      1) Loading & splitting IMDb data (train, validation, test).
      2) Fine-tuning a teacher model (BERT).
      3) Distilling to a student model (DistilBERT).
      4) Comparing the results of the Teacher vs. Student.
    """

    def __init__(self, val_size=0.2):
        self.val_size = val_size
        self.raw_data = None
        self.imdb_data = None
        self.teacher_model = None
        self.teacher_tokenizer = None
        self.teacher_dataset = None
        self.teacher_test_metrics = None
        
        self.student_model = None
        self.student_tokenizer = None
        self.student_dataset = None
        self.distil_dataset = None
        self.student_test_metrics = None

    def load_and_split_imdb(self):
        """
        This function loads the IMDb dataset from Hugging Face and creates a train/val/test dataset.
        """
        raw = load_dataset("imdb")
        #Split original train into (train + val)
        split_train = raw["train"].train_test_split(test_size=self.val_size, seed=SEED)
        #Construct new DatasetDict with the validation set
        self.imdb_data = DatasetDict({
            "train": split_train["train"],
            "validation": split_train["test"],
            "test": raw["test"]
        })
        print(f"Train size: {len(self.imdb_data['train'])}, "
              f"Val size: {len(self.imdb_data['validation'])}, "
              f"Test size: {len(self.imdb_data['test'])}")

    def setup_teacher(self, model_name="bert-base-uncased"):
        """
        This function loads the teacher tokenizer & model.
        """
        self.teacher_tokenizer = AutoTokenizer.from_pretrained(model_name)
        self.teacher_model = AutoModelForSequenceClassification.from_pretrained(
            model_name, num_labels=2
        )

    def tokenizer_for_teacher(self, examples):
        """
        This is the tokenizer function for the Teacher model.
        """
        return self.teacher_tokenizer(
            examples["text"],
            padding="max_length",
            truncation=True,
            max_length=256
        )

    def prepare_teacher_data(self):
        """
        This function maps tokenization over the train, validation, test sets with multi-processing.
        It renames 'label' -> 'labels', removes 'text', and sets the format to torch.
        """
        t_data = self.imdb_data.map(
            self.tokenizer_for_teacher,
            batched=True,
            num_proc=14
        )
        t_data = t_data.rename_column("label", "labels")
        t_data = t_data.remove_columns(["text"])
        t_data.set_format("torch")
        self.teacher_dataset = t_data

    def compute_metrics_teacher(self, eval_pred):
        """
        This function is to simply compute the Teacher's metrics such as 
        accuracy, precision, recall, and f1.
        We can add more if needed.
        """
        logits, labels = eval_pred
        preds = np.argmax(logits, axis=-1)
        precision, recall, f1, _ = precision_recall_fscore_support(labels, preds, average="binary")
        acc = accuracy_score(labels, preds)
        return {
            "eval_accuracy": acc,
            "eval_precision": precision,
            "eval_recall": recall,
            "eval_f1": f1
        }

    def train_teacher(self, output_dir="./teacher_model_final", epochs=1):
        """
        This function is to train the Teacher model.
        """
        teacher_args = TrainingArguments(
            output_dir=output_dir,
            eval_strategy="epoch",
            save_strategy="no",
            num_train_epochs=epochs,
            per_device_train_batch_size=8,
            per_device_eval_batch_size=8,
            learning_rate=5e-5,
            logging_steps=100,
            disable_tqdm=False, #show progress bar
            seed=SEED
        )
        teacher_trainer = Trainer(
            model=self.teacher_model,
            args=teacher_args,
            train_dataset=self.teacher_dataset["train"],
            eval_dataset=self.teacher_dataset["validation"],
            compute_metrics=self.compute_metrics_teacher
        )

        teacher_trainer.train()
        val_metrics = teacher_trainer.evaluate(self.teacher_dataset["validation"])
        test_metrics = teacher_trainer.evaluate(self.teacher_dataset["test"])
        self.teacher_test_metrics = test_metrics
        self.teacher_model.save_pretrained(output_dir)
        self.teacher_tokenizer.save_pretrained(output_dir)

        print("\nTeacher Validation Metrics:", val_metrics)
        print("Teacher Test Metrics:", test_metrics)

    def setup_student(self, model_name="distilbert-base-uncased"):
        """
        This function loads the Student tokenizer & model.
        """
        self.student_tokenizer = AutoTokenizer.from_pretrained(model_name)
        self.student_model = AutoModelForSequenceClassification.from_pretrained(
            model_name, num_labels=2
        )

    def tokenizer_for_student(self, examples):
        """
        This is the tokenizer function for the Student model.
        """
        return self.student_tokenizer(
            examples["text"],
            padding="max_length",
            truncation=True,
            max_length=256
        )

    def prepare_student_data(self):
        """
        This function maps tokenization over the train, validation, test sets with multi-processing.
        It renames 'label' -> 'labels', removes 'text', and sets the format to torch.
        """
        s_data = self.imdb_data.map(
            self.tokenizer_for_student, 
            batched=True, 
            num_proc=14
        )
        s_data = s_data.rename_column("label", "labels")
        s_data = s_data.remove_columns(["text"])
        s_data.set_format("torch")
        self.student_dataset = s_data

    def get_teacher_logits(self, dataset, batch_size=8):
        """
        The function is to get the teacher logits for each sample in 'dataset'.
        It returns a 2D torch.Tensor: [num_samples, num_labels].
        """
        loader = DataLoader(dataset, batch_size=batch_size)
        all_logits = []
        self.teacher_model.eval()
        for batch in loader:
            with torch.no_grad():
                out = self.teacher_model(
                    input_ids=batch["input_ids"],
                    attention_mask=batch["attention_mask"]
                )
            all_logits.append(out.logits.cpu())
        return torch.cat(all_logits, dim=0)

    def create_distil_dataset(self):
        """
        This function creates a new dataset dict (train + val + test), 
        but only the train data has the 'teacher_logits' column for the knowledge distillation, the other ones do not.
        """
        teacher_logits_train = self.get_teacher_logits(self.student_dataset["train"], batch_size=8)
        logits_list = teacher_logits_train.numpy().tolist()
        train_with_logits = self.student_dataset["train"].add_column("teacher_logits", logits_list)
        
        self.distil_dataset = DatasetDict({
            "train": train_with_logits,
            "validation": self.student_dataset["validation"], #no teacher_logits
            "test": self.student_dataset["test"] #no teacher_logits
        })


    class DistillationTrainer(Trainer):
        """
        This is a trainer that does knowledge distillation if 'teacher_logits' is in the batch,
        else we will do standard cross entropy for for the val/test sets.
        """
        def __init__(self, alpha=0.5, temperature=2.0, *args, **kwargs):
            super().__init__(*args, **kwargs)
            self.alpha = alpha
            self.temperature = temperature

        def compute_loss(self, model, inputs, return_outputs=False, **kwargs):
            labels = inputs["labels"]
            outputs = model(
                input_ids=inputs["input_ids"],
                attention_mask=inputs["attention_mask"]
            )
            student_logits = outputs.logits
            
            if "teacher_logits" in inputs:
                teacher_logits = inputs["teacher_logits"].to(model.device)
                
                ce_loss = torch.nn.functional.cross_entropy(student_logits, labels)
                T = self.temperature
                student_soft = torch.nn.functional.log_softmax(student_logits / T, dim=-1)
                teacher_soft = torch.nn.functional.softmax(teacher_logits / T, dim=-1)
                kd_loss = torch.nn.functional.kl_div(
                    student_soft, teacher_soft, reduction="batchmean"
                ) * (T**2)

                loss = self.alpha * kd_loss + (1.0 - self.alpha) * ce_loss
                return (loss, outputs) if return_outputs else loss
            else:
                loss = torch.nn.functional.cross_entropy(student_logits, labels)
                return (loss, outputs) if return_outputs else loss

    def compute_metrics_student(self, eval_pred):
        """
        This function is to simply compute the Seacher's metrics such as 
        accuracy, precision, recall, and f1.
        We can add more if needed.
        """
        logits, labels = eval_pred
        preds = np.argmax(logits, axis=-1)
        precision, recall, f1, _ = precision_recall_fscore_support(
            labels, preds, average="binary"
        )
        acc = accuracy_score(labels, preds)
        return {
            "eval_accuracy": acc,
            "eval_precision": precision,
            "eval_recall": recall,
            "eval_f1": f1
        }

    def train_student(self, output_dir="./student_model_final", epochs=1, alpha=0.5, temperature=2.0):
        """
        This function is to train the Student with knowledge distillation.
        """
        student_args = TrainingArguments(
            output_dir=output_dir,
            eval_strategy="epoch",
            save_strategy="no",
            num_train_epochs=epochs,
            per_device_train_batch_size=8,
            per_device_eval_batch_size=8,
            learning_rate=5e-5,
            logging_steps=100,
            disable_tqdm=False,
            seed=SEED
        )
        distil_trainer = self.DistillationTrainer(
            alpha=alpha,
            temperature=temperature,
            model=self.student_model,
            args=student_args,
            train_dataset=self.distil_dataset["train"],
            eval_dataset=self.distil_dataset["validation"],
            tokenizer=self.student_tokenizer,
            compute_metrics=self.compute_metrics_student
        )
        distil_trainer.train()
        val_metrics = distil_trainer.evaluate(self.distil_dataset["validation"])
        test_metrics = distil_trainer.evaluate(self.distil_dataset["test"])
        self.student_test_metrics = test_metrics

        self.student_model.save_pretrained(output_dir)
        self.student_tokenizer.save_pretrained(output_dir)

        print("\nStudent Validation Metrics:", val_metrics)
        print("Student Test Metrics:", test_metrics)

    def compare_accuracy(self):
        """
        This function is to compare teacher and student test performance.
        We are using the 'eval_accuracy' in both dictionaries.
        """
        if self.teacher_test_metrics is None or self.student_test_metrics is None:
            print("Train teacher & student first.")
            return

        t_acc = self.teacher_test_metrics["eval_accuracy"]
        s_acc = self.student_test_metrics["eval_accuracy"]
        if t_acc == 0:
            ratio = 0.0
        else:
            ratio = (s_acc / t_acc) * 100

        print("Teacher Test Results:", self.teacher_test_metrics)
        print("Student Test Results:", self.student_test_metrics)
        print(f"The Student model retains about {ratio:.1f}% of the Teacher's accuracy.")

def main():
    """
    This function is the main function that runs the entire pipeline for teacher and student.
    """
    pipeline = IMDbDistillationPipeline(val_size=0.2)
    #Load & split the data we have,
    pipeline.load_and_split_imdb()
    #Teacher setup & train from BERT.
    pipeline.setup_teacher("bert-base-uncased")
    pipeline.prepare_teacher_data()
    pipeline.train_teacher(epochs=1)
    #Student setup & distill.
    pipeline.setup_student("distilbert-base-uncased")
    pipeline.prepare_student_data()
    pipeline.create_distil_dataset()
    pipeline.train_student(epochs=1, alpha=0.5, temperature=2.0)
    #Compare the results.
    pipeline.compare_accuracy()

main()

In [None]:
EXAMPLE OUTPUT: 

Train size: 20000, Val size: 5000, Test size: 25000
Some weights of BertForSequenceClassification were not initialized from the model checkpoint at bert-base-uncased and are newly initialized: ['classifier.bias', 'classifier.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.
Map (num_proc=14): 100%|██████████| 20000/20000 [00:31<00:00, 644.04 examples/s] 
Map (num_proc=14): 100%|██████████| 5000/5000 [00:29<00:00, 168.90 examples/s]
Map (num_proc=14): 100%|██████████| 25000/25000 [00:32<00:00, 777.77 examples/s] 
  0%|          | 0/250 [3:04:09<?, ?it/s]
  4%|▍         | 100/2500 [08:28<2:58:56,  4.47s/it]{'loss': 0.5869, 'grad_norm': 11.853706359863281, 'learning_rate': 4.8e-05, 'epoch': 0.04}
  8%|▊         | 200/2500 [16:54<3:43:54,  5.84s/it]{'loss': 0.4258, 'grad_norm': 9.650397300720215, 'learning_rate': 4.600000000000001e-05, 'epoch': 0.08}
 12%|█▏        | 300/2500 [25:48<2:45:48,  4.52s/it]{'loss': 0.4155, 'grad_norm': 16.848670959472656, 'learning_rate': 4.4000000000000006e-05, 'epoch': 0.12}
 16%|█▌        | 400/2500 [35:02<3:51:42,  6.62s/it]{'loss': 0.3737, 'grad_norm': 13.249711036682129, 'learning_rate': 4.2e-05, 'epoch': 0.16}
 20%|██        | 500/2500 [46:08<3:41:54,  6.66s/it]{'loss': 0.3945, 'grad_norm': 13.04289722442627, 'learning_rate': 4e-05, 'epoch': 0.2}
 24%|██▍       | 600/2500 [57:26<3:34:35,  6.78s/it]{'loss': 0.364, 'grad_norm': 10.95732307434082, 'learning_rate': 3.8e-05, 'epoch': 0.24}
 28%|██▊       | 700/2500 [1:08:49<3:28:37,  6.95s/it]{'loss': 0.3982, 'grad_norm': 0.9471220374107361, 'learning_rate': 3.6e-05, 'epoch': 0.28}
 32%|███▏      | 800/2500 [1:20:20<3:16:22,  6.93s/it]{'loss': 0.3302, 'grad_norm': 8.628236770629883, 'learning_rate': 3.4000000000000007e-05, 'epoch': 0.32}
 36%|███▌      | 900/2500 [1:31:58<3:06:08,  6.98s/it]{'loss': 0.3724, 'grad_norm': 14.761881828308105, 'learning_rate': 3.2000000000000005e-05, 'epoch': 0.36}
 40%|████      | 1000/2500 [1:40:18<1:57:36,  4.70s/it]{'loss': 0.3738, 'grad_norm': 29.557193756103516, 'learning_rate': 3e-05, 'epoch': 0.4}
 44%|████▍     | 1100/2500 [1:49:56<2:43:09,  6.99s/it]{'loss': 0.3462, 'grad_norm': 23.52771759033203, 'learning_rate': 2.8000000000000003e-05, 'epoch': 0.44}
 48%|████▊     | 1200/2500 [1:58:51<1:39:12,  4.58s/it]{'loss': 0.3688, 'grad_norm': 14.034320831298828, 'learning_rate': 2.6000000000000002e-05, 'epoch': 0.48}
 52%|█████▏    | 1300/2500 [2:08:30<2:20:24,  7.02s/it]{'loss': 0.2986, 'grad_norm': 7.83740234375, 'learning_rate': 2.4e-05, 'epoch': 0.52}
 56%|█████▌    | 1400/2500 [2:17:30<1:20:40,  4.40s/it]{'loss': 0.3507, 'grad_norm': 3.1949851512908936, 'learning_rate': 2.2000000000000003e-05, 'epoch': 0.56}
 60%|██████    | 1500/2500 [2:26:53<1:52:19,  6.74s/it]{'loss': 0.292, 'grad_norm': 0.4668058156967163, 'learning_rate': 2e-05, 'epoch': 0.6}
 64%|██████▍   | 1600/2500 [2:36:36<1:09:43,  4.65s/it]{'loss': 0.2971, 'grad_norm': 8.485663414001465, 'learning_rate': 1.8e-05, 'epoch': 0.64}
 68%|██████▊   | 1700/2500 [2:45:45<1:17:54,  5.84s/it]{'loss': 0.3364, 'grad_norm': 0.9123271107673645, 'learning_rate': 1.6000000000000003e-05, 'epoch': 0.68}
 72%|███████▏  | 1800/2500 [2:55:56<1:23:49,  7.19s/it]{'loss': 0.2954, 'grad_norm': 16.355836868286133, 'learning_rate': 1.4000000000000001e-05, 'epoch': 0.72}
 76%|███████▌  | 1900/2500 [3:05:11<45:41,  4.57s/it]  {'loss': 0.2849, 'grad_norm': 5.070781707763672, 'learning_rate': 1.2e-05, 'epoch': 0.76}
 80%|████████  | 2000/2500 [3:14:58<57:50,  6.94s/it]  {'loss': 0.2646, 'grad_norm': 10.113204956054688, 'learning_rate': 1e-05, 'epoch': 0.8}
 84%|████████▍ | 2100/2500 [3:24:44<27:56,  4.19s/it]{'loss': 0.2952, 'grad_norm': 29.44683265686035, 'learning_rate': 8.000000000000001e-06, 'epoch': 0.84}
 88%|████████▊ | 2200/2500 [3:34:09<31:48,  6.36s/it]{'loss': 0.272, 'grad_norm': 0.19008371233940125, 'learning_rate': 6e-06, 'epoch': 0.88}
 92%|█████████▏| 2300/2500 [3:44:11<16:44,  5.02s/it]{'loss': 0.2581, 'grad_norm': 36.615272521972656, 'learning_rate': 4.000000000000001e-06, 'epoch': 0.92}
 96%|█████████▌| 2400/2500 [3:53:24<09:49,  5.90s/it]{'loss': 0.2475, 'grad_norm': 38.397850036621094, 'learning_rate': 2.0000000000000003e-06, 'epoch': 0.96}
100%|██████████| 2500/2500 [4:03:38<00:00,  6.79s/it]{'loss': 0.3127, 'grad_norm': 13.887683868408203, 'learning_rate': 0.0, 'epoch': 1.0}

100%|██████████| 2500/2500 [4:22:57<00:00,  6.31s/it]
{'eval_accuracy': 0.909, 'eval_precision': 0.8967117988394584, 'eval_recall': 0.9249800478850758, 'eval_f1': 0.9106265959536437, 'eval_loss': 0.29959747195243835, 'eval_runtime': 1159.081, 'eval_samples_per_second': 4.314, 'eval_steps_per_second': 0.539, 'epoch': 1.0}
{'train_runtime': 15777.7598, 'train_samples_per_second': 1.268, 'train_steps_per_second': 0.158, 'train_loss': 0.34220633087158203, 'epoch': 1.0}
100%|██████████| 625/625 [19:31<00:00,  1.87s/it]
100%|██████████| 3125/3125 [1:37:04<00:00,  1.86s/it]

Teacher Validation Metrics: {'eval_accuracy': 0.909, 'eval_precision': 0.8967117988394584, 'eval_recall': 0.9249800478850758, 'eval_f1': 0.9106265959536437, 'eval_loss': 0.29959747195243835, 'eval_runtime': 1174.0596, 'eval_samples_per_second': 4.259, 'eval_steps_per_second': 0.532, 'epoch': 1.0}
Teacher Test Metrics: {'eval_accuracy': 0.91036, 'eval_precision': 0.8895724158882053, 'eval_recall': 0.93704, 'eval_f1': 0.9126894455916157, 'eval_loss': 0.29378241300582886, 'eval_runtime': 5827.2636, 'eval_samples_per_second': 4.29, 'eval_steps_per_second': 0.536, 'epoch': 1.0}
Some weights of DistilBertForSequenceClassification were not initialized from the model checkpoint at distilbert-base-uncased and are newly initialized: ['classifier.bias', 'classifier.weight', 'pre_classifier.bias', 'pre_classifier.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.
Map (num_proc=14): 100%|██████████| 20000/20000 [01:11<00:00, 278.53 examples/s]
Map (num_proc=14): 100%|██████████| 5000/5000 [01:02<00:00, 80.19 examples/s] 
Map (num_proc=14): 100%|██████████| 25000/25000 [00:48<00:00, 515.84 examples/s]
C:\Users\maril\AppData\Local\Temp\ipykernel_24392\2441323602.py:263: FutureWarning: `tokenizer` is deprecated and will be removed in version 5.0.0 for `DistillationTrainer.__init__`. Use `processing_class` instead.
  super().__init__(*args, **kwargs)
  4%|▍         | 100/2500 [05:18<2:07:45,  3.19s/it]{'loss': 0.5498, 'grad_norm': 11.526269912719727, 'learning_rate': 4.8e-05, 'epoch': 0.04}
  8%|▊         | 200/2500 [10:39<2:02:04,  3.18s/it]{'loss': 0.3948, 'grad_norm': 14.726052284240723, 'learning_rate': 4.600000000000001e-05, 'epoch': 0.08}
 12%|█▏        | 300/2500 [15:55<2:01:15,  3.31s/it]{'loss': 0.3972, 'grad_norm': 6.584609031677246, 'learning_rate': 4.4000000000000006e-05, 'epoch': 0.12}
 16%|█▌        | 400/2500 [21:16<1:51:50,  3.20s/it]{'loss': 0.3456, 'grad_norm': 15.88803482055664, 'learning_rate': 4.2e-05, 'epoch': 0.16}
 20%|██        | 500/2500 [26:36<1:46:08,  3.18s/it]{'loss': 0.4028, 'grad_norm': 15.243449211120605, 'learning_rate': 4e-05, 'epoch': 0.2}
 24%|██▍       | 600/2500 [31:54<1:40:25,  3.17s/it]{'loss': 0.362, 'grad_norm': 4.344246864318848, 'learning_rate': 3.8e-05, 'epoch': 0.24}
 28%|██▊       | 700/2500 [37:13<1:35:25,  3.18s/it]{'loss': 0.4089, 'grad_norm': 25.762807846069336, 'learning_rate': 3.6e-05, 'epoch': 0.28}
 32%|███▏      | 800/2500 [42:32<1:29:54,  3.17s/it]{'loss': 0.341, 'grad_norm': 9.24248218536377, 'learning_rate': 3.4000000000000007e-05, 'epoch': 0.32}
 36%|███▌      | 900/2500 [47:52<1:25:31,  3.21s/it]{'loss': 0.4602, 'grad_norm': 7.068543910980225, 'learning_rate': 3.2000000000000005e-05, 'epoch': 0.36}
 40%|████      | 1000/2500 [53:11<1:19:51,  3.19s/it]{'loss': 0.3297, 'grad_norm': 11.032055854797363, 'learning_rate': 3e-05, 'epoch': 0.4}
 44%|████▍     | 1100/2500 [57:42<1:02:42,  2.69s/it]{'loss': 0.3472, 'grad_norm': 13.0895414352417, 'learning_rate': 2.8000000000000003e-05, 'epoch': 0.44}
 48%|████▊     | 1200/2500 [1:02:27<49:28,  2.28s/it]  {'loss': 0.3801, 'grad_norm': 6.4442033767700195, 'learning_rate': 2.6000000000000002e-05, 'epoch': 0.48}
 52%|█████▏    | 1300/2500 [1:07:20<40:38,  2.03s/it]  {'loss': 0.3041, 'grad_norm': 0.9118344187736511, 'learning_rate': 2.4e-05, 'epoch': 0.52}
 56%|█████▌    | 1400/2500 [1:11:49<1:03:30,  3.46s/it]{'loss': 0.3498, 'grad_norm': 12.24240779876709, 'learning_rate': 2.2000000000000003e-05, 'epoch': 0.56}
 60%|██████    | 1500/2500 [1:16:14<48:23,  2.90s/it]  {'loss': 0.2927, 'grad_norm': 14.98419189453125, 'learning_rate': 2e-05, 'epoch': 0.6}
 64%|██████▍   | 1600/2500 [1:20:56<36:50,  2.46s/it]{'loss': 0.3233, 'grad_norm': 10.465763092041016, 'learning_rate': 1.8e-05, 'epoch': 0.64}
 68%|██████▊   | 1700/2500 [1:25:52<27:45,  2.08s/it]{'loss': 0.306, 'grad_norm': 1.3848687410354614, 'learning_rate': 1.6000000000000003e-05, 'epoch': 0.68}
 72%|███████▏  | 1800/2500 [1:30:49<41:19,  3.54s/it]{'loss': 0.2774, 'grad_norm': 14.78300666809082, 'learning_rate': 1.4000000000000001e-05, 'epoch': 0.72}
 76%|███████▌  | 1900/2500 [1:34:57<31:37,  3.16s/it]{'loss': 0.3097, 'grad_norm': 17.732311248779297, 'learning_rate': 1.2e-05, 'epoch': 0.76}
 80%|████████  | 2000/2500 [1:39:35<22:35,  2.71s/it]{'loss': 0.2797, 'grad_norm': 9.0259428024292, 'learning_rate': 1e-05, 'epoch': 0.8}
 84%|████████▍ | 2100/2500 [1:44:24<14:56,  2.24s/it]{'loss': 0.2965, 'grad_norm': 21.90822410583496, 'learning_rate': 8.000000000000001e-06, 'epoch': 0.84}
 88%|████████▊ | 2200/2500 [1:49:23<17:34,  3.52s/it]{'loss': 0.2851, 'grad_norm': 0.12936966121196747, 'learning_rate': 6e-06, 'epoch': 0.88}
 92%|█████████▏| 2300/2500 [1:53:39<11:09,  3.35s/it]{'loss': 0.2524, 'grad_norm': 6.887401580810547, 'learning_rate': 4.000000000000001e-06, 'epoch': 0.92}
 96%|█████████▌| 2400/2500 [1:58:18<04:43,  2.83s/it]{'loss': 0.2391, 'grad_norm': 23.580942153930664, 'learning_rate': 2.0000000000000003e-06, 'epoch': 0.96}
100%|██████████| 2500/2500 [2:03:04<00:00,  2.30s/it]{'loss': 0.3137, 'grad_norm': 10.13525390625, 'learning_rate': 0.0, 'epoch': 1.0}

100%|██████████| 2500/2500 [2:13:48<00:00,  3.21s/it]
{'eval_accuracy': 0.9, 'eval_precision': 0.893025078369906, 'eval_recall': 0.9094173982442139, 'eval_f1': 0.9011466982997232, 'eval_loss': 0.30669721961021423, 'eval_runtime': 644.0439, 'eval_samples_per_second': 7.763, 'eval_steps_per_second': 0.97, 'epoch': 1.0}
{'train_runtime': 8028.3603, 'train_samples_per_second': 2.491, 'train_steps_per_second': 0.311, 'train_loss': 0.34195306549072263, 'epoch': 1.0}
100%|██████████| 625/625 [10:26<00:00,  1.00s/it]
100%|██████████| 3125/3125 [50:24<00:00,  1.03it/s]

Student Validation Metrics: {'eval_accuracy': 0.9, 'eval_precision': 0.893025078369906, 'eval_recall': 0.9094173982442139, 'eval_f1': 0.9011466982997232, 'eval_loss': 0.30669721961021423, 'eval_runtime': 627.0832, 'eval_samples_per_second': 7.973, 'eval_steps_per_second': 0.997, 'epoch': 1.0}
Student Test Metrics: {'eval_accuracy': 0.90392, 'eval_precision': 0.8891629412671497, 'eval_recall': 0.92288, 'eval_f1': 0.9057077804820601, 'eval_loss': 0.2904188930988312, 'eval_runtime': 3025.1447, 'eval_samples_per_second': 8.264, 'eval_steps_per_second': 1.033, 'epoch': 1.0}

Teacher Test: {'eval_accuracy': 0.91036, 'eval_precision': 0.8895724158882053, 'eval_recall': 0.93704, 'eval_f1': 0.9126894455916157, 'eval_loss': 0.29378241300582886, 'eval_runtime': 5827.2636, 'eval_samples_per_second': 4.29, 'eval_steps_per_second': 0.536, 'epoch': 1.0}
Student Test: {'eval_accuracy': 0.90392, 'eval_precision': 0.8891629412671497, 'eval_recall': 0.92288, 'eval_f1': 0.9057077804820601, 'eval_loss': 0.2904188930988312, 'eval_runtime': 3025.1447, 'eval_samples_per_second': 8.264, 'eval_steps_per_second': 1.033, 'epoch': 1.0}
Student retains about 99.3% of the Teacher's accuracy.

In [None]:
teacher_trainer.args.disable_tqdm = True
distil_trainer.args.disable_tqdm = True

#Predictions for Teacher
teacher_result = teacher_trainer.predict(teacher_dataset["test"])
teacher_logits = teacher_result.predictions
teacher_labels = teacher_result.label_ids
teacher_preds = np.argmax(teacher_logits, axis=-1)
#probability of class=1 for AUC/PR-AUC:
teacher_probs = torch.softmax(torch.tensor(teacher_logits), dim=-1).numpy()[:, 1]

#Predictions for Student
student_result = distil_trainer.predict(distil_dataset["test"])
student_logits = student_result.predictions
student_labels = student_result.label_ids
student_preds = np.argmax(student_logits, axis=-1)
#probability of class=1 for AUC/PR-AUC:
student_probs = torch.softmax(torch.tensor(student_logits), dim=-1).numpy()[:, 1]

def compute_metrics(labels, preds, probs):
    cm = confusion_matrix(labels, preds)
    acc = accuracy_score(labels, preds)
    prec = precision_score(labels, preds)
    rec = recall_score(labels, preds)
    f1 = f1_score(labels, preds)
    
    #Handle AUC/PR-AUC if there's >1 label
    try:
        auc = roc_auc_score(labels, probs)
        pr_auc = average_precision_score(labels, probs)
    except ValueError:
        auc = float('nan')
        pr_auc = float('nan')
    
    mcc = matthews_corrcoef(labels, preds)
    return {
        "confusion_matrix": cm.tolist(),
        "accuracy": acc,
        "precision": prec,
        "recall": rec,
        "f1": f1,
        "roc_auc": auc,
        "pr_auc": pr_auc,
        "mcc": mcc
    }

teacher_metrics = compute_metrics(teacher_labels, teacher_preds, teacher_probs)
student_metrics = compute_metrics(student_labels, student_preds, student_probs)

print("\n TEACHER TEST METRICS")
for k, v in teacher_metrics.items():
    print(f"{k}: {v}")

print("\n STUDENT TEST METRICS")
for k, v in student_metrics.items():
    print(f"{k}: {v}")

In [None]:
EXAMPLE OUTPUT: 

3162it [2:25:05,  2.75s/it]                         
100%|██████████| 3125/3125 [49:14<00:00,  1.06it/s]  

TEACHER TEST METRICS 
confusion_matrix: [[10824, 1676], [1165, 11335]]
accuracy: 0.91036
precision: 0.8711859196064868
recall: 0.9068
f1: 0.8886362745482341
roc_auc: 0.9570383999999998
pr_auc: 0.956745012420531
mcc: 0.7733664853464441

STUDENT TEST METRICS
confusion_matrix: [[10898, 1602], [1529, 10971]]
accuracy: 0.90392
precision: 0.8725841088045813
recall: 0.87768
f1: 0.875124636062697
roc_auc: 0.9481288064
pr_auc: 0.9481079798749482
mcc: 0.7495327817416035

## Discussion of the Results

The final outcomes show that the teacher network (BERT) has a slightly higher degree of accuracy (approximately 91.03\%) compared 
to the student model (DistilBERT) at about 90.39\%. This gap of roughly one percentage point is interesting given that the teacher 
network is larger and therefore often more adept at capturing subtle patterns in the training data. I believe it is also 
important that the student model's accuracy is close to the teacher's accuracy, showing us that distillation transfers much 
of the teacher’s predictive potential while it is a lighter, more efficient model. As we calculated, the Student model retains 
about 99.3% of the Teacher's accuracy.

By looking at the other metrics, we can see that the teacher’s precision (87.12\%) and recall (90.68\%) converge to yield an F1 
score of approximately 88.86\% and that the student has a precision of 87.26\% and a recall of 87.77\%, for an F1 score 
of around 87.51\%. Although the teacher is higher than the student in recall, the student’s precision is still comparative, and 
again shows is that the distilled model retains much of the teacher’s capability. The corresponding confusion matrices show that
that the teacher model typically commits fewer false negatives, while the student model slightly reduces false positives.

In the AUC and PR-AUC results, both networks have high AUC values, where the teacher is roughly 0.957 and student is about 0.948, 
showing their robustness. Additionally, the teacher achieves an MCC of around 0.773, while the student registers about 0.750. This 
is interesting because although the teacher maintains a small advantage in predictive capability, the student still operates 
in a robust and reliable manner.

These metrics tell us that the distilled student model retains the core strength of the teacher model being able to achieve 
similar performance on most metrics while it has a reduced complexity. 