# Data

In [78]:
!pip install datasets



In [79]:
from datasets import load_dataset

In [80]:
dataset = load_dataset("imdb", split = "train[:25000]")

In [81]:
texts = dataset['text']
labels = dataset['label']

In [82]:
positive_indices = [i for i, label in enumerate(labels) if label == 1]
negative_indices = [i for i, label in enumerate(labels) if label == 0]

In [83]:
from sklearn.model_selection import train_test_split

In [84]:
from datasets import Dataset

In [85]:
train_pos, val_pos = train_test_split(positive_indices, test_size=0.999, random_state=42)
train_neg, val_neg = train_test_split(negative_indices, test_size=0.999, random_state=42)

In [86]:
train_indices = train_pos + train_neg
val_indices = val_pos + val_neg

In [87]:
train_dataset = dataset.select(train_indices)
val_dataset = dataset.select(val_indices)

In [88]:
train_dataset

Dataset({
    features: ['text', 'label'],
    num_rows: 24
})

In [89]:
train_df = train_dataset.to_pandas()
value_counts = train_df['label'].value_counts()
value_counts

1    12
0    12
Name: label, dtype: int64

In [90]:
texts = train_dataset['text']
labels = train_dataset['label']

# Pre Processing

## Removing special characters or HTML tags

In [91]:
from bs4 import BeautifulSoup
import re

In [92]:
#function to clean text and remove special characters and HTML tags

def clean_text(text):

    #Remove HTML tags
    bs = BeautifulSoup(text, "html.parser")
    text = bs.get_text(separator=" ")

    #remove special characters
    text = re.sub(r"[^a-zA-Z0-9]", " ", text)

    #remove unecessary whitespaces
    text = " ".join(text.split())

    return text

In [93]:
train_dataset = train_dataset.map(lambda training_sample: {"text": clean_text(training_sample["text"])})
# val_dataset = val_dataset.map(lambda training_sample: {"text": clean_text(training_sample["text"])})

Map:   0%|          | 0/24 [00:00<?, ? examples/s]

In [94]:
train_dataset

Dataset({
    features: ['text', 'label'],
    num_rows: 24
})

# Knowledge Distillation

In [99]:
import torch

In [100]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print("Using device:", device)

Using device: cpu


In [101]:
from torch.utils.data import Dataset, DataLoader

In [102]:
from tqdm import tqdm

In [103]:
from transformers import DistilBertForSequenceClassification, DistilBertConfig

In [104]:
config = DistilBertConfig.from_pretrained("distilbert-base-uncased")
student_model = DistilBertForSequenceClassification(config)

In [105]:
from transformers import BertTokenizer, BertForSequenceClassification, DistilBertForSequenceClassification, DistilBertConfig
from torch.utils.data import Dataset, DataLoader
import torch
import torch.nn as nn
from tqdm import tqdm
from torch.nn.functional import pad
import torch.optim as optim

In [106]:
teacher = BertForSequenceClassification.from_pretrained("bert-base-uncased")
tokenizer = BertTokenizer.from_pretrained("bert-base-uncased")

Some weights of BertForSequenceClassification were not initialized from the model checkpoint at bert-base-uncased and are newly initialized: ['classifier.weight', 'classifier.bias']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


In [107]:
class SentimentDataset(Dataset):

    def __init__(self, texts, labels, tokenizer, max_length):
        self.texts = texts
        self.labels = labels
        self.tokenizer = tokenizer
        self.max_length = max_length

    def __len__(self):
        return len(self.texts)

    def __getitem__(self, idx):
        text = self.texts[idx]
        label = self.labels[idx]


        inputs = self.tokenizer(
            text,
            truncation=True,
            padding=True,
            max_length=self.max_length,
            return_tensors="pt"
        )

        return {
            "input_ids": inputs["input_ids"].squeeze(),
            "attention_mask": inputs["attention_mask"].squeeze(),
            "labels": torch.tensor(label)
        }

In [108]:
max_sequence_length = 128
batch_size = 1

In [109]:
tokenizer = BertTokenizer.from_pretrained("bert-base-uncased", max_length=max_sequence_length)
sentiment_dataset = SentimentDataset(texts, labels, tokenizer, max_sequence_length)
dataloader = DataLoader(sentiment_dataset, batch_size=batch_size, shuffle=True)

In [110]:
config = DistilBertConfig.from_pretrained("distilbert-base-uncased")
student_model = DistilBertForSequenceClassification(config)

In [111]:
class KnowledgeDistillationLoss(nn.Module):
    def __init__(self, alpha=0.5, temperature=2):
        super(KnowledgeDistillationLoss, self).__init__()
        self.alpha = alpha
        self.temperature = temperature

    def forward(self, student_logits, teacher_logits):

        loss = nn.KLDivLoss()(
            nn.functional.log_softmax(student_logits / self.temperature, dim=1),
            nn.functional.softmax(teacher_logits / self.temperature, dim=1)
        ) * (self.temperature ** 2) * self.alpha

        return loss

In [112]:
optimizer = optim.AdamW(student_model.parameters(), lr=1e-5)
scheduler = optim.lr_scheduler.StepLR(optimizer, step_size=5, gamma=0.9)

In [113]:
kd_loss = KnowledgeDistillationLoss()

In [114]:
# Train the student model
num_epochs = 4
for epoch in range(num_epochs):
    for batch in tqdm(dataloader):
        # Load data and labels from the batch
        inputs = batch["input_ids"]
        labels = batch["labels"]

        # Forward pass for student model
        student_logits = student_model(inputs).logits

        # Forward pass for teacher model
        with torch.no_grad():
            teacher_logits = teacher(inputs).logits

        # Pad sequences to the maximum length in the batch
        max_length = max(inputs.size(1), teacher_logits.size(1), student_logits.size(1))
        inputs = pad(inputs, (0, max_length - inputs.size(1)))
        student_logits = pad(student_logits, (0, max_length - student_logits.size(1)))
        teacher_logits = pad(teacher_logits, (0, max_length - teacher_logits.size(1)))

        # Calculate knowledge distillation loss
        loss = kd_loss(student_logits, teacher_logits) + nn.CrossEntropyLoss()(student_logits, labels)

        # Backward pass and optimization
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()


    scheduler.step()

100%|██████████| 24/24 [01:07<00:00,  2.83s/it]
100%|██████████| 24/24 [00:47<00:00,  1.98s/it]
100%|██████████| 24/24 [00:49<00:00,  2.06s/it]
100%|██████████| 24/24 [00:49<00:00,  2.06s/it]


In [115]:
# Save or use the trained student model for inference
# student_model.save_pretrained("")