---
title: "Hello, world! Huggingface T5 finetuning"
description: How to finetune Flan-T5 with samsum dataset?
date: "2023-02-10"
categories: [tutorial]
image: "cover.png"
---

![MJ: computer scientist coding to train AI model, studio ghibli --ar 16:9 --niji](cover.png)

^[MidJourney implies the future belongs to children playing Scratch lol.]

A learning note from reproducing this [amazing post by Philipp Schmid
](https://www.philschmid.de/fine-tune-flan-t5). 

In [None]:
#| code-fold: true

import os

from transformers import (
    AutoTokenizer,
    AutoModelForSeq2SeqLM,
    Seq2SeqTrainingArguments,
    DataCollatorForSeq2Seq,
    Seq2SeqTrainer,
)
from datasets import load_dataset
import evaluate
import nltk
import numpy as np
import pandas as pd
import wandb

nltk.download("punkt", quiet=True)


In [2]:
#| code-fold: true

checkpoint = "google/flan-t5-base"
dataset_name = "samsum"

ft_output_dir = os.getenv("HF_FINETUNE_OUTPUT_DIR")
model_name = checkpoint.split("/")[-1]
hub_model_id = f"{model_name}-{dataset_name}"
model_output_dir = os.path.join(ft_output_dir, hub_model_id)

os.environ["WANDB_PROJECT"] = hub_model_id

## Load dataset
`samsum` is a conversation dataset. The goal is to summarize a conversation. Dataset is available on [Huggingface](https://huggingface.co/datasets/samsum).

In [3]:
ds = load_dataset(dataset_name)
ds


  0%|          | 0/3 [00:00<?, ?it/s]

DatasetDict({
    train: Dataset({
        features: ['id', 'dialogue', 'summary'],
        num_rows: 14732
    })
    test: Dataset({
        features: ['id', 'dialogue', 'summary'],
        num_rows: 819
    })
    validation: Dataset({
        features: ['id', 'dialogue', 'summary'],
        num_rows: 818
    })
})

In [5]:
example = ds["train"][0]
example


{'id': '13818513',
 'dialogue': "Amanda: I baked  cookies. Do you want some?\r\nJerry: Sure!\r\nAmanda: I'll bring you tomorrow :-)",
 'summary': 'Amanda baked cookies and will bring Jerry some tomorrow.'}

## `max_length` analysis
Investigate [truncation and padding](https://huggingface.co/docs/transformers/main/en/pad_truncation#padding-and-truncation) to get statistics on dialogue and summary token length.

Outlier long input may cause out of memory error during training.

In [4]:
model = AutoModelForSeq2SeqLM.from_pretrained(checkpoint)
model.parallelize()

tokenizer = AutoTokenizer.from_pretrained(checkpoint)


In [6]:
tk_dialogue = tokenizer(ds["train"]["dialogue"])["input_ids"]
tk_summary = tokenizer(ds["train"]["summary"])["input_ids"]
pd.set_option('display.float_format', lambda x: '%.1f' % x)

df = pd.DataFrame(
    {"dialogue": [len(d) for d in tk_dialogue], "summary": [len(s) for s in tk_summary]}
)
print(df.describe())

       dialogue  summary
count   14732.0  14732.0
mean      149.0     28.9
std       110.7     15.1
min         1.0      2.0
25%        66.0     17.0
50%       120.0     26.0
75%       202.0     37.0
max      1153.0     94.0


My first hunch is I shouldn't truncate the input. Just need to pad to the longest of the batch. 
The setting would be `tokenizer(batch_sentences, padding=True)`.  

However, it seems that [truncation is inevitable in production](https://twitter.com/RamaswmySridhar/status/1621870502766858241). You need to find a balance and curb the long input outlier. 

For this dataset, 1153 max is not too crazy. 

### Padding experiments
Let's experiment with different padding strategy and how it affects the following batching and training.

First, do it without truncation:

In [7]:
tk_dialogue = tokenizer(ds["train"]["dialogue"], padding=True)["input_ids"]
tk_summary = tokenizer(ds["train"]["summary"], padding=True)["input_ids"]
pd.set_option('display.float_format', lambda x: '%.1f' % x)

df = pd.DataFrame(
    {"dialogue": [len(d) for d in tk_dialogue], "summary": [len(s) for s in tk_summary]}
)
print(df.describe())

       dialogue  summary
count   14732.0  14732.0
mean     1153.0     94.0
std         0.0      0.0
min      1153.0     94.0
25%      1153.0     94.0
50%      1153.0     94.0
75%      1153.0     94.0
max      1153.0     94.0


Expected result. This is literally treating the whole training corpus as one full batch. All sequences are pad to the max length, 1153.  

Try this idea with `batch_size = 8` in dataloader. 

In [8]:
from torch.utils.data import DataLoader

collator = DataCollatorForSeq2Seq(tokenizer, padding=True)
dl = DataLoader(ds['train'].with_transform(lambda x: tokenizer(x['dialogue'])), batch_size=8, collate_fn=collator)


tk_batched = np.array([batch['input_ids'].shape[-1] for batch in dl])

print(len(tk_batched), len(dl))
print(len(np.unique(tk_batched)))

np.unique(tk_batched).max(), np.unique(tk_batched).mean(), np.unique(tk_batched).min()b


1842 1842
482


(1153, 389.02904564315355, 92)

1842 batches, with 482 unique length. This is fine for `pytorch` but would be brutal for jax jit since every change of input shape would [trigger jit recompilation](https://huggingface.co/docs/transformers/main/en/model_doc/t5#training).

> If training on TPU, it is recommended to pad all examples of the dataset to the same length or make use of pad_to_multiple_of to have a small number of predefined bucket sizes to fit all examples in. Dynamically padding batches to the longest example is not recommended on TPU as it triggers a recompilation for every batch shape that is encountered during training thus significantly slowing down the training. only padding up to the longest example in a batch) leads to very slow training on TPU.

The part of only padding to the longest leads to slow training applies to `pytorch` as well.

Try `pad_to_multiple_of=8` to curb the variance of token length in batches. 

In [9]:
collator = DataCollatorForSeq2Seq(tokenizer, padding=True, pad_to_multiple_of=8)
dl = DataLoader(ds['train'].with_transform(lambda x: tokenizer(x['dialogue'])), batch_size=8, collate_fn=collator)


tk_batched = np.array([batch['input_ids'].shape[-1] for batch in dl])

print(len(tk_batched), len(dl))
print(len(np.unique(tk_batched)))

np.unique(tk_batched).max(), np.unique(tk_batched).mean(), np.unique(tk_batched).min()

1842 1842
91


(1160, 485.27472527472526, 96)

1842 batches with 91 unique lengths, much better. 

### Truncation experiment
How does `truncation=True` change anything? According to huggingface doc: `tokenizer(batch_sentences, padding=True, truncation=True)` has the same effect as `tokenizer(batch_sentences, padding=True)`, both padding to max sequence in batch. 

Let's try it out.

In [10]:
collator = DataCollatorForSeq2Seq(tokenizer, padding=True, pad_to_multiple_of=8)
dl = DataLoader(ds['train'].with_transform(lambda x: tokenizer(x['dialogue'], truncation=True)), batch_size=8, collate_fn=collator)


tk_batched = np.array([batch['input_ids'].shape[-1] for batch in dl])

print(len(tk_batched), len(dl))
print(len(np.unique(tk_batched)))

np.unique(tk_batched).max(), np.unique(tk_batched).mean(), np.unique(tk_batched).min()

1842 1842
51


(512, 311.52941176470586, 96)

`truncation=True` in the tokenizer truncates the dialogue to 512 tokens, which is the max length of the T5. However, by default T5 should not have a set maximum length. This is imposed, artificial limitation by transformers library.

Be careful to this behavior. Since unnoticed truncation means unnoticed loss input information during training. 

### Source implementation
In [source ipynb](https://www.philschmid.de/fine-tune-flan-t5): 
```python
tokenized_inputs = concatenate_datasets([dataset["train"], dataset["test"]]).map(lambda x: tokenizer(x["dialogue"], truncation=True), batched=True, remove_columns=["dialogue", "summary"])
max_source_length = max([len(x) for x in tokenized_inputs["input_ids"]])

def preprocess_function(sample,padding="max_length"):
    # add prefix to the input for t5
    inputs = ["summarize: " + item for item in sample["dialogue"]]

    # tokenize inputs
    model_inputs = tokenizer(inputs, max_length=max_source_length, padding=padding, truncation=True)
    pass
```
1. It pads every input to absolute corpus max length. Would waste tons of memory and computation. The mean of dialogue is 149, meaning on average, 1k unnecessary tokens would be processed for each instance, and we have 14732 instances in training set. 
2. I use `flan-t5` which is the heir of LM adopted T5, that makes prepending `summarize:` to the input not necessary. 

## Training
### Prepare for trainer

In [11]:
# no truncation, since the max_length in the training set is only 1153. Should be fine.
def preprocess(examples):
    output = tokenizer(examples["dialogue"])
    output["labels"] = tokenizer(examples["summary"])["input_ids"]
    return output

# tokenize the dataset
tk_ds = ds.map(preprocess, batched=True).remove_columns(ds['train'].column_names)

# load the evaluation metric
rouge = evaluate.load('rouge')

# postprocessing necessary for rouge
def compute_metrics(eval_preds):
    preds, labels = eval_preds

    labels = np.where(labels != -100, labels, tokenizer.pad_token_id)
    decoded_preds = tokenizer.batch_decode(preds, skip_special_tokens=True)
    decoded_labels = tokenizer.batch_decode(labels, skip_special_tokens=True)

    decoded_preds = [
        "\n".join(nltk.sent_tokenize(pred.strip())) for pred in decoded_preds
    ]
    decoded_labels = [
        "\n".join(nltk.sent_tokenize(label.strip())) for label in decoded_labels
    ]

    result = rouge.compute(
        predictions=decoded_preds, references=decoded_labels, use_stemmer=True
    )
    return result

collator = DataCollatorForSeq2Seq(tokenizer, padding=True pad_to_multiple_of=8)

args = Seq2SeqTrainingArguments(
    output_dir=model_output_dir,
    evaluation_strategy="epoch",
    learning_rate=5e-5,
    per_device_train_batch_size=8,
    per_device_eval_batch_size=8,
    weight_decay=0.01,
    save_total_limit=2,
    num_train_epochs=1,
    bf16=True,
    gradient_accumulation_steps=4,
    predict_with_generate=True,
    save_strategy="epoch",
    load_best_model_at_end=True,
    hub_model_id=hub_model_id,
    report_to="wandb",
)

In [None]:

trainer = Seq2SeqTrainer(
    model=model,
    args=args,
    train_dataset=tk_ds["train"],
    eval_dataset=tk_ds["validation"],
    tokenizer=tokenizer,
    data_collator=collator,
    compute_metrics=compute_metrics,
)

### Fire up the training

In [None]:
trainer.train()

In [None]:
#| code-fold: true

wandb.finish()

total_flos = trainer.state.total_flos
runtime = trainer.state.log_history[1]['train_runtime']
utilization = total_flos / 1e12 / runtime # in tflops

## Result
`rouge-1: 47.8%` is in the same range with the source blog. However, to save time it's only trained for 1 epoch.

### About TFLOPS
- `model.parallelize()`
  - `20.43` tflops.
  - Peak memory: GPU1: 16.6G, GPU2: 14.9G
- No `m.parallelize()`, vanilla huggingface trainer. 
  - `16.66` tflops.
  - Peak memory: GPU1: 22.27, GPU2: 21.93G
  - Higher GPU utilization, ~90%, slower training, more memory footprint. Why...?
- `pad_to_multiple_of=64` -> `19.72` tflops
  - Not ready to innovate on [dark magic](https://twitter.com/karpathy/status/1621578354024677377) yet LoL. 
- No `pad_to_multiple_of=8` -> `20.38` tflops
  - No need to do this religiously. Make no difference with `pytorch` and this dataset.