In [None]:
!pip install datasets
!pip install -U transformers

In [None]:
from google.colab import drive
drive.mount('/content/drive')

In [None]:
import torch
import torch.optim as optim
from torch.utils.data import DataLoader
from transformers import (
    BertForSequenceClassification,
    BertTokenizer,
    AutoTokenizer,
    DataCollatorForLanguageModeling,
    DistilBertConfig,
    DataCollatorWithPadding,
    TrainingArguments,
    Trainer
)
from datasets import load_dataset



# Distillation using Masked Language Modelling

In [None]:
import torch
from torch.utils.data import DataLoader
from transformers import (
    AutoTokenizer,
    AutoModelForMaskedLM,
    DataCollatorForLanguageModeling,
    AutoModelForSequenceClassification,
)
import torch.optim as opt
from datasets import load_dataset
from tqdm import tqdm
import torch.nn.functional as F

In [None]:
# Config
TEACHER_MODEL = "bert-base-uncased"
STUDENT_MODEL = "distilbert-base-uncased"
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
BATCH_SIZE = 8
LR = 5e-5
EPOCHS = 3
MAX_LEN = 128
MLM_PROB = 0.15

In [None]:
# Load dataset
dataset = load_dataset("wikitext", "wikitext-2-raw-v1", split="train")
tokenizer = AutoTokenizer.from_pretrained(TEACHER_MODEL)

In [None]:
def tokenize_function(examples):
    return tokenizer(examples["text"], truncation=True, padding="max_length", max_length=MAX_LEN)

tokenized_dataset = dataset.map(tokenize_function, batched=True, remove_columns=["text"])
tokenized_dataset.set_format(type="torch", columns=["input_ids", "attention_mask"])

# Data collator will mask tokens for us
data_collator = DataCollatorForLanguageModeling(tokenizer=tokenizer, mlm=True, mlm_probability=MLM_PROB)
dataloader = DataLoader(tokenized_dataset, batch_size=BATCH_SIZE, shuffle=True, collate_fn=data_collator)

In [None]:
# Load teacher and student
teacher = AutoModelForMaskedLM.from_pretrained(TEACHER_MODEL).to(DEVICE)
student = AutoModelForMaskedLM.from_pretrained(STUDENT_MODEL).to(DEVICE)
teacher.eval()

In [None]:
def copy_every_other_layer(student, teacher):
    """
    Copy every other teacher layer into the student.
    Assumes student has half as many encoder layers as teacher.
    """
    teacher_layers = teacher.bert.encoder.layer
    student_layers = student.distilbert.transformer.layer

    # Copy every other teacher layer into student
    for i, layer in enumerate(student_layers):
        teacher_layer = teacher_layers[i * 2]
        layer.load_state_dict(teacher_layer.state_dict(), strict=False)

copy_every_other_layer(student, teacher)

In [None]:
print(type(student))


In [None]:
from torch.nn import CosineEmbeddingLoss

# Optional: distillation via KL divergence
def distill_loss(student_logits,
                 teacher_logits,
                 student_hidden,
                 teacher_hidden,
                 labels,
                 temperature=2.0,
                 alpha=0.5,
                 beta=0.5,
                 gamma=1.0):

    cosine_loss_fn = CosineEmbeddingLoss()

    # Only compute loss on masked tokens
    mask = labels != -100
    if mask.sum() == 0:
      print("mask sum is 0")
      return torch.tensor(0.0, device=student_logits.device, requires_grad=True)

    #### Masked LM Loss ####
    ce_loss = F.cross_entropy(student_logits.view(-1, student_logits.size(-1)), labels.view(-1), ignore_index=-100)

    student_logits = torch.clamp(student_logits, -1e4, 1e4)
    teacher_logits = torch.clamp(teacher_logits, -1e4, 1e4)

    # KL divergence between softened predictions
    student_probs = F.log_softmax(student_logits / temperature, dim=-1)
    teacher_probs = F.softmax(teacher_logits / temperature, dim=-1)

    kl_loss = F.kl_div(
        student_probs,
        teacher_probs,
        reduction="batchmean"
    ) * (temperature ** 2)

    # Cosine embedding loss on hidden states
    # Cosine embedding loss between intermediate hidden states
    losses = []
    teacher_layers = teacher_hidden[1::2]  # take every 2nd layer starting from index 1
    student_layers = student_hidden[1:]    # skip the student input embedding layer

    for t, s in zip(teacher_layers, student_layers):
        # Flatten hidden state in vectors of HIDDEN_STATE_SIZE for every token
        # in the batch (B, SEQ_LEN, HIDDEN_STATE_SIZE) -> (B x SEQ_LEN, HIDDEN_STATE_SIZE)
        t_flat = t.reshape(-1, t.size(-1))
        s_flat = s.reshape(-1, s.size(-1))

        # Create a tensor of ones (B x SEQ_LEN) as the target cosine similarity
        # between teacher and student hidden states
        target = torch.ones(t_flat.size(0)).to(t.device)
        losses.append(cosine_loss_fn(s_flat, t_flat, target))

    cos_loss = torch.stack(losses).mean()

    return alpha * ce_loss + beta * kl_loss + gamma * cos_loss

