##Required Imports

In [6]:
from transformers import TrainingArguments
import torch.nn as nn
import torch.nn.functional as F
from transformers import Trainer
from datasets import load_dataset
from transformers import AutoTokenizer
import numpy as np
from datasets import load_metric
from transformers import AutoConfig, AutoModelForSequenceClassification
import os
from torch.nn.utils.rnn import pad_sequence
from nlpaug.augmenter.word import SynonymAug
from torch.utils.data import DataLoader,Dataset
from transformers import pipeline
import torch
from transformers import AutoModelForSequenceClassification

#Defining training functions

In [7]:
class KnowledgeDistillationTrainingArguments(TrainingArguments):
  def __init__(self, *args, alpha=0.5, temperature=2.0, **kwargs):
    super().__init__(*args, **kwargs)
    self.alpha = alpha
    self.temperature = temperature

In [8]:
class KnowledgeDistillationTrainer(Trainer):
  def __init__(self, *args, teacher_model=None, **kwargs):
    super().__init__(*args, **kwargs)
    self.teacher_model = teacher_model

  def compute_loss(self, model, inputs, return_outputs=False):
    outputs_student = model(**inputs)
    loss_ce = outputs_student.loss
    logits_student = outputs_student.logits

    outputs_teacher = self.teacher_model(**inputs)
    logits_teacher = outputs_teacher.logits

    loss_fct = nn.KLDivLoss(reduction="batchmean")
    loss_kd = self.args.temperature ** 2 * loss_fct(
                F.log_softmax(logits_student / self.args.temperature, dim=-1),
                F.softmax(logits_teacher / self.args.temperature, dim=-1))

    loss = self.args.alpha * loss_ce + (1. - self.args.alpha) * loss_kd
    return (loss, outputs_student) if return_outputs else loss


#Loading the dataset

In [9]:
dataset = load_dataset("carblacac/twitter-sentiment-analysis")

