In [132]:
import torch
import rich
from datasets import load_dataset
from transformers import AutoModelForSequenceClassification, AutoTokenizer
from torch.nn import CrossEntropyLoss, MSELoss
from transformers import DataCollatorWithPadding
from torch.utils.data import DataLoader
from tqdm import tqdm
from safetensors.torch import save, save_file, load_file
from copy import deepcopy
from transformers import AutoConfig, AutoModelForSequenceClassification
from sklearn.metrics import accuracy_score
from torch.optim import AdamW, Adam
from lightning import Trainer
from lightning.pytorch.callbacks import ModelCheckpoint, EarlyStopping
import lightning as L
import torchmetrics
from lightning.pytorch.callbacks import RichProgressBar

# Load the dataset


In [2]:
dataset = load_dataset("dair-ai/emotion")

In [3]:
model_bert = AutoModelForSequenceClassification.from_pretrained("bert-base-uncased", num_labels=6)

Some weights of BertForSequenceClassification were not initialized from the model checkpoint at bert-base-uncased and are newly initialized: ['classifier.bias', 'classifier.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


In [4]:
tokenizer_bert = AutoTokenizer.from_pretrained("bert-base-uncased")

In [5]:
data_train = dataset['train']
data_dev = dataset['validation']
data_test = dataset['test']

In [6]:
def tokenize(x, tokenizer):
    return tokenizer(x['text'])

In [7]:
data_train = data_train.map(lambda x: tokenize(x, tokenizer_bert), batched=True)
data_dev = data_dev.map(lambda x: tokenize(x, tokenizer_bert), batched=True)
data_test = data_test.map(lambda x: tokenize(x, tokenizer_bert), batched=True)
data_train = data_train.remove_columns(['text'])
data_dev = data_dev.remove_columns(['text'])
data_test = data_test.remove_columns(['text'])

In [8]:
from torchmetrics import F1Score

## Define the teacher model

We use Pytorch Lightning here

In [119]:
class FTModel(L.LightningModule):

    def __init__(self, model, lr: float):
        super().__init__()
        self.model = model
        self.lr = lr
        self.ce_loss = CrossEntropyLoss()
        self.train_f1 = F1Score('multiclass', num_classes=6, average='macro')
        self.val_f1 = F1Score('multiclass', num_classes=6, average='macro')
        self.test_f1 = F1Score('multiclass', num_classes=6, average='macro')

    def forward(self, *args, **kwargs):
        return self.model(*args, **kwargs)

    def training_step(self, batch, batch_idx):
        # training_step defines the train loop.
        y = batch.pop('labels')
        y_hat = self(**batch).logits
        loss = self.ce_loss(y_hat,y)
        self.train_f1(y_hat, y)
        # self.log("f1_step", self.train_f1, on_step=True, on_epoch=True)
        self.log("train_loss", loss, prog_bar=True)
        return loss


    def validation_step(self, batch, batch_idx):
        y = batch.pop('labels')
        y_hat = self(**batch).logits
        loss = self.ce_loss(y_hat,y)
        self.val_f1(y_hat, y)
        self.log("f1_val_step", self.val_f1, on_epoch=True, prog_bar=True)
        self.log("val_loss", loss, prog_bar=True)
        return loss

    def test_step(self, batch, batch_idx):
        y = batch.pop('labels')
        y_hat = self(**batch).logits
        loss = self.ce_loss(y_hat,y)
        self.test_f1(y_hat, y)
        self.log("f1_test_step", self.test_f1, on_epoch=True, prog_bar=True)
        return loss

    def configure_optimizers(self):
        optimizer = torch.optim.AdamW(self.parameters(), lr=self.lr)
        return optimizer

In [10]:
lit_teacher_model = FTModel(model_bert, 1e-5)

In [35]:
train_dl = DataLoader(data_train, batch_size=64, collate_fn=DataCollatorWithPadding(tokenizer=tokenizer_bert), shuffle=True)
dev_dl = DataLoader(data_dev, batch_size=64,  collate_fn=DataCollatorWithPadding(tokenizer=tokenizer_bert))
test_dl = DataLoader(data_test, 64, collate_fn=DataCollatorWithPadding(tokenizer=tokenizer_bert))

In [12]:
batch = next(iter(train_dl))
batch.pop('labels')
lit_teacher_model(**batch)

You're using a BertTokenizerFast tokenizer. Please note that with a fast tokenizer, using the `__call__` method is faster than using a method to encode the text followed by a call to the `pad` method to get a padded encoding.
You're using a BertTokenizerFast tokenizer. Please note that with a fast tokenizer, using the `__call__` method is faster than using a method to encode the text followed by a call to the `pad` method to get a padded encoding.
You're using a BertTokenizerFast tokenizer. Please note that with a fast tokenizer, using the `__call__` method is faster than using a method to encode the text followed by a call to the `pad` method to get a padded encoding.
You're using a BertTokenizerFast tokenizer. Please note that with a fast tokenizer, using the `__call__` method is faster than using a method to encode the text followed by a call to the `pad` method to get a padded encoding.
You're using a BertTokenizerFast tokenizer. Please note that with a fast tokenizer, using the `_

SequenceClassifierOutput(loss=None, logits=tensor([[-0.3868, -0.0174,  0.2346,  0.2838, -0.4430, -0.0506],
        [-0.6515, -0.4207,  0.3611,  0.3695, -0.0037,  0.3715],
        [-0.4106, -0.2959,  0.4098,  0.2875, -0.0127,  0.1310],
        [-0.3918, -0.2167,  0.2953,  0.3553, -0.1297, -0.0403],
        [-0.4535, -0.1855,  0.3604,  0.2629, -0.2811,  0.1525],
        [-0.5081, -0.4296,  0.4088,  0.3135, -0.0986,  0.3215],
        [-0.3434, -0.1394,  0.2728,  0.3082, -0.3253, -0.0564],
        [-0.4396, -0.2799,  0.3979,  0.2847, -0.0443,  0.1602],
        [-0.3713, -0.2373,  0.3243,  0.2871, -0.2001,  0.0391],
        [-0.5784, -0.4244,  0.3802,  0.2692,  0.0708,  0.2732],
        [-0.5308, -0.3879,  0.3763,  0.3239, -0.0937,  0.1885],
        [-0.4805, -0.0769,  0.3745,  0.3049, -0.3659,  0.0009],
        [-0.5932, -0.2624,  0.3935,  0.3139, -0.1337,  0.1156],
        [-0.3756, -0.1220,  0.3192,  0.2737, -0.4083, -0.0438],
        [-0.3672, -0.0963,  0.3253,  0.2101, -0.2910, -0.0528

In [13]:
checkpoint_callback = ModelCheckpoint(dirpath="teacher/",
                                      filename='{epoch}-{val_loss:.2f}-{f1_val_step:.2f}',
                                      save_top_k=19,
                                      monitor="f1_val_step", mode="max")
ea_stop = EarlyStopping(patience=3, monitor="f1_val_step", mode="max")
rich = RichProgressBar()
trainer = Trainer(callbacks=[checkpoint_callback, ea_stop, rich],accelerator='gpu', 
                  devices=1, precision=16, val_check_interval=100, check_val_every_n_epoch=None)

/usr/local/lib/python3.10/dist-packages/lightning/fabric/connector.py:565: `precision=16` is supported for historical reasons but its usage is discouraged. Please set your precision to 16-mixed instead!
Using 16bit Automatic Mixed Precision (AMP)
GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs
/usr/local/lib/python3.10/dist-packages/lightning/pytorch/trainer/connectors/logger_connector/logger_connector.py:67: Starting from v1.9.0, `tensorboardX` has been removed as a dependency of the `lightning.pytorch` package, due to potential conflicts with other packages in the ML ecosystem. For this reason, `logger=True` will use `CSVLogger` as the default logger, unless the `tensorboard` or `tensorboardX` packages are found. Please `pip install lightning[extra]` or one of them to enable TensorBoard support by default


In [36]:
trainer.fit(lit_teacher_model, train_dl, dev_dl)

LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]


TypeError: '<=' not supported between instances of 'float' and 'NoneType'

In [15]:
checkpoint_callback.best_model_path

'teacher/epoch=3-val_loss=0.16-f1_val_step=0.92.ckpt'

In [120]:
lit_teacher_model = FTModel.load_from_checkpoint("teacher/epoch=3-val_loss=0.16-f1_val_step=0.92.ckpt", model=model_bert, lr=None)

In [121]:
trainer.test(lit_teacher_model, test_dl)

LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]


Output()

[{'f1_test_step': 0.8809715509414673}]

## Output our model

In [133]:
train_dl_for_out = DataLoader(data_train, batch_size=64, collate_fn=DataCollatorWithPadding(tokenizer=tokenizer_bert), shuffle=False)

In [134]:
lit_teacher_model.eval()
lit_teacher_model = lit_teacher_model.cuda()
out_aggregated = []
with torch.no_grad():
    for x in tqdm(train_dl_for_out):
        out = lit_teacher_model(**x.to("cuda"))
        out_aggregated.append(out.logits.cpu().detach())
out_aggregated = torch.cat(out_aggregated)
save_file({'proba': out_aggregated}, "teacher_proba.safetensors")

100%|██████████| 250/250 [00:15<00:00, 16.28it/s]


In [135]:
out_aggregated[3]

tensor([-1.8089,  0.8797,  4.7062, -1.1240, -1.8250, -1.0949])

## Our Teacher is 88.09

## Train a small model!

In [28]:
model_smol_cfg = deepcopy(model_bert.config)

In [29]:
model_smol_cfg = deepcopy(model_bert.config)
model_smol_cfg.num_hidden_layers = 3
model_smol_cfg.num_attention_heads = 2
small_scratch_model = AutoModelForSequenceClassification.from_config(model_smol_cfg)

In [30]:
lit_small_model = FTModel(small_scratch_model, 1e-5)

In [31]:
checkpoint_callback = ModelCheckpoint(dirpath="student_scratch/",
                                      filename='{epoch}-{val_loss:.2f}-{f1_val_step:.2f}',
                                      save_top_k=1,
                                      monitor="f1_val_step", mode="max")
ea_stop = EarlyStopping(patience=3, monitor="f1_val_step", mode="max")
rich = RichProgressBar()
trainer = Trainer(callbacks=[checkpoint_callback, ea_stop, rich],accelerator='gpu', 
                  devices=1, precision=16, val_check_interval=100, check_val_every_n_epoch=None)


Using 16bit Automatic Mixed Precision (AMP)
GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs


In [37]:
trainer.fit(lit_small_model, train_dl, dev_dl)

LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]


Output()

You're using a BertTokenizerFast tokenizer. Please note that with a fast tokenizer, using the `__call__` method is faster than using a method to encode the text followed by a call to the `pad` method to get a padded encoding.


In [38]:
checkpoint_callback.best_model_path

'student_scratch/epoch=5-val_loss=0.28-f1_val_step=0.87.ckpt'

In [39]:
lit_small_model = FTModel.load_from_checkpoint(checkpoint_callback.best_model_path, model=small_scratch_model, lr=None)
trainer.test(lit_small_model, test_dl)

LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]


Output()

/usr/local/lib/python3.10/dist-packages/lightning/pytorch/trainer/connectors/data_connector.py:441: The 'test_dataloader' does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` to `num_workers=95` in the `DataLoader` to improve performance.


[{'f1_test_step': 0.8355348706245422}]

## Output our teacher

100%|██████████| 250/250 [00:20<00:00, 11.94it/s]


## Lets create a KD class

In [107]:
class KDModel(L.LightningModule):

    def __init__(self, model, lr: float, weight_default_loss=1, weight_kd_loss=1, teacher_model=None, use_hidden=False):
        super().__init__()
        self.model = model
        self.lr = lr
        self.weight_default_loss = weight_default_loss
        self.weight_kd_loss = weight_kd_loss
        self.ce_loss = CrossEntropyLoss()
        self.mse_loss = MSELoss()
        self.train_f1 = F1Score('multiclass', num_classes=6, average='macro')
        self.val_f1 = F1Score('multiclass', num_classes=6, average='macro')
        self.test_f1 = F1Score('multiclass', num_classes=6, average='macro')
        self.teacher_model = None
        self.use_hidden = use_hidden
        if teacher_model is not None and use_hidden:
            self.teacher_model = teacher_model
            self.teacher_model.eval()
            self.teacher_model.requires_grad_(False)
            
    def forward(self, *args, **kwargs):
        return self.model(*args, **kwargs)

    def training_step(self, batch, batch_idx):
        # training_step defines the train loop.
        y = batch.pop('labels')
        teacher_proba = batch.pop("teacher_proba", default=None)
        out = self.model(**batch, output_hidden_states=True)
        # 1.21 -0.11 1.44
        # 0 0 1
        # [GENERATE A JOKE] CHatGPT: [0 0 0 1 0 0 ] is funny jokes
        # [GENERATE A JOKE] Student: askdfodsakofdsak
        y_hat = out.logits
        loss_kd = self.mse_loss(y_hat, teacher_proba)
        if self.teacher_model is not None and self.use_hidden:
            y_hat_hidden_states = out.hidden_states
            out_teacher_model = self.teacher_model(**batch, output_hidden_states=True)
            sum_hidden_loss = 0
            for std_hid, tch_hid in zip(y_hat_hidden_states, out_teacher_model.hidden_states[::3]):
                sum_hidden_loss += self.mse_loss(std_hid, tch_hid)
            loss_kd += sum_hidden_loss
        loss = self.ce_loss(y_hat,y)
        # total_loss = loss * self.weight_default_loss + loss_kd * self.weight_kd_loss\
        total_loss = loss + loss_kd 

        self.train_f1(y_hat, y)
        self.log("f1_step", self.train_f1, on_step=True, on_epoch=True)
        self.log("train_loss", total_loss, prog_bar=True)
        return total_loss

    def validation_step(self, batch, batch_idx):
        y = batch.pop('labels')
        y_hat = self(**batch).logits
        loss = self.ce_loss(y_hat,y)
        self.val_f1(y_hat, y)
        self.log("f1_val_step", self.val_f1, on_epoch=True, prog_bar=True)
        self.log("val_loss", loss, prog_bar=True)
        return loss

    def test_step(self, batch, batch_idx):
        y = batch.pop('labels')
        y_hat = self(**batch).logits
        loss = self.ce_loss(y_hat,y)
        self.test_f1(y_hat, y)
        self.log("f1_test_step", self.test_f1, on_epoch=True, prog_bar=True)
        return loss

    def on_save_checkpoint(self, checkpoint):
        state_dict = checkpoint['state_dict']
        if 'teacher_model' in state_dict:
            state_dict.pop('teacher_model')
        checkpoint['state_dict'] = state_dict

    def configure_optimizers(self):
        optimizer = torch.optim.AdamW(self.model.parameters(), lr=self.lr)
        return optimizer

In [76]:
model_smol_cfg = deepcopy(model_bert.config)
model_smol_cfg.num_hidden_layers = 3
model_smol_cfg.num_attention_heads = 2
small_scratch_model = AutoModelForSequenceClassification.from_config(model_smol_cfg)

In [77]:
kd_model = KDModel(small_scratch_model, 1e-5)

In [78]:
data_train_for_kd = data_train.add_column('teacher_proba', out_aggregated.numpy().tolist())

In [79]:
train_dl_for_kd = DataLoader(data_train_for_kd, batch_size=64, collate_fn=DataCollatorWithPadding(tokenizer=tokenizer_bert))

In [80]:
checkpoint_callback = ModelCheckpoint(dirpath="kd_output_label/",
                                      filename='{epoch}-{val_loss:.2f}-{f1_val_step:.2f}',
                                      save_top_k=1,
                                      monitor="f1_val_step", mode="max")
ea_stop = EarlyStopping(patience=5, monitor="f1_val_step", mode="max")
rich = RichProgressBar()
trainer = Trainer(callbacks=[checkpoint_callback, ea_stop, rich],accelerator='gpu', 
                  devices=1, precision=16, val_check_interval=100, check_val_every_n_epoch=None)


Using 16bit Automatic Mixed Precision (AMP)
GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs


In [81]:
trainer.fit(kd_model, train_dl_for_kd, dev_dl)

LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]


Output()

In [82]:
checkpoint_callback.best_model_path

'kd_output_label/epoch=7-val_loss=0.32-f1_val_step=0.86.ckpt'

In [83]:
kd_model = KDModel.load_from_checkpoint(checkpoint_callback.best_model_path, model=small_scratch_model, lr=None)

In [84]:
trainer.test(kd_model, test_dl)

LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]


Output()

[{'f1_test_step': 0.8450269103050232}]

## Let's use feature matching too!

In [129]:
model_smol_cfg = deepcopy(model_bert.config)
model_smol_cfg.num_hidden_layers = 3
model_smol_cfg.num_attention_heads = 2
small_scratch_model = AutoModelForSequenceClassification.from_config(model_smol_cfg)

In [128]:
kd_model_feature = KDModel(small_scratch_model, 3e-5, teacher_model=lit_teacher_model, use_hidden=True)

In [125]:
checkpoint_callback = ModelCheckpoint(dirpath="kd_output_label_feature/",
                                      filename='{epoch}-{val_loss:.2f}-{f1_val_step:.2f}',
                                      save_top_k=1,
                                      monitor="f1_val_step", mode="max")
ea_stop = EarlyStopping(patience=5, monitor="f1_val_step", mode="max")
rich = RichProgressBar()
trainer = Trainer(callbacks=[checkpoint_callback, ea_stop, rich],accelerator='gpu', 
                  devices=1, precision=16, val_check_interval=100, check_val_every_n_epoch=None)


Using 16bit Automatic Mixed Precision (AMP)
GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs


In [126]:
trainer.fit(kd_model_feature, train_dl_for_kd, dev_dl)

LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]


Output()

In [130]:
kd_model_feature = KDModel.load_from_checkpoint(trainer.checkpoint_callback.best_model_path, model=small_scratch_model, lr=3e-5, teacher_model=lit_teacher_model, use_hidden=True)

In [131]:
trainer.test(kd_model_feature, test_dl)

LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]


Output()

/usr/local/lib/python3.10/dist-packages/lightning/pytorch/trainer/connectors/data_connector.py:441: The 'test_dataloader' does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` to `num_workers=95` in the `DataLoader` to improve performance.


[{'f1_test_step': 0.8488346338272095}]