In [None]:
optimizer = opt.AdamW(student.parameters(), lr=LR)

In [None]:
# Training loop
for epoch in range(EPOCHS):
    student.train()
    total_loss = 0
    for batch in tqdm(dataloader, desc=f"Epoch {epoch+1}"):
        input_ids = batch["input_ids"].to(DEVICE)
        attention_mask = batch["attention_mask"].to(DEVICE)
        labels = batch["labels"].to(DEVICE)

        with torch.no_grad():
            teacher_outputs = teacher(
                input_ids=input_ids,
                attention_mask=attention_mask,
                output_hidden_states=True
            )
            teacher_logits = teacher_outputs.logits
            teacher_hidden_states = teacher_outputs.hidden_states

        student_outputs = student(
            input_ids=input_ids,
            attention_mask=attention_mask,
            output_hidden_states=True
        )
        student_logits = student_outputs.logits
        student_hidden_states = student_outputs.hidden_states
        loss = distill_loss(student_logits, teacher_logits,
                            student_hidden_states, teacher_hidden_states,
                            labels,
                            temperature=2.0, alpha=0.5, beta=0.5, gamma=1.0)

        if torch.isnan(loss) or torch.isinf(loss):
            print("Skipping batch due to NaN/Inf loss")
            continue

        optimizer.zero_grad()
        loss.backward()
        torch.nn.utils.clip_grad_norm_(student.parameters(), max_norm=1.0)
        optimizer.step()

        total_loss += loss.item()

    avg_loss = total_loss / len(dataloader)
    print(f"Epoch {epoch+1} | Loss: {avg_loss:.4f}")



student.save_pretrained("distilled-bert-small-cola-wikitext-used")
tokenizer.save_pretrained("distilled-bert-small-cola-wikitext-used")

save_path = "/content/drive/MyDrive/CS4782FinalProject/artifacts/cola-used/distilled-bert-small-wikitext"

student.save_pretrained(save_path)
tokenizer.save_pretrained(save_path)

# Finetune for CoLA



In [None]:
MODEL_PATH = "/content/drive/MyDrive/CS4782FinalProject/artifacts/cola-used/distilled-bert-small-wikitext"
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
BATCH_SIZE = 16
EPOCHS = 3
LR = 2e-5

In [None]:
cola = load_dataset("glue", "cola")

model_path="/content/drive/MyDrive/CS4782FinalProject/artifacts/cola-used/distilled-bert-small-wikitext"
tokenizer = AutoTokenizer.from_pretrained(model_path,local_files_only=True)

def preprocess(example):
    return tokenizer(example["sentence"], truncation=True, padding="max_length", max_length=256)

encoded_cola = cola.map(preprocess, batched=True)
encoded_cola.set_format(type="torch", columns=["input_ids", "attention_mask", "label"])

In [None]:
data_collator = DataCollatorWithPadding(tokenizer)
train_loader = DataLoader(encoded_cola["train"], shuffle=True, batch_size=BATCH_SIZE, collate_fn=data_collator)
test_loader = DataLoader(encoded_cola["validation"], batch_size=BATCH_SIZE, collate_fn=data_collator)

**Load the distilled student model and add a classification head to it.**