The secret `HF_TOKEN` does not exist in your Colab secrets.
To authenticate with the Hugging Face Hub, create a token in your settings tab (https://huggingface.co/settings/tokens), set it as secret in your Google Colab and restart your session.
You will be able to reuse this secret in all of your notebooks.
Please note that authentication is recommended but still optional to access public models or datasets.
You can avoid this message in future by passing the argument `trust_remote_code=True`.
Passing `trust_remote_code=True` will be mandatory to load this dataset from the next major release of `datasets`.


Downloading builder script:   0%|          | 0.00/4.38k [00:00<?, ?B/s]

Downloading metadata:   0%|          | 0.00/2.06k [00:00<?, ?B/s]

Downloading readme:   0%|          | 0.00/5.44k [00:00<?, ?B/s]

Downloading data:   0%|          | 0.00/5.38M [00:00<?, ?B/s]

Downloading data:   0%|          | 0.00/2.23M [00:00<?, ?B/s]

Generating train split: 0 examples [00:00, ? examples/s]

Generating test split: 0 examples [00:00, ? examples/s]

Map:   0%|          | 0/149985 [00:00<?, ? examples/s]

Map:   0%|          | 0/61998 [00:00<?, ? examples/s]

Creating json from Arrow format:   0%|          | 0/120 [00:00<?, ?ba/s]

Creating json from Arrow format:   0%|          | 0/30 [00:00<?, ?ba/s]

Creating json from Arrow format:   0%|          | 0/62 [00:00<?, ?ba/s]

Generating train split:   0%|          | 0/119988 [00:00<?, ? examples/s]

Generating validation split:   0%|          | 0/29997 [00:00<?, ? examples/s]

Generating test split:   0%|          | 0/61998 [00:00<?, ? examples/s]

#Tokenizer

In [10]:
student_checkpoint = "prajjwal1/bert-tiny"
student_tokenizer = AutoTokenizer.from_pretrained(student_checkpoint)

config.json:   0%|          | 0.00/285 [00:00<?, ?B/s]

vocab.txt:   0%|          | 0.00/232k [00:00<?, ?B/s]

#Data Augmentation

In [30]:
def custom_collate_fn(batch):
    input_ids = [item["input_ids"] for item in batch]
    attention_mask = [item["attention_mask"] for item in batch]
    labels = [item["label"] for item in batch]
    input_ids = pad_sequence(input_ids, batch_first=True)
    attention_mask = pad_sequence(attention_mask, batch_first=True)

    return {"input_ids": input_ids, "attention_mask": attention_mask, "label": torch.stack(labels)}

In [32]:
class TwitterSentimentDataset_aug(Dataset):
    def __init__(self, data, tokenizer, max_length=128, augmenter=None):
        self.data = data
        self.tokenizer = tokenizer
        self.max_length = max_length
        self.augmenter = augmenter

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        item = self.data[idx]
        text = item["text"]
        if self.augmenter is not None:
            text = self.augmenter.augment(text)

        inputs = self.tokenizer(
            text,
            padding=True,
            truncation=True,
            max_length=self.max_length,
            return_tensors="pt"
        )
        label = torch.tensor(item["feeling"], dtype=torch.long)
        return {
            "input_ids": inputs["input_ids"].squeeze(),
            "attention_mask": inputs["attention_mask"].squeeze(),
            "label": label
        }
augmenter = SynonymAug()
train_dataset_aug = TwitterSentimentDataset_aug(dataset["train"], student_tokenizer, max_length=128,augmenter=augmenter)

[nltk_data] Downloading package wordnet to /root/nltk_data...
[nltk_data] Downloading package omw-1.4 to /root/nltk_data...


In [11]:
def tokenize_text(batch):
  return student_tokenizer(batch["text"], truncation=True)

clinc_tokenized = dataset.map(tokenize_text, batched=True, remove_columns=["text"])
clinc_tokenized = clinc_tokenized.rename_column("feeling", "labels")

Map:   0%|          | 0/119988 [00:00<?, ? examples/s]

Asking to truncate to max_length but no maximum length is provided and the model has no predefined maximum length. Default to no truncation.


Map:   0%|          | 0/29997 [00:00<?, ? examples/s]

Map:   0%|          | 0/61998 [00:00<?, ? examples/s]

In [46]:
accuracy_score = load_metric("accuracy")
def compute_metrics(pred):
  predictions, labels = pred
  predictions = np.argmax(predictions, axis=1)

  return accuracy_score.compute(predictions=predictions, references=labels)

You can avoid this message in future by passing the argument `trust_remote_code=True`.
Passing `trust_remote_code=True` will be mandatory to load this metric from the next major release of `datasets`.


#Loading the finetuned model

In [13]:
batch_size = 48
finetuned_student_ckpt = "tinybert-base-uncased-finetuned-twitter-student"

In [14]:
student_training_args = KnowledgeDistillationTrainingArguments(
    output_dir=finetuned_student_ckpt, evaluation_strategy = "epoch",
    num_train_epochs=1, learning_rate=2e-5,
    per_device_train_batch_size=batch_size,
    per_device_eval_batch_size=batch_size, alpha=1, weight_decay=0.01)

In [15]:
bert_ckpt = "prajjwal1/bert-tiny"
pipe = pipeline("text-classification", model=bert_ckpt)

id2label = pipe.model.config.id2label
label2id = pipe.model.config.label2id

pytorch_model.bin:   0%|          | 0.00/17.8M [00:00<?, ?B/s]

  return self.fget.__get__(instance, owner)()
Some weights of BertForSequenceClassification were not initialized from the model checkpoint at prajjwal1/bert-tiny and are newly initialized: ['classifier.bias', 'classifier.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


In [16]:
from transformers import AutoConfig
student_config = (AutoConfig
                  .from_pretrained(student_checkpoint, num_labels=2,
                                    id2label=id2label, label2id=label2id))

In [17]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
def student_init():
  return (AutoModelForSequenceClassification.from_pretrained(student_checkpoint, config=student_config).to(device))

In [18]:
teacher_checkpoint = "bert-base-uncased"

In [19]:
teacher_model = (AutoModelForSequenceClassification
                     .from_pretrained(teacher_checkpoint, num_labels=2)
                     .to(device))

config.json:   0%|          | 0.00/570 [00:00<?, ?B/s]

model.safetensors:   0%|          | 0.00/440M [00:00<?, ?B/s]

Some weights of BertForSequenceClassification were not initialized from the model checkpoint at bert-base-uncased and are newly initialized: ['classifier.bias', 'classifier.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


#Training

In [37]:
tinybert_trainer = KnowledgeDistillationTrainer(model_init=student_init,
        teacher_model=teacher_model, args=student_training_args,
        train_dataset=train_dataset_aug, eval_dataset=clinc_tokenized['validation'],
        compute_metrics=compute_metrics, tokenizer=student_tokenizer)
tinybert_trainer.train()

Some weights of BertForSequenceClassification were not initialized from the model checkpoint at prajjwal1/bert-tiny and are newly initialized: ['classifier.bias', 'classifier.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.
Some weights of BertForSequenceClassification were not initialized from the model checkpoint at prajjwal1/bert-tiny and are newly initialized: ['classifier.bias', 'classifier.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.
[nltk_data] Downloading package averaged_perceptron_tagger to
[nltk_data]     /root/nltk_data...
[nltk_data]   Unzipping taggers/averaged_perceptron_tagger.zip.


Epoch,Training Loss,Validation Loss,Accuracy
1,0.5698,0.529207,0.73804


Checkpoint destination directory tinybert-base-uncased-finetuned-twitter-student/checkpoint-500 already exists and is non-empty.Saving will proceed but saved results may be invalid.
Checkpoint destination directory tinybert-base-uncased-finetuned-twitter-student/checkpoint-1000 already exists and is non-empty.Saving will proceed but saved results may be invalid.
Checkpoint destination directory tinybert-base-uncased-finetuned-twitter-student/checkpoint-1500 already exists and is non-empty.Saving will proceed but saved results may be invalid.
Checkpoint destination directory tinybert-base-uncased-finetuned-twitter-student/checkpoint-2000 already exists and is non-empty.Saving will proceed but saved results may be invalid.
Checkpoint destination directory tinybert-base-uncased-finetuned-twitter-student/checkpoint-2500 already exists and is non-empty.Saving will proceed but saved results may be invalid.


TrainOutput(global_step=2500, training_loss=0.6014998779296875, metrics={'train_runtime': 1230.9117, 'train_samples_per_second': 97.479, 'train_steps_per_second': 2.031, 'total_flos': 16215652895760.0, 'train_loss': 0.6014998779296875, 'epoch': 1.0})

In [22]:
teacher_model.save_pretrained("teacher_model")
tinybert_trainer.save_model('student_model')

In [23]:
def compute_parameters(model_path):
  model = AutoModelForSequenceClassification.from_pretrained(model_path)
  parameters = model.num_parameters()
  return parameters

In [24]:
teacher_model_parameters = compute_parameters(model_path="/content/teacher_model")
print("Teacher Model: ", teacher_model_parameters)

Teacher Model:  109483778


In [25]:
student_model_parameters = compute_parameters(model_path="/content/student_model")
print("Student Model: ", student_model_parameters)

Student Model:  4386178


#Model evalution

In [45]:
tinybert_trainer.evaluate()

{'eval_loss': 0.5292066335678101,
 'eval_accuracy': 0.738040470713738,
 'eval_runtime': 87.7725,
 'eval_samples_per_second': 341.758,
 'eval_steps_per_second': 7.121,
 'epoch': 1.0}