source: [link](https://www.philschmid.de/fine-tune-flan-t5-deepspeed)

In [22]:
from datasets import load_dataset
from transformers import AutoTokenizer, AutoModelForSequenceClassification, TrainingArguments, Trainer
import numpy as np
import pandas as pd

In [2]:
ds = load_dataset("cnn_dailymail", "3.0.0")
ds

Downloading builder script:   0%|          | 0.00/8.33k [00:00<?, ?B/s]

Downloading metadata:   0%|          | 0.00/9.88k [00:00<?, ?B/s]

Downloading readme:   0%|          | 0.00/15.1k [00:00<?, ?B/s]

Downloading and preparing dataset cnn_dailymail/3.0.0 to /workspaces/seed/cache/hf_dataset/cnn_dailymail/3.0.0/3.0.0/1b3c71476f6d152c31c1730e83ccb08bcf23e348233f4fcc11e182248e6bf7de...


Downloading data files:   0%|          | 0/5 [00:00<?, ?it/s]

Downloading data:   0%|          | 0.00/159M [00:00<?, ?B/s]

Downloading data:   0%|          | 0.00/376M [00:00<?, ?B/s]

Downloading data:   0%|          | 0.00/12.3M [00:00<?, ?B/s]

Downloading data:   0%|          | 0.00/661k [00:00<?, ?B/s]

Downloading data:   0%|          | 0.00/572k [00:00<?, ?B/s]

Generating train split:   0%|          | 0/287113 [00:00<?, ? examples/s]

Generating validation split:   0%|          | 0/13368 [00:00<?, ? examples/s]

Generating test split:   0%|          | 0/11490 [00:00<?, ? examples/s]

Dataset cnn_dailymail downloaded and prepared to /workspaces/seed/cache/hf_dataset/cnn_dailymail/3.0.0/3.0.0/1b3c71476f6d152c31c1730e83ccb08bcf23e348233f4fcc11e182248e6bf7de. Subsequent calls will reuse this data.


  0%|          | 0/3 [00:00<?, ?it/s]

DatasetDict({
    train: Dataset({
        features: ['article', 'highlights', 'id'],
        num_rows: 287113
    })
    validation: Dataset({
        features: ['article', 'highlights', 'id'],
        num_rows: 13368
    })
    test: Dataset({
        features: ['article', 'highlights', 'id'],
        num_rows: 11490
    })
})

In [4]:
checkpoint = 'google/flan-t5-xxl'

tk = AutoTokenizer.from_pretrained(checkpoint)

In [34]:
def prompt_template(inputs):
    """batched process"""
    t = f"Summarize the following news article:\n{{input}}\nSummary:\n"
    return [t.format(input=i) for i in inputs]

In [46]:
# tokenize dataset, set max_length to 1024 for input and 256 for output, truncation=True
# padding should be done in collator per batch to save computation resources
tk_ds = (
    ds.map(  # tokenize the prompt
        lambda x: tk(
            prompt_template(x["article"]),
            max_length=1024,
            truncation=True,
        ),
        batched=True,
    )
    .map(  # tokenize the target
        lambda x: {
            "labels": tk(x["highlights"], max_length=256, truncation=True)["input_ids"]
        },
        batched=True,
    )
    .remove_columns(ds["train"].column_names)  # adapt to the model's input format
)

tk_ds

  0%|          | 0/288 [00:00<?, ?ba/s]

  0%|          | 0/288 [00:00<?, ?ba/s]

  0%|          | 0/14 [00:00<?, ?ba/s]

  0%|          | 0/12 [00:00<?, ?ba/s]

DatasetDict({
    train: Dataset({
        features: ['input_ids', 'attention_mask', 'labels'],
        num_rows: 287113
    })
    validation: Dataset({
        features: ['input_ids', 'attention_mask', 'labels'],
        num_rows: 13368
    })
    test: Dataset({
        features: ['input_ids', 'attention_mask', 'labels'],
        num_rows: 11490
    })
})

In [47]:
# simple data analysis to find out max token length for each dataset
# max length of raw (prompted tokens + target tokens) for train, val, test = (5329, 3021, 3799)
# max length of truncated input + target for each split = 1280

def count_token(batch):
    """count number of tokens for input_ids and labels"""
    return {'length': [len(i + l) for i, l in zip(batch["input_ids"], batch["labels"])]}

len_ds = tk_ds.map(count_token, batched=True)

for split in len_ds.keys():
    print(f"max length of {split} = {np.max(len_ds[split]['length'])}")

  0%|          | 0/288 [00:00<?, ?ba/s]

  0%|          | 0/14 [00:00<?, ?ba/s]

  0%|          | 0/12 [00:00<?, ?ba/s]

max length of train = 1280
max length of validation = 1280
max length of test = 1280


Was keen to carry on then I see the result of offloading... Lesson learned. Check the result section first before doing anything. 