In [None]:
from transformers import AutoModelForSequenceClassification
model = AutoModelForSequenceClassification.from_pretrained(MODEL_PATH, num_labels=2).to(DEVICE)

In [None]:
training_args = TrainingArguments(
    output_dir="./results",
    eval_strategy="epoch",
    learning_rate=LR,
    per_device_train_batch_size=BATCH_SIZE,
    per_device_eval_batch_size=BATCH_SIZE,
    num_train_epochs=EPOCHS,
    weight_decay=0.01,
    save_strategy="epoch",
    logging_dir="./logs",
    logging_steps=100,
    report_to="none"
)

trainer = Trainer(
    model=model,
    args=training_args,
    train_dataset=encoded_cola["train"],
    eval_dataset=encoded_cola["validation"],
    tokenizer=tokenizer,
    data_collator=data_collator,
)

trainer.train()

In [None]:
save_path = "/content/drive/MyDrive/CS4782FinalProject/artifacts/cola-used/finetuned-distilled-cola"
model.save_pretrained(save_path)
tokenizer.save_pretrained(save_path)

In [None]:
# Load teacher model with classification head
teacher_model = AutoModelForSequenceClassification.from_pretrained(
    "bert-base-uncased",
    num_labels=2
).to(DEVICE)

teacher_training_args = TrainingArguments(
    output_dir="./results-teacher",
    eval_strategy="epoch",
    learning_rate=LR,
    per_device_train_batch_size=BATCH_SIZE,
    per_device_eval_batch_size=BATCH_SIZE,
    num_train_epochs=EPOCHS,
    weight_decay=0.01,
    save_strategy="epoch",
    logging_dir="./logs-teacher",
    logging_steps=100,
    report_to="none"
)

teacher_trainer = Trainer(
    model=teacher_model,
    args=teacher_training_args,
    train_dataset=encoded_cola["train"],
    eval_dataset=encoded_cola["validation"],
    tokenizer=tokenizer,  # Use student tokenizer since it was already applied to data
    data_collator=data_collator,
)

In [None]:
# === Train the teacher ===
teacher_trainer.train()

# === Evaluate the teacher ===
teacher_trainer.evaluate()

In [None]:
teacher_finetuned_path = "/content/drive/MyDrive/CS4782FinalProject/artifacts/cola-used/finetuned-bert-cola"
teacher_model.save_pretrained(teacher_finetuned_path)
tokenizer.save_pretrained(teacher_finetuned_path)

# Comparison of both finetuned models

In [None]:
student_model_path = "/content/drive/MyDrive/CS4782FinalProject/artifacts/cola-used/finetuned-distilled-cola"
teacher_model_path = "/content/drive/MyDrive/CS4782FinalProject/artifacts/cola-used/finetuned-bert-cola"

In [None]:
import torch
from transformers import AutoModelForSequenceClassification
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
student = AutoModelForSequenceClassification.from_pretrained(student_model_path, local_files_only=True).to(DEVICE)
teacher = AutoModelForSequenceClassification.from_pretrained(teacher_model_path, local_files_only=True).to(DEVICE)
tokenizer = AutoTokenizer.from_pretrained(student_model_path, local_files_only=True)

In [None]:
encoded_cola = cola.map(preprocess, batched=True)
encoded_cola.set_format(type="torch", columns=["input_ids", "attention_mask", "label"])

test_dataset = encoded_cola["validation"]

In [None]:
def custom_collate(batch):
    input_ids = torch.stack([item['input_ids'] for item in batch])
    attention_mask = torch.stack([item['attention_mask'] for item in batch])
    labels = torch.tensor([item['label'] for item in batch])

    return {
        'input_ids': input_ids,
        'attention_mask': attention_mask,
        'labels': labels
    }

test_loader = DataLoader(encoded_cola["validation"], batch_size=BATCH_SIZE, collate_fn=custom_collate)


In [None]:
from tqdm import tqdm
# data_collator = DataCollatorWithPadding(tokenizer)
# test_loader = DataLoader(test_dataset, batch_size=32, collate_fn=data_collator)

