Distilling BERT to a Linear Model
=================================

This is a little extension of the work done in _Distilling Task-Specific Knowledge from BERT into Simple Neural Networks_ by Tang et al. 2019. Hopefully this notebook will serve as an easy-to-follow guide to distillation, which is actually really simple. This is based on work I did for [Polecat](polecat.com).

Tang demonstrates that training a lower-complexity student model to predict a teacher model's output logits is more effective than directly training the student model on the dataset. This is a really neat way of improving performance of smaller models (which are much easier to productionize).

In the paper Tang uses BERT to train a BiLSTM. One of the suggestions for future work is to explore to what extent even simpler models can benefit from the technique. This notebook does just that - we'll try and use BERT to train a simple linear model implemented in PyTorch.

The linear model is the FastText model (Joulin et al. 2016) which normally is an excellent compromise between speed and accuracy. The task is document classification. We wouldn't expect to get near BERT-like accuracy because FastText is a bag-of-words model (it ignores word order, although you can give it n-grams) but it will be interesting to see if we can increase its accuracy at all. 

Let's begin with our dependencies: PyTorch, the great Huggingface transformers library (for a BERT implementation) and other usual suspects.

In [0]:
!pip install torch transformers pandas tqdm



In [0]:
import numpy as np
import pandas as pd
import torch
import torch.nn as nn
from pathlib import Path
from joblib import Memory
from tqdm import tqdm
from torch.utils.data import DataLoader, TensorDataset
from transformers import AutoTokenizer, AutoModelForSequenceClassification, AutoConfig, AdamW, get_linear_schedule_with_warmup, DistilBertForSequenceClassification

In [0]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
device

device(type='cuda')

Dataset
-------

We'll use the Amazon review dataset. It is freely available and consists of product reviews with a star rating and the task is simply to predict the star rating.

First, some data wrangling. I'm afraid this notebook won't run out-of-the-box, because the data and teacher model are too large to distribute.

In [0]:
DATA = Path("/mnt/gdrive/My Drive/data")

if not DATA.exists():
    from google.colab import drive
    drive.mount("/mnt/gdrive")

assert DATA.exists()

In [0]:
CACHE = Path("/mnt/gdrive/My Drive/cache/distillation")

if not CACHE.exists():
    CACHE.mkdir(parents=True)

memory = Memory(CACHE, verbose=False)

In [0]:
market = "uk"

reviews = (pd.read_csv(DATA / f"amazon_reviews_multilingual_{market.upper()}_v1_00.tsv.gz",
                       sep="\t",
                       usecols=["review_id", "star_rating", "review_headline", "review_body"],
                       dtype={"review_id": "string",
                              "star_rating": "Int32",
                              "review_headline": "string",
                              "review_body": "string"})
            .dropna())

We balance the classes and shuffle the dataset. Ideally we should also remove some low-value reviews, e.g. single-word reviews and reviews in other languages. But there are few enough of these to not make much difference as far as this exploration goes.

In [0]:
classes = {1, 2, 3, 4, 5}
class_examples = [reviews[reviews.star_rating == rating] for rating in classes]

MAX_LEN = 50_000

