In [1]:
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader
from transformers import AutoModelForSequenceClassification, AutoTokenizer, AdamW
from datasets import load_dataset

In [2]:
teacher_models = [
    "bigscience/bloomz-560m",
    "EleutherAI/pythia-410m",
    "facebook/opt-350m"
]
student_model_name = "EleutherAI/gpt-neo-125m"

#loading student model and tokenizer
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
student = AutoModelForSequenceClassification.from_pretrained(student_model_name, num_labels=2).to(device)
student_tokenizer = AutoTokenizer.from_pretrained(student_model_name)

if student_tokenizer.pad_token is None:
    student_tokenizer.pad_token = student_tokenizer.eos_token
    student.config.pad_token_id = student_tokenizer.pad_token_id
    student.resize_token_embeddings(len(student_tokenizer))

config.json:   0%|          | 0.00/1.01k [00:00<?, ?B/s]

model.safetensors:   0%|          | 0.00/526M [00:00<?, ?B/s]

Some weights of GPTNeoForSequenceClassification were not initialized from the model checkpoint at EleutherAI/gpt-neo-125m and are newly initialized: ['score.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


tokenizer_config.json:   0%|          | 0.00/727 [00:00<?, ?B/s]

vocab.json:   0%|          | 0.00/899k [00:00<?, ?B/s]

merges.txt:   0%|          | 0.00/456k [00:00<?, ?B/s]

tokenizer.json:   0%|          | 0.00/2.11M [00:00<?, ?B/s]

special_tokens_map.json:   0%|          | 0.00/357 [00:00<?, ?B/s]

In [3]:
#imdb sentiment dataset loading and keeping only first 5000 rows
dataset = load_dataset("imdb")
train_texts = dataset["train"]["text"][:5000]
train_labels = dataset["train"]["label"][:5000]

README.md:   0%|          | 0.00/7.81k [00:00<?, ?B/s]

train-00000-of-00001.parquet:   0%|          | 0.00/21.0M [00:00<?, ?B/s]

test-00000-of-00001.parquet:   0%|          | 0.00/20.5M [00:00<?, ?B/s]

unsupervised-00000-of-00001.parquet:   0%|          | 0.00/42.0M [00:00<?, ?B/s]

Generating train split:   0%|          | 0/25000 [00:00<?, ? examples/s]

Generating test split:   0%|          | 0/25000 [00:00<?, ? examples/s]

Generating unsupervised split:   0%|          | 0/50000 [00:00<?, ? examples/s]

In [4]:
#data tokenization for student model
def tokenize_data(texts, tokenizer):
    return tokenizer(texts, padding=True, truncation=True, return_tensors="pt", max_length=512)

train_encodings = tokenize_data(train_texts, student_tokenizer)

In [5]:
#data preprocessing
class IMDbDataset(torch.utils.data.Dataset):
    def __init__(self, texts, encodings, labels):
        self.texts = texts  #keeping raw text for teacher models
        self.encodings = encodings
        self.labels = labels

    def __len__(self):
        return len(self.labels)

    def __getitem__(self, idx):
        item = {key: val[idx] for key, val in self.encodings.items()}
        item["labels"] = torch.tensor(self.labels[idx])
        item["text"] = self.texts[idx]  #including raw text for teacher models
        return item

train_dataset = IMDbDataset(train_texts, train_encodings, train_labels)
train_loader = DataLoader(train_dataset, batch_size=16, shuffle=True)

In [6]:
ce_loss = nn.CrossEntropyLoss()
temperature = 2.0 
optimizer = AdamW(student.parameters(), lr=2e-5)



In [7]:
num_epochs = 3
for epoch in range(num_epochs):
    student.train()
    total_loss = 0

    for batch in train_loader:
        input_ids = batch["input_ids"].to(device)
        attention_mask = batch["attention_mask"].to(device)
        labels = batch["labels"].to(device)
        texts = batch["text"]  #extracting raw text for teacher models

        avg_teacher_logits = None
        teacher_count = 0

        #computing teacher predictions one by one to avoid memory overflow
        for teacher_model_name in teacher_models:
            teacher = AutoModelForSequenceClassification.from_pretrained(teacher_model_name, num_labels=2).to(device)
            teacher_tokenizer = AutoTokenizer.from_pretrained(teacher_model_name)

            if teacher_tokenizer.pad_token is None:
                teacher_tokenizer.pad_token = teacher_tokenizer.eos_token
                teacher.config.pad_token_id = teacher_tokenizer.pad_token_id
                teacher.resize_token_embeddings(len(teacher_tokenizer))

            with torch.no_grad():
                teacher_inputs = teacher_tokenizer(
                    texts,  #passing raw text here
                    padding=True, truncation=True, return_tensors="pt", max_length=512
                )
                teacher_inputs = {key: val.to(device) for key, val in teacher_inputs.items()}
                teacher_logits = teacher(**teacher_inputs).logits

            if avg_teacher_logits is None:
                avg_teacher_logits = teacher_logits
            else:
                avg_teacher_logits += teacher_logits

            teacher_count += 1

            #free memory
            del teacher
            torch.cuda.empty_cache()

        avg_teacher_logits /= teacher_count
        
        #distillation loss
        student_logits = student(input_ids, attention_mask=attention_mask).logits
        distillation_loss = nn.KLDivLoss(reduction="batchmean")(
            torch.log_softmax(student_logits / temperature, dim=-1),
            torch.softmax(avg_teacher_logits / temperature, dim=-1)
        )

        #student loss with cross-entropy
        ce_loss_value = ce_loss(student_logits, labels)

        #combine losses
        loss = distillation_loss + ce_loss_value
        total_loss += loss.item()

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

    print(f"epoch {epoch + 1}, loss: {total_loss / len(train_loader)}")

print("complete")

config.json:   0%|          | 0.00/715 [00:00<?, ?B/s]

model.safetensors:   0%|          | 0.00/1.12G [00:00<?, ?B/s]

Some weights of BloomForSequenceClassification were not initialized from the model checkpoint at bigscience/bloomz-560m and are newly initialized: ['score.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


tokenizer_config.json:   0%|          | 0.00/222 [00:00<?, ?B/s]

tokenizer.json:   0%|          | 0.00/14.5M [00:00<?, ?B/s]

special_tokens_map.json:   0%|          | 0.00/85.0 [00:00<?, ?B/s]

config.json:   0%|          | 0.00/570 [00:00<?, ?B/s]

model.safetensors:   0%|          | 0.00/911M [00:00<?, ?B/s]

The `GPTNeoXSdpaAttention` class is deprecated in favor of simply modifying the `config._attn_implementation`attribute of the `GPTNeoXAttention` class! It will be removed in v4.48
Some weights of GPTNeoXForSequenceClassification were not initialized from the model checkpoint at EleutherAI/pythia-410m and are newly initialized: ['score.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


tokenizer_config.json:   0%|          | 0.00/396 [00:00<?, ?B/s]

tokenizer.json:   0%|          | 0.00/2.11M [00:00<?, ?B/s]

special_tokens_map.json:   0%|          | 0.00/99.0 [00:00<?, ?B/s]

config.json:   0%|          | 0.00/644 [00:00<?, ?B/s]

pytorch_model.bin:   0%|          | 0.00/663M [00:00<?, ?B/s]

Some weights of OPTForSequenceClassification were not initialized from the model checkpoint at facebook/opt-350m and are newly initialized: ['score.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


tokenizer_config.json:   0%|          | 0.00/685 [00:00<?, ?B/s]

vocab.json:   0%|          | 0.00/899k [00:00<?, ?B/s]

merges.txt:   0%|          | 0.00/456k [00:00<?, ?B/s]

special_tokens_map.json:   0%|          | 0.00/441 [00:00<?, ?B/s]

Some weights of BloomForSequenceClassification were not initialized from the model checkpoint at bigscience/bloomz-560m and are newly initialized: ['score.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.
Some weights of GPTNeoXForSequenceClassification were not initialized from the model checkpoint at EleutherAI/pythia-410m and are newly initialized: ['score.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.
Some weights of OPTForSequenceClassification were not initialized from the model checkpoint at facebook/opt-350m and are newly initialized: ['score.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.
Some weights of BloomForSequenceClassification were not initialized from the model checkpoint at bigscience/bloomz-560m and are newly initialized: ['score.weight']
You should proba

epoch 1, loss: 0.8798576749551791


Some weights of BloomForSequenceClassification were not initialized from the model checkpoint at bigscience/bloomz-560m and are newly initialized: ['score.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.
Some weights of GPTNeoXForSequenceClassification were not initialized from the model checkpoint at EleutherAI/pythia-410m and are newly initialized: ['score.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.
Some weights of OPTForSequenceClassification were not initialized from the model checkpoint at facebook/opt-350m and are newly initialized: ['score.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.
Some weights of BloomForSequenceClassification were not initialized from the model checkpoint at bigscience/bloomz-560m and are newly initialized: ['score.weight']
You should proba

epoch 2, loss: 0.8614833957661455


Some weights of BloomForSequenceClassification were not initialized from the model checkpoint at bigscience/bloomz-560m and are newly initialized: ['score.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.
Some weights of GPTNeoXForSequenceClassification were not initialized from the model checkpoint at EleutherAI/pythia-410m and are newly initialized: ['score.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.
Some weights of OPTForSequenceClassification were not initialized from the model checkpoint at facebook/opt-350m and are newly initialized: ['score.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.
Some weights of BloomForSequenceClassification were not initialized from the model checkpoint at bigscience/bloomz-560m and are newly initialized: ['score.weight']
You should proba

epoch 3, loss: 0.8379311943396973
complete