def evaluate(model, dataloader):
    model.eval()
    correct = 0
    total = 0
    with torch.no_grad():
        for batch in tqdm(dataloader):
            input_ids = batch["input_ids"].to(DEVICE)
            attention_mask = batch["attention_mask"].to(DEVICE)
            labels = batch["labels"].to(DEVICE)

            outputs = model(input_ids=input_ids, attention_mask=attention_mask)
            predictions = torch.argmax(outputs.logits, dim=-1)

            correct += (predictions == labels).sum().item()
            total += labels.size(0)

    return correct / total

# Evaluate
student_accuracy = evaluate(student, test_loader)
teacher_accuracy = evaluate(teacher, test_loader)

print(f"Student accuracy: {student_accuracy:.4f}\n")
print(f"Teacher accuracy: {teacher_accuracy:.4f}")


In [None]:
student_size = sum(p.numel() for p in student.parameters()) / 1e6
teacher_size = sum(p.numel() for p in teacher.parameters()) / 1e6

print(f"Student model size: {student_size:.2f}M parameters")
print(f"Teacher model size: {teacher_size:.2f}M parameters")

In [None]:
import os
import time
from sklearn.metrics import accuracy_score, precision_recall_fscore_support, confusion_matrix, matthews_corrcoef


def evaluate_model(model, dataloader, model_path=None):
    model.eval()
    preds = []
    labels = []
    start_time = time.time()

    with torch.no_grad():
        for batch in dataloader:
            input_ids = batch["input_ids"].to(DEVICE)
            attention_mask = batch["attention_mask"].to(DEVICE)
            batch_labels = batch["labels"].to(DEVICE)

            outputs = model(input_ids=input_ids, attention_mask=attention_mask)
            logits = outputs.logits

            preds.append(logits.argmax(dim=-1).cpu())
            labels.append(batch_labels.cpu())

    end_time = time.time()
    total_time = end_time - start_time

    preds = torch.cat(preds)
    labels = torch.cat(labels)

    acc = accuracy_score(labels, preds)
    precision, recall, f1, _ = precision_recall_fscore_support(labels, preds, average="binary")
    conf_matrix = confusion_matrix(labels, preds)
    mcc = matthews_corrcoef(labels, preds)

    param_count = sum(p.numel() for p in model.parameters())

    # if model_path is not None:
    #     model_size = os.path.getsize(os.path.join(model_path, "pytorch_model.bin")) / (1024**2)  # in MB

    return {
        "accuracy": acc,
        "precision": precision,
        "recall": recall,
        "f1": f1,
        "confusion_matrix": conf_matrix,
        "matthews_corrcoef": mcc,
        "inference_time_sec": total_time,
        "param_count": param_count,
    }


In [None]:
student_results = evaluate_model(student, test_loader, model_path=student_model_path)
teacher_results = evaluate_model(teacher, test_loader, model_path=teacher_model_path)


In [None]:
import pandas as pd

comparison_df = pd.DataFrame([
    {
        "Model": "Student",
        **student_results,
    },
    {
        "Model": "Teacher",
        **teacher_results,
    }
])


comparison_df = comparison_df[[
    "Model",
    "accuracy",
    "precision",
    "recall",
    "f1",
    "matthews_corrcoef",
    "inference_time_sec",
    "param_count"
]]


print(comparison_df.to_markdown(index=False))


In [None]:
save_path = "/content/drive/MyDrive/CS4782FinalProject/artifacts/cola-used/plots/comparison_results.csv"
comparison_df.to_csv(save_path, index=False)
print(f"Saved CSV to {save_path}")

In [None]:
import matplotlib.pyplot as plt
import os


plots_dir = "/content/drive/MyDrive/CS4782FinalProject/artifacts/cola-used/plots"
os.makedirs(plots_dir, exist_ok=True)


metrics = ["accuracy", "f1", "inference_time_sec", "param_count", "matthews_corrcoef"]

for metric in metrics:
    plt.figure(figsize=(6, 4))
    plt.bar(comparison_df["Model"], comparison_df[metric])
    plt.title(f"Comparison: {metric}")
    plt.ylabel(metric)
    plt.show()

    save_path = os.path.join(plots_dir, f"{metric}_comparison.png")
    plt.savefig(save_path)

    plt.close()

In [None]:
print(student_results)
print(teacher_results)