In [1]:
!nvidia-smi

Mon May 17 18:27:32 2021       
+-----------------------------------------------------------------------------+
| NVIDIA-SMI 465.19.01    Driver Version: 460.32.03    CUDA Version: 11.2     |
|-------------------------------+----------------------+----------------------+
| GPU  Name        Persistence-M| Bus-Id        Disp.A | Volatile Uncorr. ECC |
| Fan  Temp  Perf  Pwr:Usage/Cap|         Memory-Usage | GPU-Util  Compute M. |
|                               |                      |               MIG M. |
|   0  Tesla P4            Off  | 00000000:00:04.0 Off |                    0 |
| N/A   63C    P0    25W /  75W |      0MiB /  7611MiB |      0%      Default |
|                               |                      |                  N/A |
+-------------------------------+----------------------+----------------------+
                                                                               
+-----------------------------------------------------------------------------+
| Proces

### Imports

In [2]:
!pip install -q transformers

In [3]:
!pip install -q datasets

In [4]:
import torch
import torch.nn as nn
from torch.utils.data import DataLoader, RandomSampler
import torch.functional as F
import numpy as np
from torch.autograd import Variable

In [5]:
from datasets import load_dataset, load_metric

dataset = load_dataset("glue", "sst2")
metric = load_metric("glue", "sst2")

Reusing dataset glue (/root/.cache/huggingface/datasets/glue/sst2/1.0.0/dacbe3125aa31d7f70367a07a8a9e72a5a0bfeb5fc42e75c9db75b96da6053ad)


Get pretrained BERT-model on SST2 data

In [6]:
 from transformers import BertTokenizer, BertForSequenceClassification

device = torch.device("cuda")
tokenizer = BertTokenizer.from_pretrained("mfuntowicz/bert-base-cased-finetuned-sst2")
model_bert = BertForSequenceClassification.from_pretrained("mfuntowicz/bert-base-cased-finetuned-sst2").to(device)

### Evaluation

In [7]:
def accuracy(test_model):
    correct = 0
    total = 0
    test_model.eval()

    with torch.no_grad():
        for batch in data_val:
            batch_ = tokenizer(batch["sentence"], truncation=True, padding=True)

            batch_ = {k: torch.tensor(v).to(device) for k, v in batch_.items()}
            labels = batch["label"].to(device)

            output = test_model(**batch_) 
            output = output if isinstance(output, torch.Tensor) else output.logits

            probs = torch.softmax(output, dim=1)
            predictions = torch.argmax(probs, dim=1)

            correct += (predictions == labels).sum()
            total += predictions.shape[0]

    acc = correct / total

    return acc

### LSTM model

In [8]:
class LSTM_model(nn.Module):
    def __init__(self, 
                 input_size, 
                 padding_idx,
                 embedding_size=16, 
                 hidden_size=16, 
                 output_size=2, 
                 num_layers=2, 
                 batch_size=32, 
                 dropout=0.5,
                 device=torch.device("cuda")):
      
        super(LSTM_model, self).__init__()
        self.batch_size = batch_size
        self.hidden_size = hidden_size
        self.num_layers = num_layers
        self.embedding = nn.Embedding(input_size, embedding_size, padding_idx=padding_idx)

        self.lstm = nn.LSTM(embedding_size,
                            hidden_size,
                            num_layers=num_layers,
                            dropout=dropout,
                            bidirectional=True,
                            batch_first=True)

        self.fc = nn.Linear(2 * hidden_size, output_size)
        self.dropout = nn.Dropout(dropout)
        self.device = device

    def init_hidden(self):
        return (Variable(torch.zeros(2 * self.num_layers, self.batch_size, self.hidden_size).to(self.device)), 
                Variable(torch.zeros(2 * self.num_layers, self.batch_size, self.hidden_size).to(self.device)))


    def forward(self, input_ids, **params):
        hidden = self.init_hidden()
        x = self.embedding(input_ids)
        _, hidden = self.lstm(x, hidden)
        ht, _ = hidden
        ht = torch.cat((ht[-2, :, :], ht[-1, :, :]), dim=1)
        ht = self.dropout(ht)
        output = self.fc(ht)

        return output

In [9]:
from tqdm.auto import tqdm

### Training

