In [7]:
import pandas as pd
import numpy as np
from sklearn import metrics
import torch
from torch.utils.data import Dataset, DataLoader, RandomSampler, SequentialSampler
import transformers
from transformers import BertTokenizer, BertModel, BertConfig, DistilBertTokenizer, DistilBertForSequenceClassification, DistilBertConfig, DistilBertModel
from transformers import AutoModelForSequenceClassification, AutoTokenizer, AutoConfig
from transformers import TrainingArguments
import time
import torch.nn as nn
import torch.nn.functional as F
from transformers import TrainingArguments, Trainer
from sklearn.metrics import f1_score, roc_auc_score, accuracy_score
from transformers import EvalPrediction
import torch

In [3]:
# # Setting up the device for GPU usage
from torch import cuda
device = 'cuda' if cuda.is_available() else 'cpu'
print("Using device:", device)

Using device: cuda


## Loading Dataset for multi-class text classification

In [11]:
from datasets import load_dataset

dataset = load_dataset("sem_eval_2018_task_1", "subtask5.english")

In [12]:
len(dataset['train'])

6838

In [13]:
# prepare the labels dataset for inference
labels = [label for label in dataset['train'].features.keys() if label not in ['ID', 'Tweet']]
id2label = {idx:label for idx, label in enumerate(labels)}
label2id = {label:idx for idx, label in enumerate(labels)}
labels

['anger',
 'anticipation',
 'disgust',
 'fear',
 'joy',
 'love',
 'optimism',
 'pessimism',
 'sadness',
 'surprise',
 'trust']

In [14]:
tokenizer = AutoTokenizer.from_pretrained("bert-base-uncased")

def preprocess_data(examples):
  # take a batch of texts
  text = examples["Tweet"]
  # encode them
  encoding = tokenizer(text, padding="max_length", truncation=True, max_length=128)
  # add labels
  labels_batch = {k: examples[k] for k in examples.keys() if k in labels}
  # create numpy array of shape (batch_size, num_labels)
  labels_matrix = np.zeros((len(text), len(labels)))
  # fill numpy array
  for idx, label in enumerate(labels):
    labels_matrix[:, idx] = labels_batch[label]

  encoding["labels"] = labels_matrix.tolist()
  
  return encoding

In [15]:
# preprocess data
encoded_dataset = dataset.map(preprocess_data, batched=True, remove_columns=dataset['train'].column_names)
encoded_dataset.set_format("torch")

Map: 100%|██████████| 3259/3259 [00:00<00:00, 9402.75 examples/s]


In [16]:
batch_size = 8
metric_name = "f1"

In [17]:
from sklearn.metrics import f1_score, roc_auc_score, accuracy_score
from transformers import EvalPrediction
import torch
    
# source: https://jesusleal.io/2021/04/21/Longformer-multilabel-classification/
def multi_label_metrics(predictions, labels, threshold=0.5):
    # first, apply sigmoid on predictions which are of shape (batch_size, num_labels)
    sigmoid = torch.nn.Sigmoid()
    probs = sigmoid(torch.Tensor(predictions))
    # next, use threshold to turn them into integer predictions
    y_pred = np.zeros(probs.shape)
    y_pred[np.where(probs >= threshold)] = 1
    # finally, compute metrics
    y_true = labels
    f1_micro_average = f1_score(y_true=y_true, y_pred=y_pred, average='micro')
    roc_auc = roc_auc_score(y_true, y_pred, average = 'micro')
    accuracy = accuracy_score(y_true, y_pred)
    # return as dictionary
    metrics = {'f1': f1_micro_average,
               'roc_auc': roc_auc,
               'accuracy': accuracy}
    return metrics

def compute_metrics(p: EvalPrediction):
    preds = p.predictions[0] if isinstance(p.predictions, 
            tuple) else p.predictions
    result = multi_label_metrics(
        predictions=preds, 
        labels=p.label_ids)
    return result

## Distillation of BERT Model Using DistilBERT Architecture

In [18]:
fine_tuned_path = "../bert_base_model/bert-finetuned-sem_eval-english"

In [19]:
class KnowledgeDistillationTrainingArguments(TrainingArguments):
  def __init__(self, *args, alpha=0.5, temperature=2.0, **kwargs):
    #*args allows us to pass a variable number of non-keyword arguments to a Python function.
    #**kwargs stands for keyword arguments. The only difference from args is that it uses keywords and returns the values in the form of a dictionary.
    super().__init__(*args, **kwargs)
    #The super() function is often used with the __init__() method to initialize the attributes of the parent class.
    self.alpha = alpha
    self.temperature = temperature