min_len = min(MAX_LEN // len(classes), *[len(c) for c in class_examples])

balanced_df = pd.concat([c.sample(min_len, random_state=42) for c in class_examples])

shuffled_df = balanced_df.sample(len(balanced_df))
shuffled_df["label"] = shuffled_df.star_rating.astype(int) - 1

len(shuffled_df)

50000

In [0]:
shuffled_df.head(2)

Unnamed: 0,review_id,star_rating,review_headline,review_body,label
1640608,R11NKYEYMR011M,4,Good buy,I bought this sleeve after searching on the in...,3
297813,R1DS7T8FXIMZO6,3,A Street cat named Bob,Could not put it down. Really easy to read . S...,2


Split the data into a training set and a test set.

In [0]:
train_frac = 0.8
split_idx = int(train_frac * len(shuffled_df))

train_df = shuffled_df[:split_idx]
test_df =shuffled_df[split_idx:]
len(train_df), len(test_df)

(40000, 10000)

Tokenize the text and convert it to PyTorch tensors. We also need two masking vectors for each example as input to BERT.

In [0]:
try:
    tokenizer = AutoTokenizer.from_pretrained("distilbert-base-multilingual-cased")
except NameError:
    tokenizer = tokenizer

In [0]:
def dataframe_to_dataset(df):
  max_len = 128
  features = tokenizer.batch_encode_plus(df.review_body,
                                         max_length=max_len,
                                         pad_to_max_length=True,
                                         return_attention_masks=True,
                                         return_token_type_ids=True,
                                         return_tensors="pt")
  dataset = TensorDataset(features["input_ids"],
                          features["attention_mask"],
                          features["token_type_ids"],
                          torch.tensor(df.label.astype("int").to_numpy(), dtype=torch.long))
  return dataset

Hyperparameters
---------------

These are more-or-less the default hyperparameters for FastText. The embedding dimension is reduced to 50 to speed up processing slightly.

Beware the batch size - we're using a batch size of **1** for training. This has a significant impact on the accuracy of the linear model, and it's lightweight enough that we can get away with it.

In [0]:
N_EPOCHS = 5
EMBEDDING_DIM = 50
LR = 0.5
BATCH_SIZE = 32  # for TESTING
N_LABELS = 5  # num review ratings

In [0]:
train_loader = DataLoader(dataframe_to_dataset(train_df), batch_size=1, shuffle=False)  # optimal training for the linear model
test_loader = DataLoader(dataframe_to_dataset(test_df), batch_size=BATCH_SIZE, shuffle=False)
len(train_loader), len(test_loader)

(40000, 313)

Teacher
-------

The teacher is actually DistilBERT, rather than BERT. So we are distilling from a distilled model! Ideally the teacher should be BERT-proper so that results are more comparable. But this is running on Google Colab with limited GPU time, so a compromise is necessary.

I trained this DistilBERT model on the same dataset previously. Later on we'll check its accuracy.

In [0]:
try:
    config = config
    teacher = teacher
except NameError:
    config = AutoConfig.from_pretrained("distilbert-base-multilingual-cased")
    config.num_labels = N_LABELS
    teacher = DistilBertForSequenceClassification(config)
    teacher.load_state_dict(torch.load(DATA / "distilbert_uk_50000.bin", map_location=device))
_ = teacher.eval()
_ = teacher.to(device)

To reduce training time, we cache the teacher's predictions. Using the batches as cache keys require computing a hash each time, which would be computationally expensive, so we take advantage of the lack of shuffling and use the dataset name ("train" or "test") and the batch number to key the results.

If the dataset is shuffled the cache _must_ be cleared.

This could be improved by simply labelling the whole dataset before adding it to the data loader.

In [0]:
class CachingTeacher:

    def __init__(self, teacher, cache_path=None):
        self.teacher = teacher
        self.base_model_prefix = teacher.base_model_prefix
        if cache_path is None:
            self.cache = {}
        else:
            self.cache = torch.load(cache_path)

    def eval(self):
        return self.teacher.eval()

    def to(self, device):
        return self.teacher.to(device)

    def __call__(self, dataset_id, batch_id, **inputs):
        cache_id = f"{dataset_id}_{batch_id}"
        if cache_id in self.cache:
            return self.cache[cache_id]
        else:
            with torch.no_grad():
                outputs = self.teacher(**inputs)
            self.cache[cache_id] = outputs
            return outputs

    def dump(self):
        torch.save(self.cache, CACHE / "distilbert_teacher_cache.bin")


caching_teacher = CachingTeacher(teacher)

Student
-------

We define an embeddings bag model as implemented in FastText. For simplicity let's leave out n-grams.

We also define `autodidact`, another instance that will be trained on the corpus directly, rather than via supervision.

In [0]:
class PTModel(nn.Module):
    
    def __init__(self, n_vocab, embedding_dim, n_labels, padding_idx):
        super(PTModel, self).__init__()
        self.embeddings = nn.Embedding(n_vocab, embedding_dim, padding_idx=padding_idx)
        self.output = nn.Linear(embedding_dim, n_labels)
        with torch.no_grad():
            # FastText initializes embeddings with uniform distribution vs normal in PyTorch
            self.embeddings.weight.uniform_(to=1.0 / embedding_dim)
            self.embeddings.weight[padding_idx] = 0  # but FT doesn't have a padding token
            # FastText initializes output with zeros vs some random dist in PyTorch
            self.output.weight.zero_()

    def forward(self, input_ids, **kwargs):
        """Only input ids are required - kwargs are for API compat with BERT."""
        X = self.embeddings(input_ids)
        X = X.mean(dim=1)
        X = self.output(X)
        return X
    
padding_idx = tokenizer.vocab["[PAD]"]
n_vocab = len(tokenizer.vocab)
student = PTModel(n_vocab, EMBEDDING_DIM, N_LABELS, padding_idx)
student.to(device)

PTModel(
  (embeddings): Embedding(119547, 50, padding_idx=0)
  (output): Linear(in_features=50, out_features=5, bias=True)
)

In [0]:
autodidact = PTModel(n_vocab, EMBEDDING_DIM, N_LABELS, padding_idx)
autodidact.to(device)

PTModel(
  (embeddings): Embedding(119547, 50, padding_idx=0)
  (output): Linear(in_features=50, out_features=5, bias=True)
)

Training
--------

This function trains the model for one epoch. If no teacher is provided it uses cross entropy loss (i.e. softmax then NLL) and compares the model predictions to the target label.

If a teacher is provided then model predictions are compared to the teacher's predictions and MSE loss is used.

In the paper Tang defines a cost function that is a balance between the two (i.e. $L = \alpha L_{CE} + (1 - \alpha L_{MSE})$ but in practice observed that the best value for $\alpha$ was zero.

The accuracy on the training set is also output for visibility.

In [0]:
def train_epoch(train_iter, model, optim, epoch_num, teacher=None):
    train_loss = 0
    train_acc = 0
    
    model.to(device)
    model.train()
    
    if teacher is not None:
        teacher.to(device)
        teacher.eval()
        cost = nn.MSELoss()
    else:
        cost = nn.CrossEntropyLoss()

    for batch_idx, batch in enumerate(tqdm(train_iter, total=len(train_iter), desc=f"Batch progress for epoch {epoch_num}")):
        
        batch = tuple([t.to(device) for t in batch])
        inputs = {"input_ids": batch[0],
                  "attention_mask": batch[1]}
        labels = batch[3]

        optim.zero_grad()
        output = model(**inputs)

        if teacher is not None:
            if teacher.base_model_prefix == "bert":
                inputs["token_type_ids"]: batch[2]

            with torch.no_grad():
                target = teacher("train", batch_idx, **inputs)[0]  # BERT returns a tuple
        else:
            target = labels

        batch_loss = cost(output, target)
        train_loss += batch_loss.item()

        batch_acc = (output.argmax(1) == labels).sum().item()
        train_acc += batch_acc

        #print(f"{i:03}: {batch_loss.item() / len(labels):.03}\t\t{(batch_acc / len(labels)):.03}")
        
        batch_loss.backward()
        nn.utils.clip_grad_norm_(model.embeddings.parameters(), 1)  # FT only normalizes the embeddings grad
        optim.step()

    return train_loss / len(train_iter.dataset), train_acc / len(train_iter.dataset)

The validation function is similar but in this case there is no option to compare to the teacher's predictions, because that's not the ultimate point of the exercise - at the end of it all we just want a better small model. 

In [0]:
def validate(test_iter, model):
    test_acc = 0 
    test_loss = 0
    
    cost = nn.CrossEntropyLoss()

    model.to(device)
    model.eval()

    for batch in tqdm(test_iter, desc="Validating"):
        
        batch = tuple([t.to(device) for t in batch])
        inputs = {"input_ids": batch[0],
                  "attention_mask": batch[1],
                  "token_type_ids": batch[2]}
        labels = batch[3]

        with torch.no_grad():
            output = model(**inputs)
            
            batch_loss = cost(output, labels)
            test_loss += batch_loss.item()
                    
            batch_acc = (output.argmax(1) == labels).sum().item() 
            test_acc += batch_acc
            
    return test_loss / len(test_iter.dataset), test_acc / len(test_iter.dataset)

Sanity check - we expect 20% accuracy in each case.

In [0]:
validate(test_loader, student)

Validating: 100%|██████████| 313/313 [00:00<00:00, 1113.43it/s]


(0.050443163526058196, 0.1954)

In [0]:
validate(test_loader, autodidact)

Validating: 100%|██████████| 313/313 [00:00<00:00, 1191.39it/s]


(0.05050554784536362, 0.2024)

Training
--------

We use SGD and a linearly decreasing learning rate because this is empirically best for the linear model (see Joulin et al.).

In [0]:
optim = torch.optim.SGD(student.parameters(), lr=LR)
sched = torch.optim.lr_scheduler.StepLR(optim, step_size=1, gamma=0.5)

training_results = {"train_loss": [],
                    "train_acc": [],
                    "test_loss": [],
                    "test_acc": []}

student.to(device)
teacher.to(device)

try:
    for i in range(N_EPOCHS):
        train_loss, train_acc = train_epoch(train_loader, student, optim, epoch_num=i, teacher=caching_teacher)
        sched.step()
        test_loss, test_acc = validate(test_loader, student)
        training_results["train_loss"].append(train_loss)
        training_results["train_acc"].append(train_acc)
        training_results["test_loss"].append(test_loss)
        training_results["test_acc"].append(test_acc)
        print(f"{i:02}: {train_loss:.03} {train_acc:.03} {test_loss:.03} {test_acc:.03}")
        torch.save({'state_dict': student.state_dict()}, DATA / f"student_{market}_{len(shuffled_df)}.bin")
except KeyboardInterrupt:
    pass

train_df = pd.DataFrame(training_results)
train_df

Batch progress for epoch 0: 100%|██████████| 40000/40000 [04:03<00:00, 164.09it/s]
Validating: 100%|██████████| 313/313 [00:00<00:00, 1261.51it/s]
Batch progress for epoch 1:   0%|          | 74/40000 [00:00<00:54, 734.80it/s]

00: 2.1 0.376 0.0412 0.425


Batch progress for epoch 1: 100%|██████████| 40000/40000 [00:53<00:00, 753.01it/s]
Validating: 100%|██████████| 313/313 [00:00<00:00, 1218.56it/s]
Batch progress for epoch 2:   0%|          | 77/40000 [00:00<00:52, 767.59it/s]

01: 1.53 0.455 0.04 0.449


Batch progress for epoch 2: 100%|██████████| 40000/40000 [00:53<00:00, 753.31it/s]
Validating: 100%|██████████| 313/313 [00:00<00:00, 1186.13it/s]
Batch progress for epoch 3:   0%|          | 74/40000 [00:00<00:54, 738.06it/s]

02: 1.41 0.472 0.0394 0.46


Batch progress for epoch 3: 100%|██████████| 40000/40000 [00:53<00:00, 753.73it/s]
Validating: 100%|██████████| 313/313 [00:00<00:00, 1223.24it/s]
Batch progress for epoch 4:   0%|          | 68/40000 [00:00<00:59, 675.10it/s]

03: 1.35 0.479 0.0391 0.468


Batch progress for epoch 4: 100%|██████████| 40000/40000 [00:52<00:00, 754.75it/s]
Validating: 100%|██████████| 313/313 [00:00<00:00, 1239.38it/s]


04: 1.32 0.482 0.039 0.475


Unnamed: 0,train_loss,train_acc,test_loss,test_acc
0,2.103212,0.3757,0.04118,0.4255
1,1.530487,0.455025,0.039959,0.4492
2,1.40963,0.4723,0.039385,0.4604
3,1.352989,0.479475,0.039142,0.4676
4,1.323567,0.482175,0.038977,0.4754


A second training loop for the "autodidact", the model that is trained without a teacher.

In [0]:
optim_autodidact = torch.optim.SGD(autodidact.parameters(), lr=LR)
sched_autodidact = torch.optim.lr_scheduler.StepLR(optim_autodidact, step_size=1, gamma=0.5)

training_results = {"train_loss": [],
                    "train_acc": [],
                    "test_loss": [],
                    "test_acc": []}

autodidact.to(device)

try:
    for i in range(N_EPOCHS):
        train_loss, train_acc = train_epoch(train_loader, autodidact, optim_autodidact, epoch_num=i)
        sched_autodidact.step()
        test_loss, test_acc = validate(test_loader, autodidact)
        training_results["train_loss"].append(train_loss)
        training_results["train_acc"].append(train_acc)
        training_results["test_loss"].append(test_loss)
        training_results["test_acc"].append(test_acc)
        print(f"{i:02}: {train_loss:.03} {train_acc:.03} {test_loss:.03} {test_acc:.03}")
        torch.save({'state_dict': student.state_dict()}, DATA / f"autodidact_{market}_{len(shuffled_df)}.bin")
except KeyboardInterrupt:
    pass

train_df_autodidact = pd.DataFrame(training_results)
train_df_autodidact

Batch progress for epoch 0: 100%|██████████| 40000/40000 [00:54<00:00, 737.42it/s]
Validating: 100%|██████████| 313/313 [00:00<00:00, 1185.09it/s]


00: 1.52 0.33 0.0441 0.389


Batch progress for epoch 1: 100%|██████████| 40000/40000 [00:54<00:00, 739.09it/s]
Validating: 100%|██████████| 313/313 [00:00<00:00, 1232.12it/s]
Batch progress for epoch 2:   0%|          | 75/40000 [00:00<00:53, 748.50it/s]

01: 1.31 0.437 0.0409 0.433


Batch progress for epoch 2: 100%|██████████| 40000/40000 [00:54<00:00, 738.93it/s]
Validating: 100%|██████████| 313/313 [00:00<00:00, 1197.24it/s]
Batch progress for epoch 3:   0%|          | 75/40000 [00:00<00:53, 746.45it/s]

02: 1.24 0.469 0.0406 0.444


Batch progress for epoch 3: 100%|██████████| 40000/40000 [00:53<00:00, 743.98it/s]
Validating: 100%|██████████| 313/313 [00:00<00:00, 1169.77it/s]
Batch progress for epoch 4:   0%|          | 72/40000 [00:00<00:55, 716.33it/s]

03: 1.21 0.489 0.0405 0.446


Batch progress for epoch 4: 100%|██████████| 40000/40000 [00:53<00:00, 742.16it/s]
Validating: 100%|██████████| 313/313 [00:00<00:00, 1224.77it/s]


04: 1.2 0.498 0.0401 0.451


Unnamed: 0,train_loss,train_acc,test_loss,test_acc
0,1.521084,0.329625,0.044074,0.3888
1,1.309855,0.436775,0.040916,0.4331
2,1.243494,0.46945,0.040609,0.4442
3,1.21185,0.4892,0.040522,0.446
4,1.196159,0.4984,0.04013,0.4514


It's interesting to see that the student converged noticeably faster than the directly-trained model (look at the `train_acc` and `test_acc` columns). Note that you cannot directly compare the training loss, remember these are from different loss functions.

Results
-------

This is a little bit redundant because it was already output in the training loop, but it's helpful to look at it in isolation.

Let's evaluate student on test set.

In [0]:
validate(test_loader, student)

Validating: 100%|██████████| 313/313 [00:00<00:00, 1189.79it/s]


(0.038976528859138486, 0.4754)

And then compare with the accuracy of the model that was directly trained (no distillation).

In [0]:
validate(test_loader, autodidact)

Validating: 100%|██████████| 313/313 [00:00<00:00, 1194.99it/s]


(0.04013044349551201, 0.4514)

The student does score slightly higher accuracy but let's not assume this is statistically significant. (If we were serious, we would whip out a package like `statsmodels` and check.)

For comparison, this is the teacher's result on the test set (it was also trained for 5 epochs).

In [0]:
teacher_test_acc5 = []

for batch_num, batch in enumerate(tqdm(test_loader)):
    batch = tuple([t.to(device) for t in batch])
    inputs = {"input_ids": batch[0],
              "attention_mask": batch[1]}
    if caching_teacher.base_model_prefix == "bert":
        inputs["token_type_ids"]: batch[2]
    labels = batch[3]
    with torch.no_grad():
        logits = caching_teacher("test", batch_num, **inputs)[0]
        probs = torch.softmax(logits, dim=1)
        preds_5class = probs.argmax(dim=1)
        acc_5class = (preds_5class == labels).sum().item() / len(batch[0])
        teacher_test_acc5.append(acc_5class)
        
np.mean(teacher_test_acc5)

100%|██████████| 313/313 [00:18<00:00, 16.64it/s]


0.62310303514377

(Really brief) Discussion
-------------------------

So neither model really got close to the teacher's accuracy, and the student was probably not significantly more accurate. Perhaps this is simply as accurate as we can expect that architecture to be on this task (remember it is a BOW model with no n-grams).

But it's not a total bust. It's very interesting to see that the student did converge faster. This supports Tang's suggestion that the information about prediction uncertainty is valuable, and that this even outweighs the error from the teacher's inaccurate predictions.

We might get better results if we implement the data augmentation that Tang suggests. We could also probably do better with a more complex student - you can see that in this [NLP Town blog post](https://www.nlp.town/blog/distilling-bert/), which inspired me to try this. NLP Town trained a CNN, which you would expect to perform better because eit is more sophisticated and is able to account for the immediate context around a word (useful for e.g. negation).