# DistilBERT Classifier Finetuning with 8 TPU Cores

*Prepared by Jan Christian Blaise Cruz*

This notebook shows you how to finetune a pretrained DistilBERT model on the [Toxic Comments Classification Challenge](https://www.kaggle.com/c/jigsaw-toxic-comment-classification-challenge/overview/description) dataset using a Cloud TPU and all 8 of its cores, using PyTorch. For more information on the PyTorch XLA project, check out their [GitHub repo](https://github.com/pytorch/xla) for more tutorial notebooks and documentation.

# Setup

First, we set up PyTorch XLA.

In [0]:
VERSION = "20200325"  #@param ["1.5" , "20200325", "nightly"]
!curl https://raw.githubusercontent.com/pytorch/xla/master/contrib/scripts/env-setup.py -o pytorch-xla-env-setup.py
!python pytorch-xla-env-setup.py --version $VERSION

  % Total    % Received % Xferd  Average Speed   Time    Time     Time  Current
                                 Dload  Upload   Total   Spent    Left  Speed
  0     0    0     0    0     0      0      0 --:--:-- --:--:-- --:--:--     0100  4264  100  4264    0     0  34666      0 --:--:-- --:--:-- --:--:-- 34666
Updating TPU and VM. This may take around 2 minutes.
Updating TPU runtime to pytorch-dev20200325 ...
Uninstalling torch-1.5.0+cu101:
Done updating TPU runtime: <Response [200]>
  Successfully uninstalled torch-1.5.0+cu101
Uninstalling torchvision-0.6.0+cu101:
  Successfully uninstalled torchvision-0.6.0+cu101
Copying gs://tpu-pytorch/wheels/torch-nightly+20200325-cp36-cp36m-linux_x86_64.whl...
- [1 files][ 83.4 MiB/ 83.4 MiB]                                                
Operation completed over 1 objects/83.4 MiB.                                     
Copying gs://tpu-pytorch/wheels/torch_xla-nightly+20200325-cp36-cp36m-linux_x86_64.whl...
- [1 files][114.5 MiB/114.5 MiB] 

**Note: Make sure to upload your ```kaggle.json``` file in order to download the data using the API.** More information about the API can be found [here](https://www.kaggle.com/docs/api).

We'll download the 2018 Toxic Comments Classification Challenge dataset from Kaggle, unzip the files, then install the Transformers library.

In [0]:
# Set up the API
# Make sure to upload your kaggle.json file first!
!mkdir ~/.kaggle
!mv kaggle.json ~/.kaggle/kaggle.json
!chmod 600 ~/.kaggle/kaggle.json

# Download the competition data
!kaggle competitions download -c jigsaw-toxic-comment-classification-challenge
!unzip train.csv && unzip test.csv && unzip sample_submission.csv.zip && unzip test_labels.csv.zip

# Download packages
!pip install transformers

Downloading test.csv.zip to /content
 73% 17.0M/23.4M [00:00<00:00, 56.6MB/s]
100% 23.4M/23.4M [00:00<00:00, 78.2MB/s]
Downloading sample_submission.csv.zip to /content
  0% 0.00/1.39M [00:00<?, ?B/s]
100% 1.39M/1.39M [00:00<00:00, 194MB/s]
Downloading test_labels.csv.zip to /content
  0% 0.00/1.46M [00:00<?, ?B/s]
100% 1.46M/1.46M [00:00<00:00, 199MB/s]
Downloading train.csv.zip to /content
 65% 17.0M/26.3M [00:00<00:00, 176MB/s]
100% 26.3M/26.3M [00:00<00:00, 168MB/s]
Archive:  train.csv.zip
  inflating: train.csv               
Archive:  test.csv.zip
  inflating: test.csv                
Archive:  sample_submission.csv.zip
  inflating: sample_submission.csv   
Archive:  test_labels.csv.zip
  inflating: test_labels.csv         
Collecting transformers
[?25l  Downloading https://files.pythonhosted.org/packages/22/97/7db72a0beef1825f82188a4b923e62a146271ac2ced7928baa4d47ef2467/transformers-2.9.1-py3-none-any.whl (641kB)
[K     |████████████████████████████████| 645kB 3.5MB/s 
[?25hC

# Preliminaries

First, some imports.

In [0]:
import torch
import torch.nn as nn
import torch.optim as optim
import torch.utils.data as datautils

import torch_xla
import torch_xla.core.xla_model as xm
import torch_xla.distributed.xla_multiprocessing as xmp
import torch_xla.distributed.parallel_loader as pl

from transformers import AutoConfig, AutoModelForSequenceClassification, AutoTokenizer

import numpy as np
import pandas as pd
from tqdm import tqdm 
from sklearn.metrics import roc_auc_score
import time, os

For this example run, we'll use DistilBERT to finetune a toxic comment classifier. 

You are free to change this to any other pretrained checkpoint, but be sure to tweak the hyperparameters later on.

In [0]:
pretrained = 'distilbert-base-cased'

Here's a function to encode all the text in the dataset into an array of indices from the tokenizer's vocabulary. We'll also load the pretrained tokenizer.

In [0]:
# Fast encoding function
def regular_encode(texts, tokenizer, maxlen=512):
    enc_di = tokenizer.batch_encode_plus(
        texts, 
        return_attention_masks=False, 
        return_token_type_ids=False,
        pad_to_max_length=True,
        max_length=maxlen
    )
    
    return np.array(enc_di['input_ids'])

tokenizer = AutoTokenizer.from_pretrained(pretrained)

HBox(children=(FloatProgress(value=0.0, description='Downloading', max=411.0, style=ProgressStyle(description_…




HBox(children=(FloatProgress(value=0.0, description='Downloading', max=213450.0, style=ProgressStyle(descripti…




Load the training dataset. Encode the text and extract the gold labels (remember we have six classes to predict). Split these into training and validation sets. We'll convert the labels into float tensors.

In [0]:
# Read dataset
df = pd.read_csv('train.csv')

# Encode the dataset
s = time.time()
text = regular_encode(list(df['comment_text']), tokenizer)
labels = df[['toxic', 'severe_toxic', 'obscene', 'threat', 'insult', 'identity_hate']].values
print("Elapsed: {:.2f}s".format(time.time() - s))

# Split into training and testing
tr_sz = int(len(text) * 0.7)
X_train, y_train = torch.tensor(text[:tr_sz]), torch.tensor(labels[:tr_sz]).float()
X_valid, y_valid = torch.tensor(text[tr_sz:]), torch.tensor(labels[tr_sz:]).float()

# Produce datasets
train_set = datautils.TensorDataset(X_train, y_train)
valid_set = datautils.TensorDataset(X_valid, y_valid)

Elapsed: 194.24s


The metric for the competition is "mean column-wise ROC AUC," which isn't directly implemented so we'll implement it here. We're essentially just computing the ROC AUC for each of the six classes, then getting their average. 

We'll try to catch ```ValueError``` exceptions here which will happen if the model assigned (predicted) all the data into one class -- which is very rare -- but still a possibility.

In [0]:
def roc_auc(preds, actuals):
    scores = []
    for i in range(actuals.shape[1]):
        try: score = roc_auc_score(actuals[:, i], preds[:, i])
        except ValueError: score = 0 # In case only one class is present
        scores.append(score)
    return np.array(scores).mean()

# Finetuning

We'll now write the function that will be mapped to all 8 TPU cores. For more granular information on how PyTorch XLA maps to specific cores, check out [this notebook](https://colab.research.google.com/github/pytorch/xla/blob/master/contrib/colab/multi-core-alexnet-fashion-mnist.ipynb).

This function will be long, so we'll comment in everything that you'll need to know.

In [0]:
def map_fn(index, flags):
    # Set the seed and obtain an XLA device
    torch.manual_seed(flags['seed'])
    device = xm.xla_device()
    print("Process", index, "obtained, using device:", xm.xla_real_devices([str(device)])[0]) 

    # Produce distributed samplers
    train_sampler = datautils.distributed.DistributedSampler(
        train_set, 
        num_replicas=xm.xrt_world_size(), 
        rank=xm.get_ordinal(), 
        shuffle=True
    )
    valid_sampler = datautils.distributed.DistributedSampler(
        valid_set, 
        num_replicas=xm.xrt_world_size(), 
        rank=xm.get_ordinal(), 
        shuffle=False
    )

    # Create dataloaders
    train_loader = datautils.DataLoader(
        train_set,
        batch_size=flags['batch_size'], 
        sampler=train_sampler, 
        num_workers=flags['num_workers'],
        drop_last=True
    )
    valid_loader = datautils.DataLoader(
        valid_set,
        batch_size=flags['batch_size'], 
        sampler=valid_sampler, 
        num_workers=flags['num_workers'],
        drop_last=True,
        shuffle=False
    )

    # This ensures that the pretrained weights will only be
    # downloaded once (c/o the master process). It also makes
    # sure that the other processes don't attempt to load the
    # weights when downloading isn't finished yet.
    if not xm.is_master_ordinal():
        xm.rendezvous('download_only_once')

    # Configure the model
    config = AutoConfig.from_pretrained(flags['pretrained'], num_labels=flags['num_labels'])
    model = AutoModelForSequenceClassification.from_pretrained(flags['pretrained'], config=config).to(device)

    if xm.is_master_ordinal():
        xm.rendezvous('download_only_once')

    # Initialize loss and optimizer
    criterion = nn.BCEWithLogitsLoss()
    optimizer = optim.Adam(model.parameters(), lr=flags['learning_rate'])

    xm.master_print("\nNumber of training batches: {}".format(len(train_loader)))
    xm.master_print("Number of evaluation batches: {}\n".format(len(valid_loader)))

    # Train Model
    model.train()
    train_start = time.time()
    
    for e in range(1, flags['num_epochs'] + 1):
        xm.master_print("=" * 27 + "Epoch {} of {}".format(e, flags['num_epochs']) + "=" * 27)
        para_train_loader = pl.ParallelLoader(train_loader, [device]).per_device_loader(device)
        for i, batch in enumerate(para_train_loader):
            x, y = batch
            out = model(x)[0]
            loss = criterion(out, y)

            if i % flags['print_every'] == 0:
                xm.master_print('[TRAIN] Iteration {:4} | Loss {:.4f} | Time Elapsed {:.2f} seconds'.format(i, loss.item(),time.time() - train_start))

            optimizer.zero_grad()
            loss.backward()
            xm.optimizer_step(optimizer)
    xm.master_print('\nFinished training {} epochs in {:.2f} seconds.\n'.format(flags['num_epochs'], time.time() - train_start))

    # Evaluate Model
    model.eval()
    valid_start = time.time()
    preds, actuals = [], []
    
    with torch.no_grad():
        xm.master_print('=' * 28 + 'Validation' + '=' * 28)
        para_valid_loader = pl.ParallelLoader(valid_loader, [device]).per_device_loader(device)
        for i, batch in enumerate(para_valid_loader):
            x, y = batch
            out = model(x)[0]
            loss = criterion(out, y)

            # Keep track of all outputs and gold labels
            actuals.extend(y.cpu().numpy().tolist())
            preds.extend(out.cpu().detach().numpy().tolist())

            if i % flags['print_every'] == 0:
                xm.master_print('[VALID] Iteration {:4} | Loss {:.4f} | Time Elapsed {:.2f} seconds'.format(i, loss.item(),time.time() - train_start))

    preds, actuals = np.array(preds), np.array(actuals)
    valid_auroc = roc_auc(preds, actuals)
    xm.master_print('\nFinished evaluation in {:.2f} seconds. Validation AUROC: {:.4f}\n'.format(time.time() - valid_start, valid_auroc))

    # Save the model
    xm.save(model.state_dict(), flags['savedir'] + '/' + flags['modelpath'])

We'll set the hyperparamters into a dictionary we call ```flags```. If you're coding this into a script instead, this can conveniently come from command line arguments. A word on the batch size: note that we're *technically* doing batch size 128 here, since we're training 16 batches on 8 cores at once.

We'll also create a directory to store our model weights (and other things later). Then we start the distributed process.

In [0]:
# Set flags
flags = {
    'batch_size': 16,
    'num_workers': 8,
    'num_epochs': 3,
    'seed': 42,
    'num_labels': 6,
    'pretrained': pretrained,
    'savedir': 'training_dir',
    'modelpath': 'model.bin',
    'learning_rate': 1e-5,
    'print_every': 150
}

# Start the process
if flags['savedir'] not in os.listdir('.'): os.mkdir(flags['savedir'])
xmp.spawn(map_fn, args=(flags,), nprocs=8, start_method='fork')

Process 0 obtained, using device: TPU:0
Process 4 obtained, using device: TPU:4
Process 6 obtained, using device: TPU:6
Process 5 obtained, using device: TPU:5


HBox(children=(FloatProgress(value=0.0, description='Downloading', max=263273408.0, style=ProgressStyle(descri…

Process 3 obtained, using device: TPU:3
Process 2 obtained, using device: TPU:2
Process 7 obtained, using device: TPU:7
Process 1 obtained, using device: TPU:1


Number of training batches: 872
Number of evaluation batches: 374

[TRAIN] Iteration    0 | Loss 0.7006 | Time Elapsed 4.41 seconds
[TRAIN] Iteration  150 | Loss 0.0825 | Time Elapsed 98.67 seconds
[TRAIN] Iteration  300 | Loss 0.1211 | Time Elapsed 153.47 seconds
[TRAIN] Iteration  450 | Loss 0.1496 | Time Elapsed 208.38 seconds
[TRAIN] Iteration  600 | Loss 0.0523 | Time Elapsed 263.86 seconds
[TRAIN] Iteration  750 | Loss 0.0276 | Time Elapsed 319.04 seconds
[TRAIN] Iteration    0 | Loss 0.0946 | Time Elapsed 366.67 seconds
[TRAIN] Iteration  150 | Loss 0.0289 | Time Elapsed 423.45 seconds
[TRAIN] Iteration  300 | Loss 0.0364 | Time Elapsed 479.03 seconds
[TRAIN] Iteration  450 | Loss 0.0785 | Time Elapsed 534.68 seconds
[TRAIN] Iteration  600 | Loss 0.0576 | Time Elapsed 593.00 seconds
[TRAIN] Iteration  750 | Loss 0.0168 

TPUs are very fast! We finished finetuning and validation in about ~18 minutes all in all.

In comparison, finetuning for this dataset on a P100 GPU (batch size 32, all other hyperparameters the same) takes about ~2 hours 21 minutes for the full three epochs. That's a very big difference.

# Inference

For inferencing, we'd want to make sure that each prediction will be paired with the correct id from the dataset (so kaggle can score our predictions).  We'll subclass ```torch.utils.data.Dataset``` to allow non tensors to be included in a tensor dataset.

In [0]:
class TestDataset(datautils.Dataset):
    def __init__(self, text, ids):
        self.text = text
        self.ids = ids

    def __len__(self):
        return len(self.ids)

    def __getitem__(self, idx):
        ix_text = self.text[idx]
        ix_id = self.ids[idx]
        
        return ix_text, ix_id

We'll read the test set, encode the text, and extract the corresponding list of IDs. We'll construct a dataset and a dataloader from this.

In [0]:
df = pd.read_csv('test.csv')
text = regular_encode(list(df['comment_text']), tokenizer)
ids = list(df['id'])

# Produce a test set and loader
test_set = TestDataset(text, ids)
test_loader = datautils.DataLoader(test_set, batch_size=16, shuffle=False)

We'll write another mapping function for distributed inferencing. The details are mostly the same. The only difference is that at the end of inferencing, each process will save its predictions + their corresponding IDs in the folder that we created earlier. We'll collate all these predictions later.

In [0]:
def map_fn(index, flags):
    # Set the seed and obtain an XLA device
    torch.manual_seed(flags['seed'])
    device = xm.xla_device()
    print("Process", index, "obtained, using device:", xm.xla_real_devices([str(device)])[0])

    # Produce a distributed sampler and a data loader
    test_sampler = datautils.distributed.DistributedSampler(
        test_set,
        num_replicas=xm.xrt_world_size(),
        rank=xm.get_ordinal(),
        shuffle=False
    )
    test_loader = datautils.DataLoader(
        test_set,
        batch_size=flags['batch_size'],
        sampler=test_sampler,
        pin_memory=False,
        drop_last=False,
        num_workers=flags['num_workers']
    )

    # Configure the model and load the checkpoint
    config = AutoConfig.from_pretrained(flags['pretrained'], num_labels=flags['num_labels'])
    model = AutoModelForSequenceClassification.from_pretrained(flags['pretrained'], config=config).to(device)
    model.load_state_dict(torch.load(flags['savedir'] + '/' + flags['modelpath']))

    xm.master_print("\nNumber of testing batches: {}\n".format(len(test_loader)))

    # Run inferencing
    model.eval()
    preds, ids = [], []
    test_start = time.time()

    xm.master_print('=' * 25 + 'Inference' + '=' * 25)
    for i, batch in enumerate(test_loader):
        x, idx = batch
        x = x.to(device)
        with torch.no_grad():
            out = torch.sigmoid(model(x)[0])
            preds.extend(out.cpu().detach().numpy().tolist())
            ids.extend(idx)
        if i % flags['print_every'] == 0: 
            xm.master_print('Inferencing on step {:4} | Time elapsed: {:.2f} seconds'.format(i, time.time() - test_start))
    preds = np.array(preds)

    # Save the predictions and associated IDs into a temporary file
    with open('{}/preds_{}.pt'.format(flags['savedir'], xm.xla_real_devices([str(device)])[0]), 'wb') as f:
        torch.save([ids, preds], f)

    xm.master_print('\nFinished inferencing in {:.2f} seconds.\n'.format(time.time() - test_start))

Then we spawn the processes.

In [0]:
# Start the processes
xmp.spawn(map_fn, args=(flags,), nprocs=8, start_method='fork')

Process 0 obtained, using device: TPU:0
Process 4 obtained, using device: TPU:4
Process 3 obtained, using device: TPU:3
Process 7 obtained, using device: TPU:7
Process 5 obtained, using device: TPU:5
Process 6 obtained, using device: TPU:6
Process 1 obtained, using device: TPU:1
Process 2 obtained, using device: TPU:2

Number of testing batches: 1197

Inferencing on step    0 | Time elapsed: 5.58 seconds
Inferencing on step  150 | Time elapsed: 27.76 seconds
Inferencing on step  300 | Time elapsed: 42.71 seconds
Inferencing on step  450 | Time elapsed: 58.69 seconds
Inferencing on step  600 | Time elapsed: 73.84 seconds
Inferencing on step  750 | Time elapsed: 88.64 seconds
Inferencing on step  900 | Time elapsed: 104.22 seconds
Inferencing on step 1050 | Time elapsed: 120.16 seconds

Finished inferencing in 139.38 seconds.



Afterwards, we'll load all the prediction files and collate them into a Pandas DataFrame. There **will** be a number of duplicate ids, which is a consequence of the distributed strategy. We'll just keep the first that Pandas will see and drop the rest. The difference between the duplicates is very *very* miniscule, considering that they were inferred from the same finetuned weights.

In [0]:
# Load all prediction files
labellist = ['toxic', 'severe_toxic', 'obscene', 'threat', 'insult', 'identity_hate']
all_ids, all_preds = [], []
for i in range(8):
    with open('{}/preds_TPU:{}.pt'.format(flags['savedir'], i), 'rb') as f:
        idx, preds = torch.load(f)
        all_ids.extend(idx)
        all_preds.extend(preds)
preds = np.array(all_preds)

# Combine and remove duplicates
submission = pd.DataFrame(data={'id': all_ids})
for label in labellist:
    submission[label] = 0
submission[labellist] = preds
submission.drop_duplicates(keep='first', subset='id', inplace=True)

We then check if the length of our predictions is the same with the length of the test set, then save the predictions to a submission file.

In [0]:
# Save CSV
assert submission.shape[0] == df.shape[0]
submission.to_csv('submission.csv', index=False)

Sending our submission file to Kaggle got me a score of 0.97847 on the public leaderboard, which is 0.01054 away from the top score! That's a pretty good result.