In [1]:
from transformers import AutoModelForSequenceClassification, AutoTokenizer
from transformers import Trainer, TrainingArguments
from datasets import Dataset
import torch
import numpy as np
from sklearn.metrics import accuracy_score

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
import pandas as pd 
from datasets import load_dataset, Dataset
from transformers import AutoModelForSequenceClassification, AutoTokenizer
from transformers import Trainer, TrainingArguments
from peft import LoraConfig, get_peft_model
from sklearn.metrics import accuracy_score
import numpy as np
from sklearn.model_selection import train_test_split

In [3]:
def prepare_datasets(tokenizer, dataset_name="financial_phrasebank", subset_name="sentences_50agree", max_length=128, random_state=42):
    dataset = load_dataset(dataset_name, subset_name, trust_remote_code=True)
    
    df = pd.DataFrame(dataset['train'])

    # Stratify split into train, validation, and test
    train_texts, test_texts, train_labels, test_labels = train_test_split(
        df['sentence'], df['label'], test_size=0.2, stratify=df['label'], random_state=random_state
    )
    train_texts, val_texts, train_labels, val_labels = train_test_split(
        train_texts, train_labels, test_size=0.1, stratify=train_labels, random_state=random_state
    )

    # Create DataFrames for each split
    train_df = pd.DataFrame({'sentence': train_texts, 'label': train_labels})
    val_df = pd.DataFrame({'sentence': val_texts, 'label': val_labels})
    test_df = pd.DataFrame({'sentence': test_texts, 'label': test_labels})

    # Convert DataFrames to Hugging Face Dataset format
    train_dataset = Dataset.from_pandas(train_df)
    val_dataset = Dataset.from_pandas(val_df)
    test_dataset = Dataset.from_pandas(test_df)

    # Define tokenization function
    def tokenize_function(example):
        return tokenizer(
            example["sentence"], 
            padding="max_length", 
            truncation=True, 
            max_length=max_length
        )
     # Tokenize datasets
    train_dataset = train_dataset.map(tokenize_function, batched=True)
    val_dataset = val_dataset.map(tokenize_function, batched=True)
    test_dataset = test_dataset.map(tokenize_function, batched=True)

    # Remove raw text and prepare for Hugging Face Trainer
    train_dataset = train_dataset.remove_columns(["sentence"])
    val_dataset = val_dataset.remove_columns(["sentence"])
    test_dataset = test_dataset.remove_columns(["sentence"])

    train_dataset = train_dataset.rename_column("label", "labels")
    val_dataset = val_dataset.rename_column("label", "labels")
    test_dataset = test_dataset.rename_column("label", "labels")

    train_dataset.set_format("torch")
    val_dataset.set_format("torch")
    test_dataset.set_format("torch")

    return train_dataset, val_dataset, test_dataset

In [13]:
from transformers import AutoModelForSequenceClassification, AutoTokenizer

# Load the tokenizer for BERT
tokenizer = AutoTokenizer.from_pretrained("bert-base-uncased")

# Load the Teacher model (BERT)
teacher_model = AutoModelForSequenceClassification.from_pretrained("bert-base-uncased", num_labels=3)

# Load the Student model (DistilBERT)
student_model = AutoModelForSequenceClassification.from_pretrained("distilbert-base-uncased", num_labels=3)