In [20]:
def train(model,
          optimizer,
          trainset,
          epochs, 
          distillation = True, 
          alpha = 0.3):
  
    criterion_ce = nn.CrossEntropyLoss()
    criterion_mse = nn.MSELoss()

    model_bert.eval()

    for epoch in range(epochs):
        model.train()
        
        losses = []
        correct = 0
        total = 0

        for batch in tqdm(trainset):
            batch_ = tokenizer(batch["sentence"], truncation=True, padding=True)
            batch_ = {k: torch.tensor(v).to(device) for k, v in batch_.items()}
            labels = batch["label"].to(device)


            model.zero_grad()
            output = model(**batch_)
            
            with torch.no_grad():
                logits = model_bert(**batch_).logits

                probs = torch.softmax(output, dim=1)
                preds = torch.argmax(probs, dim=1)
                
            if distillation:
                cross_entropy_loss = criterion_ce(output, labels)
                mse_loss = criterion_mse(output, logits)
                loss = alpha * cross_entropy_loss + (1 - alpha) * mse_loss

            else:
                loss = criterion_ce(output, labels)

            loss.backward()
            optimizer.step() 

            losses.append(loss.item())

            correct += (preds == labels).sum()
            total += labels.shape[0]

        train_acc = correct / total
        val_acc = accuracy(model)

        print(f"Epoch {epoch + 1}\n"
              f"loss: {np.mean(losses):.4f} | "
              f"train acc: {train_acc:.4f} | "
              f"val acc: {val_acc:.4f}"
        )

    return model, losses, optimizer

Set training parameters

In [21]:
n_classes = 2
batch_size = 128

data_val = DataLoader(dataset["validation"], batch_size=batch_size, drop_last=True)
data_train = DataLoader(dataset["train"], batch_size=batch_size, drop_last=True)

In [12]:
# Teacher model accuracy
print(f'Accuracy BERT: {accuracy(model_bert):.4f}')

Accuracy BERT: 0.8958


In [13]:
params = {
    'embedding_size': 64, 
    'hidden_size': 32, 
    'output_size': n_classes, 
    'num_layers': 2, 
    'batch_size': batch_size, 
    'dropout': 0.4
}

learning_rate = 1e-3
epochs = 4


Training base student model

In [14]:
lstm = LSTM_model(input_size=tokenizer.vocab_size, padding_idx = tokenizer.pad_token_id, **params).to(device)
optimizer = torch.optim.Adam(lstm.parameters(), lr=learning_rate)

lstm_base, losses_base, _ = train(lstm, optimizer=optimizer, trainset=data_train, epochs=epochs, distillation=False)

HBox(children=(FloatProgress(value=0.0, max=526.0), HTML(value='')))


Epoch 1
loss: 0.5525 | train acc: 0.7022 | val acc: 0.7773


HBox(children=(FloatProgress(value=0.0, max=526.0), HTML(value='')))


Epoch 2
loss: 0.3424 | train acc: 0.8555 | val acc: 0.7891


HBox(children=(FloatProgress(value=0.0, max=526.0), HTML(value='')))


Epoch 3
loss: 0.2557 | train acc: 0.9012 | val acc: 0.8021


HBox(children=(FloatProgress(value=0.0, max=526.0), HTML(value='')))


Epoch 4
loss: 0.2089 | train acc: 0.9235 | val acc: 0.8021


Training model with applying knowledge distillation

In [22]:
lstm = LSTM_model(input_size=tokenizer.vocab_size, padding_idx = tokenizer.pad_token_id, **params).to(device)
optimizer = torch.optim.Adam(lstm.parameters(), lr=learning_rate)

lstm_distil, losses_distil, _ = train(lstm, optimizer=optimizer, trainset=data_train, epochs=epochs, distillation=True)

HBox(children=(FloatProgress(value=0.0, max=526.0), HTML(value='')))


Epoch 1
loss: 6.5782 | train acc: 0.7145 | val acc: 0.7956


HBox(children=(FloatProgress(value=0.0, max=526.0), HTML(value='')))


Epoch 2
loss: 3.2680 | train acc: 0.8600 | val acc: 0.8008


HBox(children=(FloatProgress(value=0.0, max=526.0), HTML(value='')))


Epoch 3
loss: 2.2599 | train acc: 0.8999 | val acc: 0.8112


HBox(children=(FloatProgress(value=0.0, max=526.0), HTML(value='')))


Epoch 4
loss: 1.7511 | train acc: 0.9207 | val acc: 0.8164


### Results

In [23]:
# test data is not available, so we will evaluate on val set

print(f'BERT pretrained (teacher) val acc: {accuracy(model_bert):.4f}')
print(f'Amount of training parameters: {sum(p.numel() for p in model_bert.parameters())}')
print()
print(f'LSTM base (student) val acc: {accuracy(lstm_base):.4f}')
print(f'LSTM distilled (student) val acc: {accuracy(lstm_distil):.4f}')
print(f'Amount of training parameters: {sum(p.numel() for p in lstm_distil.parameters())}')


BERT pretrained (teacher) val acc: 0.8958
Amount of training parameters: 108311810

LSTM base (student) val acc: 0.8021
LSTM distilled (student) val acc: 0.8164
Amount of training parameters: 1906050
