# T5 on TPU 💥🚀

In this notebook we will see how to train T5 model on TPU with Huggingface's awesome new [trainer](https://github.com/huggingface/transformers/blob/master/src/transformers/trainer.py). We will train T5 base model on SQUAD dataset for QA task. We will use the recently released amazing [nlp](https://github.com/huggingface/nlp) package to load and process the dataset in just few lines.

First make sure you are connected to the high RAM instance. This will not work on 12 GB colab instance.

In [0]:
# Crash on purpose to get more ram :
#import torch
#torch.tensor([10.]*10000000000)

Let's install [PyTorch/XLA](https://github.com/pytorch/xla) which enables PyTorch on TPU. Make sure you install the nightly version, as the trainer breaks on other versions.

In [1]:
VERSION = "nightly"  #@param ["1.5" , "20200325", "nightly"]
!curl https://raw.githubusercontent.com/pytorch/xla/master/contrib/scripts/env-setup.py -o pytorch-xla-env-setup.py
!python pytorch-xla-env-setup.py --version $VERSION

  % Total    % Received % Xferd  Average Speed   Time    Time     Time  Current
                                 Dload  Upload   Total   Spent    Left  Speed
  0     0    0     0    0     0      0      0 --:--:-- --:--:-- --:--:--     0100  4264  100  4264    0     0  22208      0 --:--:-- --:--:-- --:--:-- 22208
Updating TPU and VM. This may take around 2 minutes.
Updating TPU runtime to pytorch-nightly ...
Uninstalling torch-1.6.0a0+03eca38:
  Successfully uninstalled torch-1.6.0a0+03eca38
Uninstalling torchvision-0.7.0a0+3902140:
  Successfully uninstalled torchvision-0.7.0a0+3902140
Copying gs://tpu-pytorch/wheels/torch-nightly-cp36-cp36m-linux_x86_64.whl...
- [1 files][ 90.6 MiB/ 90.6 MiB]                                                
Operation completed over 1 objects/90.6 MiB.                                     
Copying gs://tpu-pytorch/wheels/torch_xla-nightly-cp36-cp36m-linux_x86_64.whl...
\ [1 files][121.4 MiB/121.4 MiB]                                                
Op

Install transformers and the nlp package. Restart colab after this

In [2]:
!git clone https://github.com/huggingface/transformers.git
!pip install ./transformers


