# DistilGPT2 LM Finetuning with 8 TPU Cores

*Prepared by Jan Christian Blaise Cruz*

This notebook shows you how to finetune a pretrained DistilGPT2 model on the [WikText-2](https://blog.einstein.ai/the-wikitext-long-term-dependency-language-modeling-dataset/) language modeling dataset using a Cloud TPU and all 8 of its cores, using PyTorch. For more information on the PyTorch XLA project, check out their [GitHub repo](https://github.com/pytorch/xla) for more tutorial notebooks and documentation.

# Setup

First, let's set up PyTorch XLA.

In [1]:
VERSION = "20200325"  #@param ["1.5" , "20200325", "nightly"]
!curl https://raw.githubusercontent.com/pytorch/xla/master/contrib/scripts/env-setup.py -o pytorch-xla-env-setup.py
!python pytorch-xla-env-setup.py --version $VERSION

  % Total    % Received % Xferd  Average Speed   Time    Time     Time  Current
                                 Dload  Upload   Total   Spent    Left  Speed
  0     0    0     0    0     0      0      0 --:--:-- --:--:-- --:--:--     0  0     0    0     0    0     0      0      0 --:--:-- --:--:-- --:--:--     0100  4264  100  4264    0     0  48454      0 --:--:-- --:--:-- --:--:-- 47910
Updating TPU and VM. This may take around 2 minutes.
Updating TPU runtime to pytorch-dev20200325 ...
Uninstalling torch-1.5.0+cu101:
Done updating TPU runtime: <Response [200]>
  Successfully uninstalled torch-1.5.0+cu101
Uninstalling torchvision-0.6.0+cu101:
  Successfully uninstalled torchvision-0.6.0+cu101
Copying gs://tpu-pytorch/wheels/torch-nightly+20200325-cp36-cp36m-linux_x86_64.whl...
- [1 files][ 83.4 MiB/ 83.4 MiB]                                                
Operation completed over 1 objects/83.4 MiB.                                     
Copying gs://tpu-pytorch/wheels/torch_xla-ni

Then download the WikiText-2 dataset, unzip it, then download the HuggingFace Transformers package.

In [2]:
!wget https://s3.amazonaws.com/research.metamind.io/wikitext/wikitext-2-v1.zip
!unzip wikitext-2-v1.zip && rm wikitext-2-v1.zip
!pip install transformers

--2020-05-16 15:58:01--  https://s3.amazonaws.com/research.metamind.io/wikitext/wikitext-2-v1.zip
Resolving s3.amazonaws.com (s3.amazonaws.com)... 52.217.45.86
Connecting to s3.amazonaws.com (s3.amazonaws.com)|52.217.45.86|:443... connected.
HTTP request sent, awaiting response... 200 OK
Length: 4475746 (4.3M) [application/zip]
Saving to: ‘wikitext-2-v1.zip’


2020-05-16 15:58:01 (18.2 MB/s) - ‘wikitext-2-v1.zip’ saved [4475746/4475746]

Archive:  wikitext-2-v1.zip
   creating: wikitext-2/
  inflating: wikitext-2/wiki.test.tokens  
  inflating: wikitext-2/wiki.valid.tokens  
  inflating: wikitext-2/wiki.train.tokens  
Collecting transformers
[?25l  Downloading https://files.pythonhosted.org/packages/22/97/7db72a0beef1825f82188a4b923e62a146271ac2ced7928baa4d47ef2467/transformers-2.9.1-py3-none-any.whl (641kB)
[K     |████████████████████████████████| 645kB 3.4MB/s 
[?25hCollecting tokenizers==0.7.0
[?25l  Downloading https://files.pythonhosted.org/packages/14/e5/a26eb4716523808bb0a7

# Preliminaries

We'll start with some imports.

In [0]:
import torch
import torch.nn as nn
import torch.optim as optim
import torch.utils.data as datautils

import torch_xla
import torch_xla.core.xla_model as xm
import torch_xla.distributed.xla_multiprocessing as xmp
import torch_xla.distributed.parallel_loader as pl

from transformers import GPT2Tokenizer, GPT2LMHeadModel, TextDataset

import numpy as np
import pandas as pd
import time, os

We'll load up the DistilGPT2 tokenizer and add in special tokens since they're not initialized.

In [4]:
tokenizer = GPT2Tokenizer.from_pretrained('distilgpt2')
tokenizer.add_special_tokens({'bos_token': '<bos>', 
                              'eos_token': '<eos>', 
                              'unk_token': '<unk>',
                              'pad_token': '<pad>'})

HBox(children=(FloatProgress(value=0.0, description='Downloading', max=1042301.0, style=ProgressStyle(descript…




HBox(children=(FloatProgress(value=0.0, description='Downloading', max=456318.0, style=ProgressStyle(descripti…




3

HuggingFace Transformers has a wrapper that produces ```torch.utils.data.TensorDataset``` objects for language modeling. We don't need to worry about manually chunking the data by batches and BPTT lengths.

We'll set the block size (the number of tokens the model will see at once) to the GPT2 standard 1024.

In [0]:
train_set = TextDataset(tokenizer=tokenizer, 
                        file_path='wikitext-2/wiki.train.tokens', 
                        block_size=tokenizer.max_len + 1)
valid_set = TextDataset(tokenizer=tokenizer, 
                        file_path='wikitext-2/wiki.valid.tokens', 
                        block_size=tokenizer.max_len + 1)

# Finetuning

Next, we'll write up a mapping function that will be distributed to 8 TPU cores. For more granular information on how PyTorch XLA maps to specific cores, check out [this notebook](https://colab.research.google.com/github/pytorch/xla/blob/master/contrib/colab/multi-core-alexnet-fashion-mnist.ipynb).

This function will be long, so we'll comment in everything that you'll need to know.

In [0]:
def map_fn(index, flags):
    # Set the seed and obtain an XLA device
    torch.manual_seed(flags['seed'])
    device = xm.xla_device()
    print("Process", index, "obtained, using device:", xm.xla_real_devices([str(device)])[0]) 

    # Produce distributed samplers
    train_sampler = datautils.distributed.DistributedSampler(
        train_set, 
        num_replicas=xm.xrt_world_size(), 
        rank=xm.get_ordinal(), 
        shuffle=True
    )
    valid_sampler = datautils.distributed.DistributedSampler(
        valid_set, 
        num_replicas=xm.xrt_world_size(), 
        rank=xm.get_ordinal(), 
        shuffle=False
    )

    # Create dataloaders. Do not shuffle train loader
    # to maintain sequential order
    train_loader = datautils.DataLoader(
        train_set,
        batch_size=flags['batch_size'], 
        sampler=train_sampler, 
        num_workers=flags['num_workers'],
        drop_last=True,
        shuffle=False
    )
    valid_loader = datautils.DataLoader(
        valid_set,
        batch_size=flags['batch_size'], 
        sampler=valid_sampler, 
        num_workers=flags['num_workers'],
        drop_last=True,
        shuffle=False
    )

    # This ensures that the pretrained weights will only be
    # downloaded once (c/o the master process). It also makes
    # sure that the other processes don't attempt to load the
    # weights when downloading isn't finished yet.
    if not xm.is_master_ordinal():
        xm.rendezvous('download_only_once')

    # Configure the model
    model = GPT2LMHeadModel.from_pretrained(flags['pretrained']).to(device)

    if xm.is_master_ordinal():
        xm.rendezvous('download_only_once')

    # Initialize loss and optimizer
    criterion = nn.CrossEntropyLoss()
    optimizer = optim.Adam(model.parameters(), lr=flags['learning_rate'])

    xm.master_print("\nNumber of training batches: {}".format(len(train_loader)))
    xm.master_print("Number of evaluation batches: {}\n".format(len(valid_loader)))

    # Train Model
    model.train()
    train_start = time.time()
  
    for e in range(1, flags['num_epochs'] + 1):
        xm.master_print("=" * 27 + "Epoch {} of {}".format(e, flags['num_epochs']) + "=" * 27)
        para_train_loader = pl.ParallelLoader(train_loader, [device]).per_device_loader(device)
        for i, batch in enumerate(para_train_loader):
            x, y = batch[:,:-1], batch[:, 1:]
            out = model(x)[0]
            loss = criterion(out.flatten(0, 1), y.flatten(0))

            if i % flags['print_every'] == 0:
                xm.master_print('[TRAIN] Iteration {:4} | Loss {:.4f} | Time Elapsed {:.2f} seconds'.format(i, loss.item(),time.time() - train_start))

            optimizer.zero_grad()
            loss.backward()
            xm.optimizer_step(optimizer)
    xm.master_print('\nFinished training {} epochs in {:.2f} seconds.\n'.format(flags['num_epochs'], time.time() - train_start))

    # Evaluate Model
    model.eval()
    valid_start = time.time()
    valid_loss = 0
    
    with torch.no_grad():
        xm.master_print('=' * 28 + 'Validation' + '=' * 28)
        para_valid_loader = pl.ParallelLoader(valid_loader, [device]).per_device_loader(device)
        for i, batch in enumerate(para_valid_loader):
            x, y = batch[:,:-1], batch[:, 1:]
            out = model(x)[0]
            loss = criterion(out.flatten(0, 1), y.flatten(0))

            valid_loss += loss.item()
            if i % flags['print_every'] == 0:
                xm.master_print('[VALID] Iteration {:4} | Loss {:.4f} | Time Elapsed {:.2f} seconds'.format(i, loss.item(),time.time() - train_start))

    valid_loss /= len(valid_loader)
    xm.master_print('\nFinished evaluation in {:.2f} seconds. Validation Loss: {:.4f} | Validation PPL {:.4f}\n'.format(time.time() - valid_start, valid_loss, np.exp(valid_loss)))

    # Save the model
    xm.save(model.state_dict(), flags['savedir'] + '/' + flags['modelpath'])

We'll add in our settings as a dictionary, then start the processes. 

Since we're training on 8 cores with batch size 8 each, we're technically training using batch size 64, which is great.

In [7]:
# Set flags
flags = {
    'batch_size': 8,
    'num_workers': 8,
    'num_epochs': 3,
    'seed': 42,
    'pretrained': 'distilgpt2',
    'savedir': 'training_dir',
    'modelpath': 'model.bin',
    'learning_rate': 1e-5,
    'print_every': 20
}

# Start the process
if flags['savedir'] not in os.listdir('.'): os.mkdir(flags['savedir'])
xmp.spawn(map_fn, args=(flags,), nprocs=8, start_method='fork')

Process 0 obtained, using device: TPU:0
Process 1 obtained, using device: TPU:1


HBox(children=(FloatProgress(value=0.0, description='Downloading', max=762.0, style=ProgressStyle(description_…


Process 5 obtained, using device: TPU:5
Process 7 obtained, using device: TPU:7
Process 2 obtained, using device: TPU:2


HBox(children=(FloatProgress(value=0.0, description='Downloading', max=352833716.0, style=ProgressStyle(descri…

Process 6 obtained, using device: TPU:6
Process 3 obtained, using device: TPU:3
Process 4 obtained, using device: TPU:4


Number of training batches: 35
Number of evaluation batches: 3

[TRAIN] Iteration    0 | Loss 4.7338 | Time Elapsed 2.97 seconds
[TRAIN] Iteration   20 | Loss 4.1438 | Time Elapsed 79.53 seconds
[TRAIN] Iteration    0 | Loss 3.6926 | Time Elapsed 89.75 seconds
[TRAIN] Iteration   20 | Loss 3.7793 | Time Elapsed 103.52 seconds
[TRAIN] Iteration    0 | Loss 3.4750 | Time Elapsed 113.67 seconds
[TRAIN] Iteration   20 | Loss 3.6648 | Time Elapsed 127.69 seconds

Finished training 3 epochs in 136.46 seconds.

[VALID] Iteration    0 | Loss 3.3704 | Time Elapsed 141.23 seconds

Finished evaluation in 8.59 seconds. Validation Loss: 3.4767 | Validation PPL 32.3538



We reach a validation perplexity of ~32 after three epochs of finetuning...for only about 2 minutes!

# Testing

Before we test generation, let's see the performance of our finetuned model on the test set. Let's instantiate another instance of the model and load our finetuned weights. We'll only use one TPU core for this.

In [0]:
test_set = TextDataset(tokenizer=tokenizer, 
                       file_path='wikitext-2/wiki.test.tokens', 
                       block_size=tokenizer.max_len + 1)
test_loader = datautils.DataLoader(test_set, batch_size=flags['batch_size'], shuffle=False, drop_last=True)

# Acquire a device and instantiate model
device = xm.xla_device()
model = GPT2LMHeadModel.from_pretrained('distilgpt2').to(device)
criterion = nn.CrossEntropyLoss()

# Load the saved weights
with open(flags['savedir'] + '/' + flags['modelpath'], 'rb') as f:
    model.load_state_dict(torch.load(f))

Then start the test loop.

In [9]:
print("Testing batches: {}".format(len(test_loader)))
test_start = time.time()

model.eval()
test_loss = 0
for i, batch in enumerate(test_loader):
    x, y = batch[:,:-1], batch[:, 1:]
    x, y = x.to(device), y.to(device)
    with torch.no_grad():
        out, _ = model(x)
        loss = criterion(out.flatten(0, 1), y.flatten(0))

        if i % 10 == 0: 
            print('[TEST] Iteration {:4} | Loss {:.4f} | Time Elapsed {:.2f} seconds'.format(i, loss.item(), time.time() - test_start))
        test_loss += loss.item()
test_loss /= len(test_loader)

print('\nFinished evaluation in {:.2f} seconds. Test Loss: {:.4f} | Test PPL {:.4f}\n'.format(time.time() - test_start, test_loss, np.exp(test_loss)))

Testing batches: 32
[TEST] Iteration    0 | Loss 3.4673 | Time Elapsed 3.39 seconds
[TEST] Iteration   10 | Loss 3.4809 | Time Elapsed 4.87 seconds
[TEST] Iteration   20 | Loss 3.9778 | Time Elapsed 6.31 seconds
[TEST] Iteration   30 | Loss 3.2503 | Time Elapsed 7.74 seconds

Finished evaluation in 7.98 seconds. Test Loss: 3.4646 | Test PPL 31.9635



That's a pretty good perplexity.

Now let's move the model to CPU.

In [0]:
model = model.cpu()

And generate some text.

In [11]:
text = "this is a brand new"
nwords = 15

for i in range(nwords):
    ids = torch.tensor(tokenizer.encode(text)).unsqueeze(0)
    with torch.no_grad():
        out = model(ids)[0]
    pred_ix = out.squeeze(0)[-1].argmax().item()
    pred_word = tokenizer.decode(pred_ix)
    text += pred_word

print(text)

this is a brand new product that will be available in the US and Europe in the coming months.