class KnowledgeDistillationTrainer(Trainer):
  def __init__(self, *args, teacher_model=None, **kwargs):
    super().__init__(*args, **kwargs)
    self.teacher_model = teacher_model

  def compute_loss(self, model, inputs, return_outputs=False):
    #Extract cross-entropy loss and logits from student
    outputs_student = model(**inputs)
    loss_ce = outputs_student.loss
    logits_student = outputs_student.logits

    # Extract logits from teacher
    outputs_teacher = self.teacher_model(**inputs)
    logits_teacher = outputs_teacher.logits

     #Computing distillation loss by Softening probabilities
    loss_fct = nn.KLDivLoss(reduction="batchmean")
    #The reduction=batchmean argument in nn.KLDivLoss() specifies that we average the losses over the batch dimension.
    loss_kd = self.args.temperature ** 2 * loss_fct(
                F.log_softmax(logits_student / self.args.temperature, dim=-1),
                F.softmax(logits_teacher / self.args.temperature, dim=-1))

    # Return weighted student loss
    loss = self.args.alpha * loss_ce + (1. - self.args.alpha) * loss_kd
    return (loss, outputs_student) if return_outputs else loss

In [20]:
student_model_name = "distilbert-base-uncased"
student_tokenizer = DistilBertTokenizer.from_pretrained(student_model_name)
teacher_tokenizer = AutoTokenizer.from_pretrained(fine_tuned_path)

In [21]:
student_model = DistilBertForSequenceClassification.from_pretrained(student_model_name, 
                                                              problem_type="multi_label_classification", 
                                                              num_labels=len(labels),
                                                              id2label=id2label,
                                                              label2id=label2id).to(device)

teacher_model = AutoModelForSequenceClassification.from_pretrained(fine_tuned_path, 
                                                           problem_type="multi_label_classification", 
                                                           num_labels=len(labels),
                                                           id2label=id2label,
                                                           label2id=label2id).to(device)

pytorch_model.bin: 100%|██████████| 268M/268M [00:01<00:00, 211MB/s] 
Some weights of the model checkpoint at distilbert-base-uncased were not used when initializing DistilBertForSequenceClassification: ['vocab_transform.bias', 'vocab_projector.weight', 'vocab_layer_norm.bias', 'vocab_projector.bias', 'vocab_layer_norm.weight', 'vocab_transform.weight']
- This IS expected if you are initializing DistilBertForSequenceClassification from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing DistilBertForSequenceClassification from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).
Some weights of DistilBertForSequenceClassification were not initialized from the model checkpoint at distilbert-base-uncased and are newly

In [22]:
student_training_args = KnowledgeDistillationTrainingArguments(
    output_dir="./student_model", 
    evaluation_strategy = "epoch",
    save_strategy = "epoch",
    num_train_epochs=5, 
    learning_rate=2e-5,
    per_device_train_batch_size=batch_size,
    per_device_eval_batch_size=batch_size, 
    alpha=1, 
    weight_decay=0.01,
    metric_for_best_model=metric_name,
    load_best_model_at_end=True,
    )

In [89]:
#Lets start the training
start_time = time.time()
distilbert_trainer = KnowledgeDistillationTrainer(student_model,
        teacher_model=teacher_model, args=student_training_args,
        train_dataset=encoded_dataset["train"],
        eval_dataset=encoded_dataset["validation"],
        compute_metrics=compute_metrics, 
        tokenizer=student_tokenizer)
distilbert_trainer.train()
end_time = time.time()



Epoch,Training Loss,Validation Loss,F1,Roc Auc,Accuracy
1,0.2337,0.326479,0.689673,0.787747,0.265237
2,0.1923,0.340214,0.690739,0.792089,0.27088
3,0.1678,0.352306,0.684952,0.787725,0.244921
4,0.1586,0.364707,0.682781,0.78521,0.242664
5,0.1361,0.369166,0.68412,0.787296,0.238149


In [90]:
print("Training time For Distillation: ", end_time - start_time)

Training time For Distillation:  465.87510895729065


In [104]:
start_time = time.time()
print(distilbert_trainer.evaluate())
end_time = time.time()

{'eval_loss': 0.34021425247192383, 'eval_f1': 0.6907393760746744, 'eval_roc_auc': 0.7920888880212138, 'eval_accuracy': 0.2708803611738149, 'eval_runtime': 3.5856, 'eval_samples_per_second': 247.101, 'eval_steps_per_second': 30.957, 'epoch': 5.0}


In [105]:
print(f"Evaluation time: {end_time - start_time}")

Evaluation time: 3.588000535964966


## Comparing Performance to Baseline Model

In [106]:
param_size = 0
for param in teacher_model.parameters():
    param_size += param.nelement() * param.element_size()
buffer_size = 0
for buffer in teacher_model.buffers():
    buffer_size += buffer.nelement() * buffer.element_size()

teacher_size_all_mb = (param_size + buffer_size) / 1024**2
print('Teacher model size: {:.3f}MB'.format(teacher_size_all_mb))

Teacher model size: 417.682MB


In [108]:
param_size = 0
for param in student_model.parameters():
    param_size += param.nelement() * param.element_size()
buffer_size = 0
for buffer in student_model.buffers():
    buffer_size += buffer.nelement() * buffer.element_size()

student_size_all_mb = (param_size + buffer_size) / 1024**2
print('Student model size: {:.3f}MB'.format(student_size_all_mb))

Student model size: 255.443MB


In [109]:
print("Total compression: {:.1f}x".format(teacher_size_all_mb / student_size_all_mb))

Total compression: 1.6x