Some weights of BertForSequenceClassification were not initialized from the model checkpoint at bert-base-uncased and are newly initialized: ['classifier.bias', 'classifier.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.
Some weights of DistilBertForSequenceClassification were not initialized from the model checkpoint at distilbert-base-uncased and are newly initialized: ['classifier.bias', 'classifier.weight', 'pre_classifier.bias', 'pre_classifier.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


In [19]:
train_dataset, val_dataset, test_dataset = prepare_datasets(tokenizer)

Map: 100%|██████████| 3488/3488 [00:00<00:00, 24431.47 examples/s]
Map: 100%|██████████| 388/388 [00:00<00:00, 22940.69 examples/s]
Map: 100%|██████████| 970/970 [00:00<00:00, 22919.05 examples/s]


In [31]:
import torch
import torch.nn.functional as F

def distillation_loss(student_logits, teacher_logits, true_labels, temperature=2.0, alpha=0.7):
    # Soft labels from teacher model
    soft_labels = F.softmax(teacher_logits / temperature, dim=-1)

    # Hard loss using true labels
    hard_loss = F.cross_entropy(student_logits, true_labels)

    # Soft loss between teacher and student logits
    soft_loss = F.kl_div(
        F.log_softmax(student_logits / temperature, dim=-1),
        soft_labels,
        reduction='batchmean'
    )

    # Weighted combination of soft and hard loss
    return alpha * soft_loss + (1.0 - alpha) * hard_loss



In [32]:
from torch.utils.data import DataLoader

def get_teacher_logits(model, dataset, batch_size=16, device='cpu'):
    dataloader = DataLoader(dataset, batch_size=batch_size)
    logits = []

    model.eval()  # Set teacher model to evaluation mode
    with torch.no_grad():
        for batch in dataloader:
            inputs = {k: v.to(device) for k, v in batch.items() if k in ['input_ids', 'attention_mask']}
            outputs = model(**inputs)
            logits.append(outputs.logits.cpu())

    return torch.cat(logits, dim=0)  # Concatenate all logits into one tensor


In [35]:
def distillation_train_loop(student_model, teacher_model, train_dataset, val_dataset, tokenizer, epochs=3, batch_size=16, learning_rate=5e-5, device='cpu'):
    student_model = student_model.to(device)
    teacher_model = teacher_model.to(device)

    train_dataloader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
    val_dataloader = DataLoader(val_dataset, batch_size=batch_size)

    optimizer = torch.optim.AdamW(student_model.parameters(), lr=learning_rate)

    # Teacher logits for the entire training dataset
    teacher_logits = get_teacher_logits(teacher_model, train_dataset, batch_size=batch_size, device=device)
    print(f"Teacher logits: {teacher_logits}")

    for epoch in range(epochs):
        student_model.train()
        total_loss = 0.0

        for i, batch in enumerate(train_dataloader):
            inputs = {k: v.to(device) for k, v in batch.items() if k in ['input_ids', 'attention_mask']}
            labels = batch['labels'].to(device)

            # Select corresponding teacher logits for this batch
            batch_teacher_logits = teacher_logits[i * batch_size:(i + 1) * batch_size].to(device)

            # Forward pass of student model
            student_outputs = student_model(**inputs)
            print(f"Student outputs: {student_outputs}")
            student_logits = student_outputs.logits
            print(f"Student logits: {student_logits}")

            # Compute loss
            loss = distillation_loss(student_logits, batch_teacher_logits, labels)
            total_loss += loss.item()

            # Backpropagation
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

        print(f"Epoch {epoch + 1}/{epochs}, Loss: {total_loss / len(train_dataloader):.4f}")

    return student_model



In [23]:
# Set up the training arguments for the teacher model
training_args = TrainingArguments(
    output_dir="./results_knowledge_distill/teacher_model",
    evaluation_strategy="epoch",
    learning_rate=5e-5,
    per_device_train_batch_size=16,
    per_device_eval_batch_size=16,
    num_train_epochs=3,
    logging_dir="./logs_knowledge_distill/teacher_model",
    logging_steps=10,
)

# Trainer for the teacher model
trainer_teacher = Trainer(
    model=teacher_model,
    args=training_args,
    train_dataset=train_dataset,
    eval_dataset=val_dataset,
    tokenizer=tokenizer
)

# Fine-tune the teacher model
trainer_teacher.train()


  trainer_teacher = Trainer(
  2%|▏         | 10/654 [00:22<15:48,  1.47s/it] 

{'loss': 0.8957, 'grad_norm': 4.335817337036133, 'learning_rate': 4.923547400611621e-05, 'epoch': 0.05}


  3%|▎         | 20/654 [00:35<11:03,  1.05s/it]

{'loss': 0.7797, 'grad_norm': 3.726905584335327, 'learning_rate': 4.847094801223242e-05, 'epoch': 0.09}


  5%|▍         | 30/654 [00:55<19:18,  1.86s/it]

{'loss': 0.7911, 'grad_norm': 3.6768481731414795, 'learning_rate': 4.7706422018348626e-05, 'epoch': 0.14}


  6%|▌         | 40/654 [01:14<20:44,  2.03s/it]

{'loss': 0.7659, 'grad_norm': 6.559085845947266, 'learning_rate': 4.694189602446483e-05, 'epoch': 0.18}


  8%|▊         | 50/654 [01:34<14:28,  1.44s/it]

{'loss': 0.6494, 'grad_norm': 5.637365341186523, 'learning_rate': 4.617737003058104e-05, 'epoch': 0.23}


  9%|▉         | 60/654 [02:01<34:59,  3.54s/it]

{'loss': 0.5958, 'grad_norm': 7.2245049476623535, 'learning_rate': 4.541284403669725e-05, 'epoch': 0.28}


 11%|█         | 70/654 [02:44<33:00,  3.39s/it]  

{'loss': 0.5212, 'grad_norm': 9.817317008972168, 'learning_rate': 4.4648318042813456e-05, 'epoch': 0.32}


 12%|█▏        | 80/654 [03:20<37:59,  3.97s/it]

{'loss': 0.5809, 'grad_norm': 14.302302360534668, 'learning_rate': 4.3883792048929664e-05, 'epoch': 0.37}


 14%|█▍        | 90/654 [03:51<25:06,  2.67s/it]

{'loss': 0.4667, 'grad_norm': 4.403275012969971, 'learning_rate': 4.311926605504588e-05, 'epoch': 0.41}


 15%|█▌        | 100/654 [04:09<13:33,  1.47s/it]

{'loss': 0.4744, 'grad_norm': 2.98917293548584, 'learning_rate': 4.235474006116208e-05, 'epoch': 0.46}


 17%|█▋        | 110/654 [04:27<14:55,  1.65s/it]

{'loss': 0.4692, 'grad_norm': 13.180750846862793, 'learning_rate': 4.159021406727829e-05, 'epoch': 0.5}


 18%|█▊        | 120/654 [04:39<08:44,  1.02it/s]

{'loss': 0.451, 'grad_norm': 4.209173679351807, 'learning_rate': 4.0825688073394495e-05, 'epoch': 0.55}


 20%|█▉        | 130/654 [04:47<07:08,  1.22it/s]

{'loss': 0.3658, 'grad_norm': 13.522235870361328, 'learning_rate': 4.00611620795107e-05, 'epoch': 0.6}


 21%|██▏       | 140/654 [04:56<08:09,  1.05it/s]

{'loss': 0.4707, 'grad_norm': 6.981060981750488, 'learning_rate': 3.929663608562692e-05, 'epoch': 0.64}


 23%|██▎       | 150/654 [05:05<07:03,  1.19it/s]

{'loss': 0.394, 'grad_norm': 3.078388214111328, 'learning_rate': 3.8532110091743125e-05, 'epoch': 0.69}


 24%|██▍       | 160/654 [05:15<08:03,  1.02it/s]

{'loss': 0.4116, 'grad_norm': 11.815357208251953, 'learning_rate': 3.7767584097859326e-05, 'epoch': 0.73}


 26%|██▌       | 170/654 [05:26<07:30,  1.07it/s]

{'loss': 0.4293, 'grad_norm': 8.353646278381348, 'learning_rate': 3.7003058103975534e-05, 'epoch': 0.78}


 28%|██▊       | 180/654 [05:34<06:24,  1.23it/s]

{'loss': 0.4091, 'grad_norm': 3.852647542953491, 'learning_rate': 3.623853211009174e-05, 'epoch': 0.83}


 29%|██▉       | 190/654 [05:47<11:29,  1.48s/it]

{'loss': 0.3439, 'grad_norm': 8.242115020751953, 'learning_rate': 3.5474006116207956e-05, 'epoch': 0.87}


 31%|███       | 200/654 [05:57<07:18,  1.03it/s]

{'loss': 0.4478, 'grad_norm': 11.006133079528809, 'learning_rate': 3.4709480122324164e-05, 'epoch': 0.92}


 32%|███▏      | 210/654 [06:07<07:36,  1.03s/it]

{'loss': 0.4146, 'grad_norm': 3.8564493656158447, 'learning_rate': 3.394495412844037e-05, 'epoch': 0.96}


                                                 
 33%|███▎      | 218/654 [06:22<06:04,  1.20it/s]

{'eval_loss': 0.39951249957084656, 'eval_runtime': 7.9863, 'eval_samples_per_second': 48.583, 'eval_steps_per_second': 3.13, 'epoch': 1.0}


 34%|███▎      | 220/654 [06:25<20:41,  2.86s/it]

{'loss': 0.4293, 'grad_norm': 5.919196605682373, 'learning_rate': 3.318042813455658e-05, 'epoch': 1.01}


 35%|███▌      | 230/654 [06:35<06:36,  1.07it/s]

{'loss': 0.2745, 'grad_norm': 12.205646514892578, 'learning_rate': 3.241590214067278e-05, 'epoch': 1.06}


 37%|███▋      | 240/654 [06:44<05:46,  1.19it/s]

{'loss': 0.2492, 'grad_norm': 19.635948181152344, 'learning_rate': 3.1651376146788995e-05, 'epoch': 1.1}


 38%|███▊      | 250/654 [06:52<05:59,  1.12it/s]

{'loss': 0.2927, 'grad_norm': 13.9088716506958, 'learning_rate': 3.08868501529052e-05, 'epoch': 1.15}


 40%|███▉      | 260/654 [07:01<05:22,  1.22it/s]

{'loss': 0.2574, 'grad_norm': 12.238587379455566, 'learning_rate': 3.012232415902141e-05, 'epoch': 1.19}


 41%|████▏     | 270/654 [07:09<05:22,  1.19it/s]

{'loss': 0.2473, 'grad_norm': 4.494692802429199, 'learning_rate': 2.9357798165137618e-05, 'epoch': 1.24}


 43%|████▎     | 280/654 [07:18<05:12,  1.20it/s]

{'loss': 0.2933, 'grad_norm': 5.556410789489746, 'learning_rate': 2.8593272171253826e-05, 'epoch': 1.28}


 44%|████▍     | 290/654 [07:26<05:01,  1.21it/s]

{'loss': 0.2478, 'grad_norm': 15.016253471374512, 'learning_rate': 2.782874617737003e-05, 'epoch': 1.33}


 46%|████▌     | 300/654 [07:35<04:48,  1.23it/s]

{'loss': 0.2413, 'grad_norm': 10.48573112487793, 'learning_rate': 2.7064220183486238e-05, 'epoch': 1.38}


 47%|████▋     | 310/654 [07:43<04:30,  1.27it/s]

{'loss': 0.2933, 'grad_norm': 8.468277931213379, 'learning_rate': 2.629969418960245e-05, 'epoch': 1.42}


 49%|████▉     | 320/654 [07:51<04:33,  1.22it/s]

{'loss': 0.2082, 'grad_norm': 11.91784954071045, 'learning_rate': 2.5535168195718656e-05, 'epoch': 1.47}


 50%|█████     | 330/654 [08:00<04:21,  1.24it/s]

{'loss': 0.1717, 'grad_norm': 6.597239017486572, 'learning_rate': 2.4770642201834864e-05, 'epoch': 1.51}


 52%|█████▏    | 340/654 [08:08<04:10,  1.25it/s]

{'loss': 0.2109, 'grad_norm': 6.263698577880859, 'learning_rate': 2.4006116207951072e-05, 'epoch': 1.56}


 54%|█████▎    | 350/654 [08:16<04:01,  1.26it/s]

{'loss': 0.2668, 'grad_norm': 3.51605486869812, 'learning_rate': 2.324159021406728e-05, 'epoch': 1.61}


 55%|█████▌    | 360/654 [08:25<03:58,  1.23it/s]

{'loss': 0.2358, 'grad_norm': 6.929905891418457, 'learning_rate': 2.2477064220183487e-05, 'epoch': 1.65}


 57%|█████▋    | 370/654 [08:32<03:38,  1.30it/s]

{'loss': 0.1737, 'grad_norm': 13.150646209716797, 'learning_rate': 2.1712538226299695e-05, 'epoch': 1.7}


 58%|█████▊    | 380/654 [08:59<08:46,  1.92s/it]

{'loss': 0.1887, 'grad_norm': 2.7542884349823, 'learning_rate': 2.0948012232415903e-05, 'epoch': 1.74}


 60%|█████▉    | 390/654 [09:17<08:21,  1.90s/it]

{'loss': 0.2438, 'grad_norm': 7.993503093719482, 'learning_rate': 2.018348623853211e-05, 'epoch': 1.79}


 61%|██████    | 400/654 [09:43<07:55,  1.87s/it]

{'loss': 0.2273, 'grad_norm': 1.731151819229126, 'learning_rate': 1.9418960244648318e-05, 'epoch': 1.83}


 63%|██████▎   | 410/654 [09:51<03:25,  1.19it/s]

{'loss': 0.189, 'grad_norm': 10.786277770996094, 'learning_rate': 1.8654434250764526e-05, 'epoch': 1.88}


 64%|██████▍   | 420/654 [10:00<03:27,  1.13it/s]

{'loss': 0.2369, 'grad_norm': 24.436796188354492, 'learning_rate': 1.7889908256880737e-05, 'epoch': 1.93}


 66%|██████▌   | 430/654 [10:09<03:13,  1.15it/s]

{'loss': 0.2379, 'grad_norm': 10.953409194946289, 'learning_rate': 1.712538226299694e-05, 'epoch': 1.97}


                                                 
 67%|██████▋   | 436/654 [10:20<03:09,  1.15it/s]

{'eval_loss': 0.42221856117248535, 'eval_runtime': 5.1733, 'eval_samples_per_second': 75.001, 'eval_steps_per_second': 4.833, 'epoch': 2.0}


 67%|██████▋   | 440/654 [10:25<05:57,  1.67s/it]

{'loss': 0.1247, 'grad_norm': 0.48542287945747375, 'learning_rate': 1.636085626911315e-05, 'epoch': 2.02}


 69%|██████▉   | 450/654 [10:35<03:32,  1.04s/it]

{'loss': 0.1219, 'grad_norm': 4.878985404968262, 'learning_rate': 1.559633027522936e-05, 'epoch': 2.06}


 70%|███████   | 460/654 [10:44<02:38,  1.22it/s]

{'loss': 0.0529, 'grad_norm': 0.5471055507659912, 'learning_rate': 1.4831804281345565e-05, 'epoch': 2.11}


 72%|███████▏  | 470/654 [10:56<04:36,  1.50s/it]

{'loss': 0.094, 'grad_norm': 2.234342336654663, 'learning_rate': 1.4067278287461774e-05, 'epoch': 2.16}


 73%|███████▎  | 480/654 [11:10<03:11,  1.10s/it]

{'loss': 0.0941, 'grad_norm': 2.1173996925354004, 'learning_rate': 1.3302752293577984e-05, 'epoch': 2.2}


 75%|███████▍  | 490/654 [11:23<03:31,  1.29s/it]

{'loss': 0.1095, 'grad_norm': 8.039071083068848, 'learning_rate': 1.253822629969419e-05, 'epoch': 2.25}


 76%|███████▋  | 500/654 [11:50<09:03,  3.53s/it]

{'loss': 0.0603, 'grad_norm': 0.2917783856391907, 'learning_rate': 1.1773700305810397e-05, 'epoch': 2.29}


 78%|███████▊  | 510/654 [12:28<06:34,  2.74s/it]

{'loss': 0.058, 'grad_norm': 3.5893125534057617, 'learning_rate': 1.1009174311926607e-05, 'epoch': 2.34}


 80%|███████▉  | 520/654 [12:48<04:15,  1.90s/it]

{'loss': 0.0454, 'grad_norm': 0.35865363478660583, 'learning_rate': 1.0244648318042814e-05, 'epoch': 2.39}


 81%|████████  | 530/654 [12:58<01:46,  1.16it/s]

{'loss': 0.0929, 'grad_norm': 1.6767597198486328, 'learning_rate': 9.480122324159022e-06, 'epoch': 2.43}


 83%|████████▎ | 540/654 [13:06<01:33,  1.21it/s]

{'loss': 0.065, 'grad_norm': 0.4124130606651306, 'learning_rate': 8.71559633027523e-06, 'epoch': 2.48}


 84%|████████▍ | 550/654 [13:59<06:46,  3.91s/it]

{'loss': 0.1427, 'grad_norm': 7.561567783355713, 'learning_rate': 7.951070336391438e-06, 'epoch': 2.52}


 86%|████████▌ | 560/654 [14:19<03:21,  2.15s/it]

{'loss': 0.1231, 'grad_norm': 0.09226620942354202, 'learning_rate': 7.186544342507645e-06, 'epoch': 2.57}


 87%|████████▋ | 570/654 [14:48<03:34,  2.56s/it]

{'loss': 0.0902, 'grad_norm': 2.3611152172088623, 'learning_rate': 6.422018348623854e-06, 'epoch': 2.61}


 89%|████████▊ | 580/654 [15:07<02:04,  1.69s/it]

{'loss': 0.1019, 'grad_norm': 0.11030992120504379, 'learning_rate': 5.657492354740062e-06, 'epoch': 2.66}


 90%|█████████ | 590/654 [15:25<01:34,  1.47s/it]

{'loss': 0.0148, 'grad_norm': 0.17759627103805542, 'learning_rate': 4.892966360856269e-06, 'epoch': 2.71}


 92%|█████████▏| 600/654 [15:37<00:56,  1.05s/it]

{'loss': 0.127, 'grad_norm': 0.4788254201412201, 'learning_rate': 4.128440366972477e-06, 'epoch': 2.75}


 93%|█████████▎| 610/654 [16:06<01:00,  1.37s/it]

{'loss': 0.1061, 'grad_norm': 13.228991508483887, 'learning_rate': 3.363914373088685e-06, 'epoch': 2.8}


 95%|█████████▍| 620/654 [16:22<00:40,  1.19s/it]

{'loss': 0.1632, 'grad_norm': 16.87421417236328, 'learning_rate': 2.599388379204893e-06, 'epoch': 2.84}


 96%|█████████▋| 630/654 [16:33<00:28,  1.18s/it]

{'loss': 0.1102, 'grad_norm': 2.434553861618042, 'learning_rate': 1.8348623853211011e-06, 'epoch': 2.89}


 98%|█████████▊| 640/654 [16:47<00:24,  1.78s/it]

{'loss': 0.0989, 'grad_norm': 0.21832436323165894, 'learning_rate': 1.0703363914373088e-06, 'epoch': 2.94}


 99%|█████████▉| 650/654 [16:58<00:04,  1.02s/it]

{'loss': 0.1114, 'grad_norm': 0.08660129457712173, 'learning_rate': 3.0581039755351683e-07, 'epoch': 2.98}


                                                 
100%|██████████| 654/654 [17:20<00:00,  1.59s/it]

{'eval_loss': 0.6311530470848083, 'eval_runtime': 7.7797, 'eval_samples_per_second': 49.873, 'eval_steps_per_second': 3.213, 'epoch': 3.0}
{'train_runtime': 1040.5886, 'train_samples_per_second': 10.056, 'train_steps_per_second': 0.628, 'train_loss': 0.2857098263611487, 'epoch': 3.0}





TrainOutput(global_step=654, training_loss=0.2857098263611487, metrics={'train_runtime': 1040.5886, 'train_samples_per_second': 10.056, 'train_steps_per_second': 0.628, 'total_flos': 688304700776448.0, 'train_loss': 0.2857098263611487, 'epoch': 3.0})

#### This function ensures the student model is trained using the logits from the teacher model in addition to the ground truth labels, combining hard and soft losses.

In [36]:
# Now, distill the knowledge into the student model
student_model = distillation_train_loop(student_model, teacher_model, train_dataset, val_dataset, tokenizer, epochs=3, batch_size=16)

Teacher logits: tensor([[-2.6213,  4.1635, -2.3194],
        [-1.5190, -1.7753,  4.8557],
        [-2.5256,  4.1303, -2.2848],
        ...,
        [-1.7173,  3.6361, -2.9695],
        [-1.8634, -1.4155,  4.7208],
        [-2.6480,  4.1266, -2.2814]])
Student outputs: SequenceClassifierOutput(loss=None, logits=tensor([[-1.4568,  1.6938, -0.6216],
        [-1.4918,  1.5309, -0.0303],
        [-1.6417,  1.1014,  0.4621],
        [-1.2778,  1.1750, -0.1223],
        [-1.7693,  0.3149,  1.4295],
        [-1.8186,  1.1852,  0.1105],
        [ 0.1959,  0.1122, -0.8044],
        [-2.2644,  0.2283,  1.6057],
        [-1.8654,  0.4816,  0.9782],
        [-0.0396,  0.2007, -0.8431],
        [-1.1631,  1.3751, -0.3029],
        [-1.7715,  1.5458,  0.2369],
        [-1.1664,  1.3201, -0.6355],
        [-1.6168, -0.0067,  1.1921],
        [-1.7620,  0.3772,  1.1147],
        [-1.8191,  1.2931,  0.3181]], grad_fn=<AddmmBackward0>), hidden_states=None, attentions=None)
Student logits: tensor([[-1.456

#### After the student model is trained, HuggingFace Trainer is used for evaluation. evaluate() method in Trainer does not perform training again, it only uses the student model in evaluation mode

In [37]:
trainer_student = Trainer(
    model=student_model,
    args=training_args,
    eval_dataset=test_dataset,
    tokenizer=tokenizer
)

# Evaluate the student model
results = trainer_student.evaluate()
print(f"Evaluation Results: {results}")


  trainer_student = Trainer(
100%|██████████| 61/61 [00:07<00:00,  7.97it/s]

Evaluation Results: {'eval_loss': 0.5431520938873291, 'eval_model_preparation_time': 0.0042, 'eval_runtime': 8.9747, 'eval_samples_per_second': 108.081, 'eval_steps_per_second': 6.797}





In [38]:
predictions = trainer_student.predict(test_dataset)

100%|██████████| 61/61 [00:07<00:00,  8.04it/s]


In [39]:
import numpy as np

predicted_labels = np.argmax(predictions.predictions, axis=-1)

true_labels = predictions.label_ids

In [41]:
from sklearn.metrics import accuracy_score, precision_score, recall_score, f1_score

accuracy = accuracy_score(true_labels, predicted_labels)
precision = precision_score(true_labels, predicted_labels, average='weighted')
recall = recall_score(true_labels, predicted_labels, average='weighted')
f1 = f1_score(true_labels, predicted_labels, average='weighted')

print(f"Accuracy: {accuracy:.4f}")
print(f"Precision: {precision:.4f}")
print(f"Recall: {recall:.4f}")
print(f"F1 Score: {f1:.4f}")

Accuracy: 0.8062
Precision: 0.8145
Recall: 0.8062
F1 Score: 0.7975
