# T5 on TPU 💥🚀

In this notebook we will see how to train T5 model on TPU with Huggingface's awesome new [trainer](https://github.com/huggingface/transformers/blob/master/src/transformers/trainer.py). We will train T5 base model on SQUAD dataset for QA task. We will use the recently released amazing [nlp](https://github.com/huggingface/nlp) package to load and process the dataset in just few lines.

First make sure you are connected to the high RAM instance. This will not work on 12 GB colab instance.

In [0]:
# Crash on purpose to get more ram :
#import torch
#torch.tensor([10.]*10000000000)

Let's install [PyTorch/XLA](https://github.com/pytorch/xla) which enables PyTorch on TPU. Make sure you install the nightly version, as the trainer breaks on other versions.

In [1]:
VERSION = "nightly"  #@param ["1.5" , "20200325", "nightly"]
!curl https://raw.githubusercontent.com/pytorch/xla/master/contrib/scripts/env-setup.py -o pytorch-xla-env-setup.py
!python pytorch-xla-env-setup.py --version $VERSION

  % Total    % Received % Xferd  Average Speed   Time    Time     Time  Current
                                 Dload  Upload   Total   Spent    Left  Speed
  0     0    0     0    0     0      0      0 --:--:-- --:--:-- --:--:--     0100  4264  100  4264    0     0  23821      0 --:--:-- --:--:-- --:--:-- 23955
Updating TPU and VM. This may take around 2 minutes.
Updating TPU runtime to pytorch-nightly ...
Uninstalling torch-1.5.0+cu101:
Done updating TPU runtime: <Response [200]>
  Successfully uninstalled torch-1.5.0+cu101
Uninstalling torchvision-0.6.0+cu101:
  Successfully uninstalled torchvision-0.6.0+cu101
Copying gs://tpu-pytorch/wheels/torch-nightly-cp36-cp36m-linux_x86_64.whl...
- [1 files][ 90.6 MiB/ 90.6 MiB]                                                
Operation completed over 1 objects/90.6 MiB.                                     
Copying gs://tpu-pytorch/wheels/torch_xla-nightly-cp36-cp36m-linux_x86_64.whl...
\ [1 files][121.4 MiB/121.4 MiB]                       

Install transformers and the nlp package. Restart colab after this

In [2]:
!git clone https://github.com/huggingface/transformers.git
!pip install ./transformers


