In [1]:
import torch
import evaluate

from datasets import load_dataset
from transformers import AutoTokenizer, AutoModelForSequenceClassification, AutoModel, AutoConfig
from transformers.modeling_outputs import SequenceClassifierOutput

from torch import nn
from torch.optim import AdamW

from tqdm import tqdm

dataset = load_dataset("glue", "cola")



In [2]:
len(dataset["train"]), len(dataset["validation"])

(8551, 1043)

In [3]:

class Distillation(nn.Module):
  def __init__(self, teacher, student, num_labels=2): 
    super(Distillation, self).__init__() 
    self.teacher = AutoModelForSequenceClassification.from_pretrained(
        teacher,
        config=AutoConfig.from_pretrained(teacher, output_attentions=True,output_hidden_states=True, num_labels=num_labels))
    self.student = AutoModelForSequenceClassification.from_pretrained(
        student,
        config=AutoConfig.from_pretrained(student, output_attentions=True,output_hidden_states=True, num_labels=num_labels))
    
    self.kl_div_loss = nn.KLDivLoss(reduction="batchmean", log_target=True)
    self.mse_loss = nn.MSELoss()
    self.eps = 1e-8

  def forward(self, input_ids=None, attention_mask=None, labels=None, who=None):
    assert who in (None, "teacher", "student"), f"invalid argument {who=}"
    if who == "teacher":
        return self.teacher(input_ids=input_ids, attention_mask=attention_mask, labels=labels)
    if who == "student":
        return self.student(input_ids=input_ids, attention_mask=attention_mask, labels=labels)

    t_outputs = self.teacher(input_ids=input_ids, attention_mask=attention_mask, labels=labels)  # type: SequenceClassifierOutput
    s_outputs = self.student(input_ids=input_ids, attention_mask=attention_mask, labels=labels)  # type: SequenceClassifierOutput
    # task loss
    t_task_loss, t_logits, t_hidden_states, t_attentions = t_outputs.to_tuple()
    s_task_loss, s_logits, s_hidden_states, s_attentions = s_outputs.to_tuple()

    # KL-divergence loss
    s_kl_loss = self.kl_div_loss(t_logits.detach(), s_logits) / (s_task_loss.detach() + t_task_loss.detach() + self.eps)
    t_kl_loss = self.kl_div_loss(s_logits.detach(), t_logits) / (s_task_loss.detach() + t_task_loss.detach() + self.eps)

    # attention loss
    attention_layer_ratio = len(t_attentions) // len(s_attentions)

    # TODO if the attention dimensions don't line up for some pair of teacher/student model we want to evaluate (like bert-large vs bert)
    # then we need to apply a learnable matrix to the teacher attention to line up the dimensions before applying the mse loss
    assert t_attentions[0].size() == s_attentions[0].size(), f"teacher and student attention dimensions don't match, {t_attentions[0].size()=}, {s_attentions[0].size()=}"
    mse_loss = 0.
    for s_layer in range(len(s_attentions)):
        t_layer = s_layer * 2
        mse_loss += self.mse_loss(s_attentions[-s_layer - 1], t_attentions[-t_layer - 1])
        # adding the hidden-layer loss blows up the training
        # self.mse_loss(s_hidden_states[-s_layer - 1], t_hidden_states[-t_layer - 1])
    mse_loss = mse_loss / (s_task_loss + t_task_loss + self.eps)
    # outputs = last_hidden_state=[batch_size, sentence_length, features/embeddings] hidden_states=7 attentions=6

    total_loss = s_task_loss + t_task_loss + mse_loss + s_kl_loss + t_kl_loss
#     total_loss = s_task_loss + t_task_loss + s_kl_loss + t_kl_loss + mse_loss
    total_logits = s_outputs.logits + t_outputs.logits
    if torch.any(total_logits.isnan()):
        raise RuntimeError("logits blew up")
    return SequenceClassifierOutput(loss=total_loss, logits=total_logits)

In [4]:
tokenizer = AutoTokenizer.from_pretrained("bert-base-uncased")
def encode(examples):
    return tokenizer(examples["sentence"], truncation=True, padding="max_length")
dataset = dataset.map(encode, batched=True)
dataset = dataset.map(lambda examples: {"labels": examples["label"]}, batched=True)

dataset.set_format(type="torch", columns=["input_ids", "attention_mask", "labels"])  # bert
train_dataset = dataset["train"]
test_dataset = dataset["validation"]
train_dataloader = torch.utils.data.DataLoader(train_dataset, batch_size=8)
test_dataloader = torch.utils.data.DataLoader(test_dataset, batch_size=8)

In [5]:
model = Distillation(teacher="bert-base-uncased", student="distilbert-base-uncased")

Some weights of BertForSequenceClassification were not initialized from the model checkpoint at bert-base-uncased and are newly initialized: ['classifier.weight', 'classifier.bias']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.
Some weights of DistilBertForSequenceClassification were not initialized from the model checkpoint at distilbert-base-uncased and are newly initialized: ['pre_classifier.weight', 'classifier.weight', 'pre_classifier.bias', 'classifier.bias']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


In [6]:
optimizer = AdamW(model.parameters(), lr=5e-5)
from transformers import get_scheduler
num_epochs = 3

num_training_steps = num_epochs * len(train_dataloader)

lr_scheduler = get_scheduler(
    name="linear", optimizer=optimizer, num_warmup_steps=0, num_training_steps=num_training_steps
)
device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
model.to(device);

In [7]:
progress_bar = tqdm(range(num_training_steps))
model.train()
for epoch in range(num_epochs):
    for batch in train_dataloader:
        batch = {k: v.to(device) for k, v in batch.items()}
        outputs = model(**batch)
        loss = outputs.loss
        loss.backward()
        optimizer.step()
        lr_scheduler.step()
        optimizer.zero_grad()
        progress_bar.update(1)

100%|██████████| 3207/3207 [24:57<00:00,  2.18it/s]

In [8]:
metric = evaluate.load("accuracy")
model.eval()
for batch in test_dataloader:
    batch = {k: v.to(device) for k, v in batch.items()}
    with torch.no_grad():
        outputs = model(**batch, who="student")
    logits = outputs.logits
    predictions = torch.argmax(logits, dim=-1)
    metric.add_batch(predictions=predictions, references=batch["labels"])

metric.compute()

{'accuracy': 0.7574304889741131}

In [9]:

metric = evaluate.load("accuracy")
model.eval()
for batch in test_dataloader:
    batch = {k: v.to(device) for k, v in batch.items()}
    with torch.no_grad():
        outputs = model(**batch, who="teacher")
    logits = outputs.logits
    predictions = torch.argmax(logits, dim=-1)
    metric.add_batch(predictions=predictions, references=batch["labels"])

metric.compute()

{'accuracy': 0.8034515819750719}

In [10]:
def create_predictor(model, who="teacher"):
  def predict_proba(text1):
    encodings = tokenizer([text1], truncation=True, padding="max_length")
    encodings = encodings.convert_to_tensors("pt")
    encodings = {k: v.to(device) for k, v in encodings.items()}
    with torch.no_grad():
      outputs = model(input_ids=encodings["input_ids"], attention_mask=encodings["attention_mask"], who=who)
      logits = outputs.logits
      predictions = torch.argmax(logits, dim=-1)
      return predictions.cpu().numpy()[0]

  return predict_proba

In [11]:
predictor = create_predictor(model, who="student")

predictor("He beat the dead horse.")

1