<a href="https://colab.research.google.com/github/jyotidabass/NLP-Projects/blob/main/Knowledge_distillation_and_Adapter_Modules.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [2]:
!pip install transformers

import torch
from transformers import AutoModelForSequenceClassification, AutoTokenizer
from sklearn.metrics import accuracy_score
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms

class MyDataset(Dataset):
    def __init__(self, texts, labels, tokenizer):
        self.texts = texts
        self.labels = labels
        self.tokenizer = tokenizer

    def __len__(self):
        return len(self.texts)

    def __getitem__(self, idx):
        text = self.texts[idx]
        label = self.labels[idx]

        encoding = self.tokenizer.encode_plus(
            text,
            max_length=512,
            padding='max_length',
            truncation=True,
            return_attention_mask=True,
            return_tensors='pt'
        )

        return {
            'input_ids': encoding['input_ids'].flatten(),
            'attention_mask': encoding['attention_mask'].flatten(),
            'labels': torch.tensor(label, dtype=torch.long)
        }

# Load pre-trained teacher model and tokenizer
teacher_model = AutoModelForSequenceClassification.from_pretrained('distilbert-base-uncased', num_labels=2) # Set num_labels
teacher_tokenizer = AutoTokenizer.from_pretrained('distilbert-base-uncased')

# Define the student model and adapter
class Adapter(torch.nn.Module):
    def __init__(self, input_dim, output_dim):
        super(Adapter, self).__init__()
        self.fc1 = torch.nn.Linear(input_dim, 128)
        self.fc2 = torch.nn.Linear(128, output_dim)

    def forward(self, x):
        x = torch.relu(self.fc1(x))
        x = self.fc2(x)
        return x

student_model = AutoModelForSequenceClassification.from_pretrained('distilbert-base-uncased', num_labels=2) # Set num_labels
adapter = Adapter(input_dim=768, output_dim=2) # Set output_dim to match num_labels

# Define the dataset and data loader
texts = ['This is a positive review.', 'This is a negative review.']
labels = [1, 0]

dataset = MyDataset(texts, labels, teacher_tokenizer)
data_loader = DataLoader(dataset, batch_size=2, shuffle=True)

# Train the student model with knowledge distillation
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
student_model.to(device)
adapter.to(device)
teacher_model.to(device) # Move teacher model to device

criterion = torch.nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(list(student_model.parameters()) + list(adapter.parameters()), lr=1e-5) # Include adapter parameters

for epoch in range(5):
    student_model.train()
    total_loss = 0
    for batch in data_loader:
        input_ids = batch['input_ids'].to(device)
        attention_mask = batch['attention_mask'].to(device)
        labels = batch['labels'].to(device)

        optimizer.zero_grad()

        teacher_output = teacher_model(input_ids, attention_mask=attention_mask).logits # Get logits from teacher

        # Get hidden states from student model instead of logits
        # Access hidden_states from SequenceClassifierOutput and then last_hidden_state
        student_output = student_model(input_ids, attention_mask=attention_mask, output_hidden_states=True)
        student_hidden_states = student_output.hidden_states[-1] # Get last hidden state


        # Get the CLS token embedding (which represents the sentence embedding)
        student_output = student_hidden_states[:, 0, :]

        adapter_output = adapter(student_output) # Pass student hidden states to adapter

        loss = criterion(adapter_output, labels)
        loss.backward()
        optimizer.step()

        total_loss += loss.item()

    print(f'Epoch {epoch+1}, Loss: {total_loss / len(data_loader)}')

# Evaluate the student model
student_model.eval()
with torch.no_grad():
    total_correct = 0
    for batch in data_loader:
        input_ids = batch['input_ids'].to(device)
        attention_mask = batch['attention_mask'].to(device)
        labels = batch['labels'].to(device)

        # Get hidden states from student model instead of logits
        # Access hidden_states from SequenceClassifierOutput and then last_hidden_state
        student_output = student_model(input_ids, attention_mask=attention_mask, output_hidden_states=True)
        student_hidden_states = student_output.hidden_states[-1] # Get last hidden state


        # Get the CLS token embedding (which represents the sentence embedding)
        student_output = student_hidden_states[:, 0, :]

        adapter_output = adapter(student_output) # Pass hidden states to adapter

        _, predicted = torch.max(adapter_output, dim=1)
        total_correct += (predicted == labels).sum().item()

    accuracy = total_correct / len(dataset)
    print(f'Accuracy: {accuracy:.4f}')



Some weights of DistilBertForSequenceClassification were not initialized from the model checkpoint at distilbert-base-uncased and are newly initialized: ['classifier.bias', 'classifier.weight', 'pre_classifier.bias', 'pre_classifier.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.
Some weights of DistilBertForSequenceClassification were not initialized from the model checkpoint at distilbert-base-uncased and are newly initialized: ['classifier.bias', 'classifier.weight', 'pre_classifier.bias', 'pre_classifier.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


Epoch 1, Loss: 0.6558558940887451
Epoch 2, Loss: 0.7094675898551941
Epoch 3, Loss: 0.6972615122795105
Epoch 4, Loss: 0.6804457306861877
Epoch 5, Loss: 0.7000874876976013
Accuracy: 0.5000