Cloning into 'transformers'...
remote: Enumerating objects: 65, done.[K
remote: Counting objects:   1% (1/65)[Kremote: Counting objects:   3% (2/65)[Kremote: Counting objects:   4% (3/65)[Kremote: Counting objects:   6% (4/65)[Kremote: Counting objects:   7% (5/65)[Kremote: Counting objects:   9% (6/65)[Kremote: Counting objects:  10% (7/65)[Kremote: Counting objects:  12% (8/65)[Kremote: Counting objects:  13% (9/65)[Kremote: Counting objects:  15% (10/65)[Kremote: Counting objects:  16% (11/65)[Kremote: Counting objects:  18% (12/65)[Kremote: Counting objects:  20% (13/65)[Kremote: Counting objects:  21% (14/65)[Kremote: Counting objects:  23% (15/65)[Kremote: Counting objects:  24% (16/65)[Kremote: Counting objects:  26% (17/65)[Kremote: Counting objects:  27% (18/65)[Kremote: Counting objects:  29% (19/65)[Kremote: Counting objects:  30% (20/65)[Kremote: Counting objects:  32% (21/65)[Kremote: Counting objects:  33% (22/65)[Kremote: Coun

## Load and process data

Let's load and process the dataset using the nlp library. We will process the examples in follwoing way to cast QA task in text-to-text setting

**input**
question: question_text  context: context 

**target**
answer_text

In [0]:
import torch
from transformers import XLMRobertaTokenizer, XLMRobertaModel
import torch.nn as nn 
from transformers import AdamW
from transformers import get_linear_schedule_with_warmup
import json
import pandas as pd

import torch_xla.core.xla_model as xm
import torch_xla.distributed.parallel_loader as pl
import torch_xla.distributed.xla_multiprocessing as xmp

In [4]:
tokenizer = XLMRobertaTokenizer.from_pretrained('xlm-roberta-base')

HBox(children=(FloatProgress(value=0.0, description='Downloading', max=5069051.0, style=ProgressStyle(descript…




In [6]:
from google.colab import files
files.upload()

Saving kaggle.json to kaggle.json


{'kaggle.json': b'{"username":"doanquanvietnamca","key":"5c44ad334dfc534e12d04dc8373e0440"}'}

In [7]:
!mkdir -p ~/.kaggle
!cp kaggle.json ~/.kaggle/
!pip install kaggle
!kaggle competitions download -c jigsaw-multilingual-toxic-comment-classification

Downloading validation.csv.zip to /content
  0% 0.00/1.35M [00:00<?, ?B/s]
100% 1.35M/1.35M [00:00<00:00, 45.3MB/s]
Downloading test-processed-seqlen128.csv.zip to /content
 74% 22.0M/29.8M [00:00<00:00, 99.6MB/s]
100% 29.8M/29.8M [00:00<00:00, 99.3MB/s]
Downloading validation-processed-seqlen128.csv.zip to /content
  0% 0.00/3.44M [00:00<?, ?B/s]
100% 3.44M/3.44M [00:00<00:00, 114MB/s]
Downloading sample_submission.csv to /content
  0% 0.00/612k [00:00<?, ?B/s]
100% 612k/612k [00:00<00:00, 86.8MB/s]
Downloading test.csv.zip to /content
 40% 5.00M/12.4M [00:00<00:00, 49.0MB/s]
100% 12.4M/12.4M [00:00<00:00, 79.2MB/s]
Downloading jigsaw-unintended-bias-train.csv.zip to /content
 99% 288M/292M [00:01<00:00, 202MB/s]
100% 292M/292M [00:01<00:00, 201MB/s]
Downloading jigsaw-unintended-bias-train-processed-seqlen128.csv.zip to /content
 97% 632M/650M [00:05<00:00, 140MB/s]
100% 650M/650M [00:05<00:00, 131MB/s]
Downloading jigsaw-toxic-comment-train.csv.zip to /content
 99% 37.0M/37.3M [00:0

In [0]:
import zipfile as zf
file_name = ["jigsaw-toxic-comment-train.csv.zip", "jigsaw-unintended-bias-train.csv.zip","validation.csv.zip"]
for file in file_name:
  with  zf.ZipFile(file, 'r') as f:
    f.extractall('')

In [0]:
train1 = pd.read_csv("jigsaw-toxic-comment-train.csv")
train2 = pd.read_csv("jigsaw-unintended-bias-train.csv")
train2.toxic = train2.toxic.round().astype(int)
valid = pd.read_csv('validation.csv')
    #test = pd.read_csv('/kaggle/input/jigsaw-multilingual-toxic-comment-classification/test.csv')
    #sub = pd.read_csv('/kaggle/input/jigsaw-multilingual-toxic-comment-classification/sample_submission.csv')
train = pd.concat([
train1[['comment_text', 'toxic']],
train2[['comment_text', 'toxic']].query('toxic==1'),
train2[['comment_text', 'toxic']].query('toxic==0').sample(n=100000, random_state=0)])
train = train.replace('\n',' ', regex=True)
valid = valid.replace('\n',' ', regex=True)

In [14]:
#preprocessing data.
#remove punc, html, stop works,
from bs4 import BeautifulSoup
from nltk.tokenize import word_tokenize
import nltk
nltk.download('punkt')
nltk.download('wordnet')
import re
from tqdm import tqdm
from nltk.stem import WordNetLemmatizer
lemmatizer = WordNetLemmatizer()
def clean_sentences(df):
    reviews = []

    for sent in tqdm(df['comment_text']):
        
        #remove html content
        review_text = BeautifulSoup(sent).get_text()
        
        #remove non-alphabetic characters
        review_text = re.sub("[^a-zA-Z]"," ", review_text)
    
        #tokenize the sentences
        words = word_tokenize(review_text.lower())
        #lemmatize each word to its lemma
        lemma_words = [lemmatizer.lemmatize(i) for i in words]
    
        reviews.append(" ".join(lemma_words))

    return(reviews)

[nltk_data] Downloading package punkt to /root/nltk_data...
[nltk_data]   Package punkt is already up-to-date!
[nltk_data] Downloading package wordnet to /root/nltk_data...
[nltk_data]   Unzipping corpora/wordnet.zip.


In [15]:
train['comment_text'] = clean_sentences(train)
valid['comment_text'] = clean_sentences(valid)

  ' that document to Beautiful Soup.' % decoded_markup
  ' that document to Beautiful Soup.' % decoded_markup
  ' that document to Beautiful Soup.' % decoded_markup
  ' that document to Beautiful Soup.' % decoded_markup
  ' that document to Beautiful Soup.' % decoded_markup
  ' that document to Beautiful Soup.' % decoded_markup
  ' that document to Beautiful Soup.' % decoded_markup
  ' Beautiful Soup.' % markup)
  ' that document to Beautiful Soup.' % decoded_markup
  ' that document to Beautiful Soup.' % decoded_markup
  ' that document to Beautiful Soup.' % decoded_markup
  ' that document to Beautiful Soup.' % decoded_markup
  ' that document to Beautiful Soup.' % decoded_markup
  ' that document to Beautiful Soup.' % decoded_markup
  ' that document to Beautiful Soup.' % decoded_markup
  ' that document to Beautiful Soup.' % decoded_markup
  ' that document to Beautiful Soup.' % decoded_markup
  ' that document to Beautiful Soup.' % decoded_markup
  ' that document to Beautiful Sou

In [0]:
import re
train['comment_text'] = train['comment_text'].apply(lambda x: re.split('https:\/\/.*', str(x))[0])
valid['comment_text'] = valid['comment_text'].apply(lambda x: re.split('https:\/\/.*', str(x))[0])


In [0]:
from torch.utils.data import Dataset, DataLoader
from transformers import T5Tokenizer
import torch
class XLMRobertaDataset(Dataset):
  def __init__(self, tokenizer,df,  max_len=192):
    self.data_column = df["comment_text"].values
    self.value = df['toxic'].values
    self.max_len = max_len
    self.tokenizer = tokenizer
        
  def __len__(self):
      return len(self.data_column)

  def __getitem__(self, index):
    # tokenize inputs
    input_ = '<s> %s </s>'%(self.data_column[index])
    tokenized_inputs = self.tokenizer.batch_encode_plus( [input_], max_length=self.max_len, pad_to_max_length=True,return_tensors="pt" )
    source_ids = tokenized_inputs["input_ids"].squeeze()
    src_mask    = tokenized_inputs["attention_mask"].squeeze() # might need to squeeze
    return {"input_ids": source_ids, "attention_mask": src_mask, "target": torch.tensor(self.value[index])}

In [0]:
train_dataset = XLMRobertaDataset(tokenizer,train, 128)
valid_dataset = XLMRobertaDataset(tokenizer, valid, 128)

In [27]:
%%timeit
train_dataset[1]

The slowest run took 5.38 times longer than the fastest. This could mean that an intermediate result is being cached.
10000 loops, best of 3: 142 µs per loop


In [28]:
len(train_dataset), len(valid_dataset)

(435775, 8000)

In [0]:
# cach the dataset, so we can load it directly for training

torch.save(train_dataset, 'train_data.pt')
torch.save(valid_dataset, 'valid_data.pt')

For more details on how to use the nlp library check out this [notebook](https://colab.research.google.com/github/huggingface/nlp/blob/master/notebooks/Overview.ipynb).

## Write training script

Using the `Trainer` is pretty straightforward. Here are the 4 basic steps which are needed to use trainer.

1. **Parse the arguments needed**. These are divided in 3 parts for clarity and seperation (TrainingArguments, ModelArguments and DataTrainingArguments).

  1. **TrainingArguments**: These are basicaly the training hyperparameters such as learning rate, batch size, weight decay, gradient accumulation steps etc. See all possible arguments [here](https://github.com/huggingface/transformers/blob/master/src/transformers/training_args.py). These are used by the Trainer.

  2. **ModelArguments**: These are the arguments for the model that you want to use such as the model_name_or_path, tokenizer_name etc. You'll need these to load the model and tokenizer.

  3. **DataTrainingArguments**: These are as the name suggests arguments needed for the dataset. Such as the directory name where your files are stored etc. You'll need these to load/process the dataset.

  TrainingArguments are already defined in the `TrainingArguments` class, you'll need to define `ModelArguments` and `DataTrainingArguments` classes for your task.




2. Load train and eval datasets
3. Initialize the `Trainer`

    These are the mininum parameters which you'll for initializing `Trainer`. For full list check [here](https://github.com/huggingface/transformers/blob/master/src/transformers/trainer.py#L107)

    ```
      model: PreTrainedModel
      args: TrainingArguments
      train_dataset: Optional[Dataset]
      eval_dataset: Optional[Dataset]
    ```
4. Start training with  `trainer.train`

    Call `trainer.train` and let the magic begin!


There are lots of things which the trainer handles for you out of the box such as gradient_accumulation, fp16 training, setting up the optimizer and scheduler, logging with wandb etc. I didn't set-up wandb for this experiment, but will explore it for sure in future experiment.

In [0]:
def loss_fn(outputs, targets):
    return nn.BCEWithLogitsLoss()(outputs, targets.view(-1, 1))


def train_fn(data_loader, model, optimizer, device, scheduler):
    model.train()

    for bi, d in enumerate(data_loader):
        ids = d["input_ids"]
        #token_type_ids = d["token_type_ids"]
        mask = d["attention_mask"]
        targets = d["target"]

        ids = ids.to(device, dtype=torch.long)
        #token_type_ids = token_type_ids.to(device, dtype=torch.long)
        mask = mask.to(device, dtype=torch.long)
        targets = targets.to(device, dtype=torch.float)

        optimizer.zero_grad()
        outputs = model(
            ids,
           mask,
            #token_type_ids=token_type_ids
        )

        loss = loss_fn(outputs, targets)
        loss.backward()
        xm.optimizer_step(optimizer)
        if scheduler is not None:
            scheduler.step()

        if bi % 100 == 0:
            print(f'[xla:{xm.get_ordinal()}]: bi={bi}, loss={loss}')


def eval_fn(data_loader, model, device):
    model.eval()
    fin_targets = []
    fin_outputs = []
    with torch.no_grad():
        for bi, d in enumerate(data_loader):
            ids = d["input_ids"]
           # token_type_ids = d["token_type_ids"]
            mask = d["attention_mask"]
            targets = d["target"]

            ids = ids.to(device, dtype=torch.long)
            #token_type_ids = token_type_ids.to(device, dtype=torch.long)
            mask = mask.to(device, dtype=torch.long)
            targets = targets.to(device, dtype=torch.float)

            outputs = model(
                ids,
                mask,
                #token_type_ids=token_type_ids
            )
            fin_targets.extend(targets.cpu().detach().numpy().tolist())
            fin_outputs.extend(outputs.cpu().detach().numpy().tolist())
    return fin_outputs, fin_targets

In [0]:
class RobertaMultilayerClassification(torch.nn.Module):
    def __init__(self,config= None):
        super(RobertaMultilayerClassification,self).__init__()
        self.model = XLMRobertaModel.from_pretrained('xlm-roberta-base')
        self.dense = torch.nn.Linear(768, 1)
        self.config = config
        torch.nn.init.xavier_normal_(self.dense.weight)

    def forward(self, ids , attention_mask=None, token_type_ids=None):
        last_hidden_state = self.model(input_ids=ids, attention_mask = attention_mask,token_type_ids= token_type_ids)
        mean_last_hidden_state = self.pool_hidden_state(last_hidden_state)
        #mean_last_hidden_state = self.dropout(mean_last_hidden_state)
        logits = self.dense(mean_last_hidden_state)
        return logits
    def pool_hidden_state(self, last_hidden_state):
        last_hidden_state = last_hidden_state[0]
        mean_last_hidden_state = torch.mean(last_hidden_state, 1)
        return mean_last_hidden_state
        #return last_hidden_state

In [0]:
import time
time_train = [15,10,20,5,0]
config = {"epochs":2, "train_batch":4,\
          "valid_batch":4, "learning_rate":5e-5,\
          "save_dir":"models/tpu"}
def main():
    time_train = random.shuffle(time_train)
    print("train load")
    time.sleep(time_train[3])
    train_dataset  = torch.load("train_data.pt")
    train_sampler = torch.utils.data.distributed.DistributedSampler(
          train_dataset,
          num_replicas=xm.xrt_world_size(),
          rank=xm.get_ordinal(),
          shuffle=True)

    train_data_loader = torch.utils.data.DataLoader(
        train_dataset,
        batch_size=config['train_batch'],
        sampler=train_sampler,
        drop_last=True,
        num_workers=2
    )
    print("train done")
    time.sleep(time_train[4])
    print("valid load")
    valid_dataset  = torch.load("valid_data.pt")
    valid_sampler = torch.utils.data.distributed.DistributedSampler(
          valid_dataset,
          num_replicas=xm.xrt_world_size(),
          rank=xm.get_ordinal(),
          shuffle=False)

    valid_data_loader = torch.utils.data.DataLoader(
        valid_dataset,
        batch_size=config['valid_batch'],
        sampler=valid_sampler,
        drop_last=False,
        num_workers=1
    )
    print("valid done")
    device = xm.xla_device()
    model = RobertaMultilayerClassification().to(device)
    
    print('training started')
    time.sleep(time_train[0])
    # Initialize our Trainer
    param_optimizer = list(model.named_parameters())
    no_decay = ["bias", "LayerNorm.bias", "LayerNorm.weight"]
    optimizer_parameters = [
        {
            'params': [
                p for n, p in param_optimizer if not any(
                    nd in n for nd in no_decay
                )
            ], 
            'weight_decay': 0.001
        },
        {
            'params': [
                p for n, p in param_optimizer if any(
                    nd in n for nd in no_decay
                )
            ],
            'weight_decay': 0.0
        },
    ]

    num_train_steps = int(
        len(train_dataset) / config['train_batch'] / xm.xrt_world_size() * config['epochs']
    )
    time.sleep(time_train[1])
    print("%s : num train steps: %d"%(device, num_train_steps))
    optimizer = AdamW(
        optimizer_parameters, 
        lr=5e-5 * xm.xrt_world_size()
    )
    scheduler = get_linear_schedule_with_warmup(
        optimizer,
        num_warmup_steps=0,
        num_training_steps=num_train_steps
    )
    time.sleep(time_train[2])
    best_auc = 0
    for epoch in range(config['epochs']):
        para_loader = pl.ParallelLoader(train_data_loader, [device])
        train_fn(
            para_loader.per_device_loader(device), 
            model, 
            optimizer, 
            device, 
            scheduler
        )
        
        para_loader = pl.ParallelLoader(valid_data_loader, [device])
        outputs, targets = eval_fn(
            para_loader.per_device_loader(device), 
            model, 
            device
        )

        targets = np.array(targets) >= 0.5
        auc = metrics.roc_auc_score(targets, outputs)
        print(f'[xla:{xm.get_ordinal()}]: AUC={auc}')
        if auc > best_auc:
            xm.save(model.state_dict(), config['save_dir'])
            best_auc = auc



def _mp_fn(index):
    # For xla_spawn (TPUs)
    main()

## Train

In [0]:
xmp.spawn(_mp_fn, args=(), nprocs=1, start_method='fork')

training started
%s : num train steps: %d (device(type='xla', index=1), 217887)
[xla:0]: bi=0, loss=1.0552258491516113
[xla:0]: bi=100, loss=0.46707791090011597
[xla:0]: bi=200, loss=0.3886053264141083
[xla:0]: bi=300, loss=0.6350961923599243
[xla:0]: bi=400, loss=0.5741508603096008
[xla:0]: bi=500, loss=0.5198253989219666
[xla:0]: bi=600, loss=1.1354632377624512
[xla:0]: bi=700, loss=0.5688650608062744
[xla:0]: bi=800, loss=0.5962872505187988
[xla:0]: bi=900, loss=0.7892423868179321
[xla:0]: bi=1000, loss=0.558315098285675
[xla:0]: bi=1100, loss=0.5810234546661377
[xla:0]: bi=1200, loss=0.7550468444824219
[xla:0]: bi=1300, loss=0.3292660713195801
[xla:0]: bi=1400, loss=0.5508790016174316
[xla:0]: bi=1500, loss=0.36974647641181946
[xla:0]: bi=1600, loss=0.8953086137771606
[xla:0]: bi=1700, loss=0.576103925704956
[xla:0]: bi=1800, loss=0.6993395090103149
[xla:0]: bi=1900, loss=0.5973218083381653
[xla:0]: bi=2000, loss=0.7874424457550049
[xla:0]: bi=2100, loss=0.6097589135169983
[xla:0]:

## Eval

In [0]:
import torch
import torch_xla
import torch_xla.core.xla_model as xm
from tqdm.auto import tqdm

In [0]:
model  = XLMRobertaClassification().load_state_dict("models/tpu").to('cpu')

NameError: ignored

In [0]:
pred = []
label = []
for batch in valid_dataset:
  pred.append(model(batch['input_ids'], batch['mask']))
  label.append(batch['target'])

In [0]:
pred[0], label[0]

('negative', 'positive')

In [0]:
from sklearn.metrics import classification_report
print(classification_report(pred, label))

              precision    recall  f1-score   support

    negative       0.86      0.81      0.84       874
    positive       0.77      0.82      0.79       649

    accuracy                           0.82      1523
   macro avg       0.81      0.82      0.81      1523
weighted avg       0.82      0.82      0.82      1523

