In [None]:
!pip install transformers
!pip install datasets

Collecting datasets
  Downloading datasets-2.16.1-py3-none-any.whl (507 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m507.1/507.1 kB[0m [31m3.9 MB/s[0m eta [36m0:00:00[0m
Collecting dill<0.3.8,>=0.3.0 (from datasets)
  Downloading dill-0.3.7-py3-none-any.whl (115 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m115.3/115.3 kB[0m [31m6.9 MB/s[0m eta [36m0:00:00[0m
Collecting multiprocess (from datasets)
  Downloading multiprocess-0.70.15-py310-none-any.whl (134 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m134.8/134.8 kB[0m [31m5.4 MB/s[0m eta [36m0:00:00[0m
Installing collected packages: dill, multiprocess, datasets
Successfully installed datasets-2.16.1 dill-0.3.7 multiprocess-0.70.15


In [None]:
from datasets import load_dataset
from torch.utils.data import Dataset
from torch.utils.data import DataLoader
import torch
from transformers import AutoModelForSequenceClassification
from torch import nn
from torch import optim
from torch.nn import functional as F
from transformers import AutoTokenizer
from tqdm import tqdm
from time import perf_counter
import pandas as pd

In [None]:
dataset_ckpt = 'ag_news'
teacher_model_ckpt = 'google/bert_uncased_L-2_H-128_A-2' #our already finetuned teacher model
student_model_ckpt = 'google/bert_uncased_L-12_H-256_A-4'



In [None]:
tokenizer = AutoTokenizer.from_pretrained(student_model_ckpt)

The secret `HF_TOKEN` does not exist in your Colab secrets.
To authenticate with the Hugging Face Hub, create a token in your settings tab (https://huggingface.co/settings/tokens), set it as secret in your Google Colab and restart your session.
You will be able to reuse this secret in all of your notebooks.
Please note that authentication is recommended but still optional to access public models or datasets.


config.json:   0%|          | 0.00/384 [00:00<?, ?B/s]

vocab.txt:   0%|          | 0.00/232k [00:00<?, ?B/s]

In [None]:
data = load_dataset(dataset_ckpt)
train_test = data['train'].train_test_split(test_size = 0.2)
valid_data = train_test['test']
train_data = train_test['train']
test_data = data['test']

def get_num_rows(dataset):
  return dataset.num_rows

print(f'Train set has {get_num_rows(train_data)} texts')
print(f'Valid set has {get_num_rows(valid_data)} texts')
print(f'Test set has {get_num_rows(test_data)} texts')

Downloading data:   0%|          | 0.00/18.6M [00:00<?, ?B/s]

Downloading data:   0%|          | 0.00/1.23M [00:00<?, ?B/s]

Generating train split:   0%|          | 0/120000 [00:00<?, ? examples/s]

Generating test split:   0%|          | 0/7600 [00:00<?, ? examples/s]

Train set has 96000 texts
Valid set has 24000 texts
Test set has 7600 texts


In [None]:
#now we would utilise pytorch's Dataset andDataloader classes to create our dataset

class MyData(Dataset):
  def __init__(self, data):
    targets = data['label']
    texts = data['text']

    tokens = tokenizer(texts, return_tensors = 'pt', truncation = True, padding = 'max_length', max_length = 150)
    self.input_ids = tokens['input_ids']
    self.attention_mask = tokens['attention_mask']
    self.targets = torch.tensor(targets)
    self.length = len(texts)
  def __len__(self):
    return self.length
  def __getitem__(self, index):
    return self.input_ids[index], self.attention_mask[index], self.targets[index]


train_data = MyData(train_data)
valid_data = MyData(valid_data)
test_data = MyData(test_data)

# now we build our loaders
batch_size = 64
train_loader = DataLoader(train_data,batch_size = batch_size)
valid_loader = DataLoader(valid_data, batch_size = batch_size)
test_loader = DataLoader(test_data, batch_size = batch_size)

In [None]:
#first we install define our device and download our teacher model from huggingface
device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')
teacher_model = AutoModelForSequenceClassification.from_pretrained(teacher_model_ckpt).to(device)


# we define a function to help us compute accuracy as we train, we would also define another function to measure time ellapsed
def accuracy_score(batch, model):
  with torch.no_grad():
    outputs = model(
        batch[0].to(device),
        batch[1].to(device)
    )
    logits = outputs.logits
    probabilities = torch.softmax(logits, dim = 1)
    class_predictions = torch.argmax(probabilities, dim = 1)
    acc = torch.mean((class_predictions == batch[2].to(device)).to(torch.float)).data.item()
    return acc

#now let us test!

accuracy = 0.0
time_taken = 0.0
count = 0
for batch in tqdm(test_loader):
  start_time = perf_counter()
  score = accuracy_score(batch, teacher_model)
  end_time = perf_counter()
  accuracy += score
  time_taken += end_time - start_time

print('\n\n')
print(f"number of samples in each batch is {len(batch[0])}")
print(f'number of batch is {len(test_loader)}')
print(f"accuracy is {accuracy / len(test_loader):.2f}")
print(f'time taken per batch is {time_taken / len(test_loader):.6f}')

config.json:   0%|          | 0.00/382 [00:00<?, ?B/s]

model.safetensors:   0%|          | 0.00/17.7M [00:00<?, ?B/s]

Some weights of BertForSequenceClassification were not initialized from the model checkpoint at google/bert_uncased_L-2_H-128_A-2 and are newly initialized: ['classifier.weight', 'classifier.bias']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.
100%|██████████| 119/119 [00:00<00:00, 145.47it/s]




number of samples in each batch is 48
number of batch is 119
accuracy is 0.24
time taken per batch is 0.006155





In [None]:
student_model = AutoModelForSequenceClassification.from_pretrained(student_model_ckpt, num_labels = 20).to(device)
student_model.dropout = nn.Dropout(0.3) #Increase dropout to improve generalization.
epochs = 5#we train for5epochs
learning_rate = 2e-5
entropy_loss = nn.CrossEntropyLoss() #cross entropy loss
temperature = 2.0
alpha = 0.5
criterion = nn.KLDivLoss(reduction = 'batchmean') #KL Divergence Loss
optimizer = optim.Adam(student_model.parameters(), lr = learning_rate)

pytorch_model.bin:   0%|          | 0.00/70.4M [00:00<?, ?B/s]

Some weights of BertForSequenceClassification were not initialized from the model checkpoint at google/bert_uncased_L-12_H-256_A-4 and are newly initialized: ['classifier.weight', 'classifier.bias']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


In [None]:
teacher_model


BertForSequenceClassification(
  (bert): BertModel(
    (embeddings): BertEmbeddings(
      (word_embeddings): Embedding(30522, 128, padding_idx=0)
      (position_embeddings): Embedding(512, 128)
      (token_type_embeddings): Embedding(2, 128)
      (LayerNorm): LayerNorm((128,), eps=1e-12, elementwise_affine=True)
      (dropout): Dropout(p=0.1, inplace=False)
    )
    (encoder): BertEncoder(
      (layer): ModuleList(
        (0-1): 2 x BertLayer(
          (attention): BertAttention(
            (self): BertSelfAttention(
              (query): Linear(in_features=128, out_features=128, bias=True)
              (key): Linear(in_features=128, out_features=128, bias=True)
              (value): Linear(in_features=128, out_features=128, bias=True)
              (dropout): Dropout(p=0.1, inplace=False)
            )
            (output): BertSelfOutput(
              (dense): Linear(in_features=128, out_features=128, bias=True)
              (LayerNorm): LayerNorm((128,), eps=1e-12, e

In [None]:
student_model

BertForSequenceClassification(
  (bert): BertModel(
    (embeddings): BertEmbeddings(
      (word_embeddings): Embedding(30522, 256, padding_idx=0)
      (position_embeddings): Embedding(512, 256)
      (token_type_embeddings): Embedding(2, 256)
      (LayerNorm): LayerNorm((256,), eps=1e-12, elementwise_affine=True)
      (dropout): Dropout(p=0.1, inplace=False)
    )
    (encoder): BertEncoder(
      (layer): ModuleList(
        (0-11): 12 x BertLayer(
          (attention): BertAttention(
            (self): BertSelfAttention(
              (query): Linear(in_features=256, out_features=256, bias=True)
              (key): Linear(in_features=256, out_features=256, bias=True)
              (value): Linear(in_features=256, out_features=256, bias=True)
              (dropout): Dropout(p=0.1, inplace=False)
            )
            (output): BertSelfOutput(
              (dense): Linear(in_features=256, out_features=256, bias=True)
              (LayerNorm): LayerNorm((256,), eps=1e-12,

In [None]:
def get_parameter_count(model):
  num_params = sum(p.numel() for p in model.parameters())
  return num_params

print(f'teacher model has {(get_parameter_count(teacher_model)/1000000):.2f} parameters')
print(f'student model has {(get_parameter_count(student_model)/1000000):.2f} parameters')

teacher model has 109.49 parameters
student model has 17.49 parameters


In [None]:
epochs = 4
learning_rate = 2e-4
entropy_loss = nn.CrossEntropyLoss()
temperature = 1.0
alpha = 0.5 #test this
criterion = nn.KLDivLoss()
optimizer = optim.AdamW(student_model.parameters(), lr = learning_rate)

In [None]:
import torch.nn.functional as F
kl_loss = nn.KLDivLoss(reduction="batchmean")

input = F.log_softmax(torch.randn(64, 20, requires_grad=True), dim=1)
target = F.softmax(torch.rand(64, 4), dim=1)

print(input.shape)
print(target.shape)

output = kl_loss(input, target)

print(output)


64, 20
# kl_loss = nn.KLDivLoss(reduction="batchmean", log_target=True)
# log_target = F.log_softmax(torch.rand(3, 5), dim=1)
# output = kl_loss(input, log_target)

torch.Size([64, 20])
torch.Size([64, 4])


RuntimeError: The size of tensor a (4) must match the size of tensor b (20) at non-singleton dimension 1

In [None]:
# Example of target with class indices
loss = nn.CrossEntropyLoss()
input = torch.randn(3, 5, requires_grad=True)
target = torch.empty(3, dtype=torch.long).random_(5)
output = loss(input, target)

print(input.shape)
print(target.shape)


input = torch.randn(3, 5, requires_grad=True)
target = torch.randn(3, 5).softmax(dim=1)
output = loss(input, target)



torch.Size([3, 5])
torch.Size([3])


In [None]:
import pandas as pd

# Lists to store training and validation metrics
training_loss_list = []
training_kd_loss_list = []
training_accuracy_list = []
valid_loss_list = []
valid_accuracy_list = []

#starting loop
for epoch in tqdm(range(epochs), total=epochs):
    student_model.train()
    train_loss = 0.0
    train_kd_loss = 0.0
    train_accuracy = 0.0
    valid_loss = 0.0
    valid_accuracy = 0.0

    for batch in train_loader:
        optimizer.zero_grad()
        input_ids = batch[0].to(device)
        attention_mask = batch[1].to(device)
        target_tensors = batch[2].to(device)

        # Student model predictions
        student_logits = student_model(input_ids=input_ids, attention_mask=attention_mask).logits
        ce_loss = entropy_loss(student_logits, target_tensors).data.item()

        # We extract teacher model logits
        with torch.no_grad():
            teacher_outputs = teacher_model(input_ids=input_ids, attention_mask=attention_mask)
            teacher_logits = teacher_outputs.logits


        print("student_logits: ", student_logits.shape)
        print("teacher_logits: ", teacher_logits.shape)

        # Knowledge distillation loss (KD divergence)
        kd_loss = temperature ** 2 * criterion(
            F.log_softmax(student_logits / temperature, dim=-1),
            F.softmax(teacher_logits / temperature, dim=-1)
        )

        # Combined loss
        loss = alpha * ce_loss + (1. - alpha) * kd_loss
        loss.backward()
        optimizer.step()

        # Update training metrics
        train_kd_loss += kd_loss.data.item()
        train_loss += loss
        accuracy = accuracy_score(batch, student_model)
        train_accuracy += accuracy

    student_model.eval()
    for batch in valid_loader:
        input_ids = batch[0].to(device)
        attention_mask = batch[1].to(device)
        target_tensors = batch[2].to(device)

        # Validation loss
        output = student_model(input_ids=input_ids, attention_mask=attention_mask)
        val_loss = entropy_loss(output.logits, target_tensors)
        valid_loss += val_loss.data.item()

        # Update validation accuracy
        accuracy = accuracy_score(batch, student_model)
        valid_accuracy += accuracy

    # Calculate average metrics
    train_accuracy /= len(train_loader)
    valid_accuracy /= len(valid_loader)
    train_loss /= len(train_loader)
    train_kd_loss /= len(train_loader)
    valid_loss /= len(valid_loader)

    # Append metrics to lists
    training_kd_loss_list.append(train_kd_loss)
    training_loss_list.append(train_loss.cpu().detach().numpy())
    training_accuracy_list.append(train_accuracy)
    valid_loss_list.append(valid_loss)
    valid_accuracy_list.append(valid_accuracy)

    # Print and store metrics
    print(f"""
    After epoch {epoch + 1}:
    Training loss (entropy): {train_loss}
    Kullback-Leibler (KL) divergence loss: {train_kd_loss}
    Validation loss (entropy): {valid_loss}
    Training accuracy: {train_accuracy}
    Validation accuracy: {valid_accuracy}
    """)

# Create a DataFrame to store the metrics
metrics = pd.DataFrame({
    'training_loss': training_loss_list,
    'training_kd_loss': training_kd_loss_list,
    'training_accuracy': training_accuracy_list,
    'valid_loss': valid_loss_list,
    'valid_accuracy': valid_accuracy_list
})

  0%|          | 0/4 [00:00<?, ?it/s]


student_logits:  torch.Size([64, 20])
teacher_logits:  torch.Size([64, 4])


RuntimeError: The size of tensor a (4) must match the size of tensor b (20) at non-singleton dimension 1

In [None]:
accuracy_teacher = 0.0
time_taken_teacher = 0.0

accuracy_student = 0.0
time_taken_student = 0.0
count = 0
for batch in tqdm(test_loader):
  start_time = perf_counter()
  score = accuracy_score(batch, teacher_model)
  end_time = perf_counter()
  accuracy_teacher += score
  time_taken_teacher += end_time - start_time

  start_time = perf_counter()
  score = accuracy_score(batch, student_model)
  end_time = perf_counter()
  accuracy_student += score
  time_taken_student += end_time - start_time


print('\n\n')
print(f"number of samples in each batch is {len(batch[0])}")
print(f'total number of batches is {len(test_loader)}')
print(f"teacher accuracy is {accuracy_teacher / len(test_loader):.2f}")
print(f'time taken per batch for teacher is {time_taken_teacher / len(test_loader):.6f}')
print('\n\n\n')
print(f"student accuracy is {accuracy_student / len(test_loader):.2f}")
print(f'time taken per batch for student is {time_taken_student / len(test_loader):.6f}')Running this code on a T5 GPU we get the following scores.