fatal: destination path 'transformers' already exists and is not an empty directory.
Processing ./transformers
Building wheels for collected packages: transformers
  Building wheel for transformers (setup.py) ... [?25l[?25hdone
  Created wheel for transformers: filename=transformers-2.11.0-cp36-none-any.whl size=675545 sha256=6f85d3f1a8327f8696ebed5189ca0d2d16940bbbf2f7dbd24919916fcbe81a28
  Stored in directory: /tmp/pip-ephem-wheel-cache-egzvkmic/wheels/23/19/dd/2561a4e47240cf6b307729d58e56f8077dd0c698f5992216cf
Successfully built transformers
Installing collected packages: transformers
  Found existing installation: transformers 2.11.0
    Uninstalling transformers-2.11.0:
      Successfully uninstalled transformers-2.11.0
Successfully installed transformers-2.11.0


## Load and process data

Let's load and process the dataset using the nlp library. We will process the examples in follwoing way to cast QA task in text-to-text setting

**input**
question: question_text  context: context 

**target**
answer_text

In [0]:
import torch
from transformers import XLMRobertaTokenizer, XLMRobertaModel
import torch.nn as nn 
from transformers import AdamW
from transformers import get_linear_schedule_with_warmup
import json
import pandas as pd

import torch_xla.core.xla_model as xm
import torch_xla.distributed.parallel_loader as pl
import torch_xla.distributed.xla_multiprocessing as xmp

In [0]:
tokenizer = XLMRobertaTokenizer.from_pretrained('xlm-roberta-base')

In [37]:
from google.colab import files
files.upload()

Saving kaggle.json to kaggle (1).json


{'kaggle.json': b'{"username":"doanquanvietnamca","key":"5c44ad334dfc534e12d04dc8373e0440"}'}

In [38]:
!mkdir -p ~/.kaggle
!cp kaggle.json ~/.kaggle/
!pip install kaggle
!kaggle competitions download -c jigsaw-multilingual-toxic-comment-classification

test-processed-seqlen128.csv.zip: Skipping, found more recently modified local copy (use --force to force download)
jigsaw-toxic-comment-train.csv.zip: Skipping, found more recently modified local copy (use --force to force download)
jigsaw-unintended-bias-train.csv.zip: Skipping, found more recently modified local copy (use --force to force download)
sample_submission.csv: Skipping, found more recently modified local copy (use --force to force download)
jigsaw-toxic-comment-train-processed-seqlen128.csv.zip: Skipping, found more recently modified local copy (use --force to force download)
validation-processed-seqlen128.csv.zip: Skipping, found more recently modified local copy (use --force to force download)
validation.csv.zip: Skipping, found more recently modified local copy (use --force to force download)
jigsaw-unintended-bias-train-processed-seqlen128.csv.zip: Skipping, found more recently modified local copy (use --force to force download)
test.csv.zip: Skipping, found more rece

In [0]:
import zipfile as zf
file_name = ["jigsaw-toxic-comment-train.csv.zip", "jigsaw-unintended-bias-train.csv.zip","validation.csv.zip"]
for file in file_name:
  with  zf.ZipFile(file, 'r') as f:
    f.extractall('')

In [0]:
train1 = pd.read_csv("jigsaw-toxic-comment-train.csv")
train2 = pd.read_csv("jigsaw-unintended-bias-train.csv")
train2.toxic = train2.toxic.round().astype(int)
valid = pd.read_csv('validation.csv')
    #test = pd.read_csv('/kaggle/input/jigsaw-multilingual-toxic-comment-classification/test.csv')
    #sub = pd.read_csv('/kaggle/input/jigsaw-multilingual-toxic-comment-classification/sample_submission.csv')
train = pd.concat([
train1[['comment_text', 'toxic']],
train2[['comment_text', 'toxic']].query('toxic==1'),
train2[['comment_text', 'toxic']].query('toxic==0').sample(n=100000, random_state=0)])
train = train.replace('\n',' ', regex=True)
valid = valid.replace('\n',' ', regex=True)

In [0]:
#preprocessing data.
#remove punc, html, stop works,
from bs4 import BeautifulSoup
from nltk.tokenize import word_tokenize

import re
from tqdm import tqdm
from nltk.stem import WordNetLemmatizer
lemmatizer = WordNetLemmatizer()
def clean_sentences(df):
    reviews = []

    for sent in tqdm(df['comment_text']):
        
        #remove html content
        review_text = BeautifulSoup(sent).get_text()
        
        #remove non-alphabetic characters
        review_text = re.sub("[^a-zA-Z]"," ", review_text)
    
        #tokenize the sentences
        words = word_tokenize(review_text.lower())
    
        #lemmatize each word to its lemma
        lemma_words = [lemmatizer.lemmatize(i) for i in words]
    
        reviews.append(" ".join(lemma_words))

    return(reviews)

In [7]:
train['comment_text'] = clean_sentences(train)
valid['comment_text'] = clean_sentences(valid)

  ' that document to Beautiful Soup.' % decoded_markup
  ' that document to Beautiful Soup.' % decoded_markup
  ' that document to Beautiful Soup.' % decoded_markup
  ' that document to Beautiful Soup.' % decoded_markup
  ' that document to Beautiful Soup.' % decoded_markup
  ' that document to Beautiful Soup.' % decoded_markup
  ' that document to Beautiful Soup.' % decoded_markup
  ' Beautiful Soup.' % markup)
  ' that document to Beautiful Soup.' % decoded_markup
  ' that document to Beautiful Soup.' % decoded_markup
  ' that document to Beautiful Soup.' % decoded_markup
  ' that document to Beautiful Soup.' % decoded_markup
  ' that document to Beautiful Soup.' % decoded_markup
  ' that document to Beautiful Soup.' % decoded_markup
  ' that document to Beautiful Soup.' % decoded_markup
  ' that document to Beautiful Soup.' % decoded_markup
  ' that document to Beautiful Soup.' % decoded_markup
  ' that document to Beautiful Soup.' % decoded_markup
  ' that document to Beautiful Sou

In [0]:
from torch.utils.data import Dataset, DataLoader
from transformers import T5Tokenizer
import torch
class XLMRobertaDataset(Dataset):
  def __init__(self, tokenizer,df,  max_len=192):
    self.data_column = df["comment_text"].values
    self.value = df['toxic'].values
    self.max_len = max_len
    self.tokenizer = tokenizer
        
  def __len__(self):
      return len(self.data_column)

  def __getitem__(self, index):
    # tokenize inputs
    input_ = '<s> %s </s>'%(self.data_column[index])
    tokenized_inputs = self.tokenizer.batch_encode_plus( [input_], max_length=self.max_len, pad_to_max_length=True,return_tensors="pt" )
    source_ids = tokenized_inputs["input_ids"].squeeze()
    src_mask    = tokenized_inputs["attention_mask"].squeeze() # might need to squeeze
    return {"input_ids": source_ids, "attention_mask": src_mask, "target": torch.tensor(self.value[index])}

In [0]:
train_dataset = XLMRobertaDataset(tokenizer,train, 128)
valid_dataset = XLMRobertaDataset(tokenizer, valid, 128)

In [10]:
%%timeit
train_dataset[1]

The slowest run took 10.75 times longer than the fastest. This could mean that an intermediate result is being cached.
10000 loops, best of 3: 143 µs per loop


In [11]:
len(train_dataset), len(valid_dataset)

(435775, 8000)

In [0]:
# cach the dataset, so we can load it directly for training

torch.save(train_dataset, 'train_data.pt')
torch.save(valid_dataset, 'valid_data.pt')

For more details on how to use the nlp library check out this [notebook](https://colab.research.google.com/github/huggingface/nlp/blob/master/notebooks/Overview.ipynb).

## Write training script

Using the `Trainer` is pretty straightforward. Here are the 4 basic steps which are needed to use trainer.

1. **Parse the arguments needed**. These are divided in 3 parts for clarity and seperation (TrainingArguments, ModelArguments and DataTrainingArguments).

  1. **TrainingArguments**: These are basicaly the training hyperparameters such as learning rate, batch size, weight decay, gradient accumulation steps etc. See all possible arguments [here](https://github.com/huggingface/transformers/blob/master/src/transformers/training_args.py). These are used by the Trainer.

  2. **ModelArguments**: These are the arguments for the model that you want to use such as the model_name_or_path, tokenizer_name etc. You'll need these to load the model and tokenizer.

  3. **DataTrainingArguments**: These are as the name suggests arguments needed for the dataset. Such as the directory name where your files are stored etc. You'll need these to load/process the dataset.

  TrainingArguments are already defined in the `TrainingArguments` class, you'll need to define `ModelArguments` and `DataTrainingArguments` classes for your task.




2. Load train and eval datasets
3. Initialize the `Trainer`

    These are the mininum parameters which you'll for initializing `Trainer`. For full list check [here](https://github.com/huggingface/transformers/blob/master/src/transformers/trainer.py#L107)

    ```
      model: PreTrainedModel
      args: TrainingArguments
      train_dataset: Optional[Dataset]
      eval_dataset: Optional[Dataset]
    ```
4. Start training with  `trainer.train`

    Call `trainer.train` and let the magic begin!


There are lots of things which the trainer handles for you out of the box such as gradient_accumulation, fp16 training, setting up the optimizer and scheduler, logging with wandb etc. I didn't set-up wandb for this experiment, but will explore it for sure in future experiment.

In [0]:
def loss_fn(outputs, targets):
    return nn.BCEWithLogitsLoss()(outputs, targets.view(-1, 1))


def train_fn(data_loader, model, optimizer, device, scheduler):
    model.train()

    for bi, d in enumerate(data_loader):
        ids = d["input_ids"]
        #token_type_ids = d["token_type_ids"]
        mask = d["attention_mask"]
        targets = d["target"]

        ids = ids.to(device, dtype=torch.long)
        #token_type_ids = token_type_ids.to(device, dtype=torch.long)
        mask = mask.to(device, dtype=torch.long)
        targets = targets.to(device, dtype=torch.float)

        optimizer.zero_grad()
        outputs = model(
            ids,
           mask,
            #token_type_ids=token_type_ids
        )

        loss = loss_fn(outputs, targets)
        loss.backward()
        xm.optimizer_step(optimizer)
        if scheduler is not None:
            scheduler.step()

        if bi % 100 == 0:
            print(f'[xla:{xm.get_ordinal()}]: bi={bi}, loss={loss}')


def eval_fn(data_loader, model, device):
    model.eval()
    fin_targets = []
    fin_outputs = []
    with torch.no_grad():
        for bi, d in enumerate(data_loader):
            ids = d["input_ids"]
           # token_type_ids = d["token_type_ids"]
            mask = d["attention_mask"]
            targets = d["target"]

            ids = ids.to(device, dtype=torch.long)
            #token_type_ids = token_type_ids.to(device, dtype=torch.long)
            mask = mask.to(device, dtype=torch.long)
            targets = targets.to(device, dtype=torch.float)

            outputs = model(
                ids,
                mask,
                #token_type_ids=token_type_ids
            )
            fin_targets.extend(targets.cpu().detach().numpy().tolist())
            fin_outputs.extend(outputs.cpu().detach().numpy().tolist())
    return fin_outputs, fin_targets

In [0]:
class RobertaMultilayerClassification(torch.nn.Module):
    def __init__(self,config= None):
        super(RobertaMultilayerClassification,self).__init__()
        self.model = XLMRobertaModel.from_pretrained('xlm-roberta-base')
        self.dense = torch.nn.Linear(768, 1)
        self.config = config
        torch.nn.init.xavier_normal_(self.dense.weight)

    def forward(self, ids , attention_mask=None, token_type_ids=None):
        last_hidden_state = self.model(input_ids=ids, attention_mask = attention_mask,token_type_ids= token_type_ids)
        mean_last_hidden_state = self.pool_hidden_state(last_hidden_state)
        #mean_last_hidden_state = self.dropout(mean_last_hidden_state)
        logits = self.dense(mean_last_hidden_state)
        return logits
    def pool_hidden_state(self, last_hidden_state):
        last_hidden_state = last_hidden_state[0]
        mean_last_hidden_state = torch.mean(last_hidden_state, 1)
        return mean_last_hidden_state
        #return last_hidden_state

In [0]:

config = {"epochs":2, "train_batch":4,\
          "valid_batch":4, "learning_rate":5e-5,\
          "save_dir":"models/tpu"}
def main():
    train_dataset  = torch.load("train_data.pt")
    train_sampler = torch.utils.data.distributed.DistributedSampler(
          train_dataset,
          num_replicas=xm.xrt_world_size(),
          rank=xm.get_ordinal(),
          shuffle=True)

    train_data_loader = torch.utils.data.DataLoader(
        train_dataset,
        batch_size=config['train_batch'],
        sampler=train_sampler,
        drop_last=True,
        num_workers=2
    )
    valid_dataset  = torch.load("valid_data.pt")
    valid_sampler = torch.utils.data.distributed.DistributedSampler(
          valid_dataset,
          num_replicas=xm.xrt_world_size(),
          rank=xm.get_ordinal(),
          shuffle=False)

    valid_data_loader = torch.utils.data.DataLoader(
        valid_dataset,
        batch_size=config['valid_batch'],
        sampler=valid_sampler,
        drop_last=False,
        num_workers=1
    )
    device = xm.xla_device()
    model = RobertaMultilayerClassification().to(device)
    
    print('training started')

    # Initialize our Trainer
    param_optimizer = list(model.named_parameters())
    no_decay = ["bias", "LayerNorm.bias", "LayerNorm.weight"]
    optimizer_parameters = [
        {
            'params': [
                p for n, p in param_optimizer if not any(
                    nd in n for nd in no_decay
                )
            ], 
            'weight_decay': 0.001
        },
        {
            'params': [
                p for n, p in param_optimizer if any(
                    nd in n for nd in no_decay
                )
            ],
            'weight_decay': 0.0
        },
    ]

    num_train_steps = int(
        len(train_dataset) / config['train_batch'] / xm.xrt_world_size() * config['epochs']
    )
    print("%s : num train steps: %d",(device, num_train_steps))
    optimizer = AdamW(
        optimizer_parameters, 
        lr=5e-5 * xm.xrt_world_size()
    )
    scheduler = get_linear_schedule_with_warmup(
        optimizer,
        num_warmup_steps=0,
        num_training_steps=num_train_steps
    )

    best_auc = 0
    for epoch in range(config['epochs']):
        para_loader = pl.ParallelLoader(train_data_loader, [device])
        train_fn(
            para_loader.per_device_loader(device), 
            model, 
            optimizer, 
            device, 
            scheduler
        )
        
        para_loader = pl.ParallelLoader(valid_data_loader, [device])
        outputs, targets = eval_fn(
            para_loader.per_device_loader(device), 
            model, 
            device
        )

        targets = np.array(targets) >= 0.5
        auc = metrics.roc_auc_score(targets, outputs)
        print(f'[xla:{xm.get_ordinal()}]: AUC={auc}')
        if auc > best_auc:
            xm.save(model.state_dict(), config['save_dir'])
            best_auc = auc



def _mp_fn(index):
    # For xla_spawn (TPUs)
    main()

## Train

Let's write the arguments in a dict and store in a json file. The above code will load this file and parse the arguments.

In [1]:
xmp.spawn(_mp_fn, args=(), nprocs=8, start_method='fork')

NameError: ignored

## Eval

There are two gotchas here. First the metrics functionality in the nlp package is still work-in-progress so we will use the official squad evaluation script. Second, for some reason which I couldn't figure out, the `.generate` method is not working on TPU so will need to do prediction on CPU. For predicting the validation set it almost takes 40 mins.

In [0]:
import torch
import torch_xla
import torch_xla.core.xla_model as xm
from tqdm.auto import tqdm

In [135]:
model  = XLMRobertaClassification().load_state_dict("models/tpu").to('cpu')

NameError: ignored

In [0]:
pred = []
label = []
for batch in valid_dataset:
  pred.append(model(batch['input_ids'], batch['mask']))
  label.append(batch['target'])

In [0]:
pred[0], label[0]

('negative', 'positive')

In [0]:
from sklearn.metrics import classification_report
print(classification_report(pred, label))

              precision    recall  f1-score   support

    negative       0.86      0.81      0.84       874
    positive       0.77      0.82      0.79       649

    accuracy                           0.82      1523
   macro avg       0.81      0.82      0.81      1523
weighted avg       0.82      0.82      0.82      1523

