<table class="tfo-notebook-buttons" align="left">
  <td>
    <a target="_blank" href="https://colab.research.google.com/github/milmor/NLP/blob/main/Notebooks/20_Flan-T5_hf.ipynb">
    <img src="https://www.tensorflow.org/images/colab_logo_32px.png" />
    Run in Google Colab</a>
  </td>
</table>

# Fine-tune Flan-T5

In [1]:
import os
os.environ['TF_CPP_MIN_LOG_LEVEL'] = '3' # Disable tensorflow debugging logs
os.environ["KERAS_BACKEND"] = "torch"
import keras
import torch
import pandas as pd
import pathlib
import random

torch.__version__

'2.5.1+cu124'

In [2]:
from transformers import AutoModelForSeq2SeqLM, AutoTokenizer
from transformers import T5Tokenizer, T5ForConditionalGeneration
from datasets import Dataset

In [3]:
checkpoint = "google/flan-t5-base"
# checkpoint = "t5-small"
model = AutoModelForSeq2SeqLM.from_pretrained(checkpoint)
tokenizer = AutoTokenizer.from_pretrained(checkpoint)

inputs = tokenizer("Tell something people love", return_tensors="pt")
outputs = model.generate(**inputs,  max_new_tokens=20)
print(tokenizer.batch_decode(outputs, skip_special_tokens=True))

['a good story']


In [4]:
task_prefix = "Translate from English to Spanish: "
sentences = ["I like to read.", "The black dog."]

inputs = tokenizer([task_prefix + sentence for sentence in sentences], return_tensors="pt", padding=True)

output_sequences = model.generate(
    input_ids=inputs["input_ids"],
    attention_mask=inputs["attention_mask"],
    do_sample=False,  # disable sampling to test if batching affects output
    max_new_tokens=20
)

print(tokenizer.batch_decode(output_sequences, skip_special_tokens=True))

['Mientras leer.', 'El chihuahua negra.']


## 1.- Conjuntos de datos

In [5]:
text_file = keras.utils.get_file(
    fname="spa-eng.zip",
    origin="http://storage.googleapis.com/download.tensorflow.org/data/spa-eng.zip",
    extract=True,
)
text_file = pathlib.Path(text_file).parent / "spa-eng" / "spa.txt"

In [6]:
with open(text_file) as f:
    lines = f.read().split("\n")[:-1]

len(lines)

118964

In [7]:
translation = []
idx = []
for i, line in enumerate(lines):
    eng, spa = line.split("\t")
    idx.append(i)
    translation.append({'es': spa, 'en':eng})

translation[0], idx[0]

({'es': 'Ve.', 'en': 'Go.'}, 0)

In [8]:
my_dict = { "id": idx, "translation": translation}

pairs = Dataset.from_dict(my_dict)
pairs = pairs.train_test_split(test_size=0.02)
pairs

DatasetDict({
    train: Dataset({
        features: ['id', 'translation'],
        num_rows: 116584
    })
    test: Dataset({
        features: ['id', 'translation'],
        num_rows: 2380
    })
})

## 2.- Pipeline

In [9]:
source_lang = "en"
target_lang = "es"

prefix = "translate English to Spanish: "


def preprocess_function(examples):
    inputs = [prefix + example[source_lang] for example in examples["translation"]]
    targets = [example[target_lang] for example in examples["translation"]]

    model_inputs = tokenizer(inputs, text_target=targets, max_length=64, truncation=True)

    return model_inputs

In [10]:
tokenized_pairs = pairs.map(preprocess_function, batched=True)

Map:   0%|          | 0/116584 [00:00<?, ? examples/s]

Map:   0%|          | 0/2380 [00:00<?, ? examples/s]

In [11]:
from transformers import Seq2SeqTrainingArguments, Seq2SeqTrainer

In [12]:
from transformers import DataCollatorForSeq2Seq

data_collator = DataCollatorForSeq2Seq(tokenizer=tokenizer, model=checkpoint)

## 3.- Entrenamiento

In [13]:
# pip install sacrebleu

In [14]:
def preprocess_logits_for_metrics(logits, labels):
    """
    Original Trainer may have a memory leak. 
    This is a workaround to avoid storing too many tensors that are not needed.
    """
    pred_ids = torch.argmax(logits[0], dim=-1)
    return pred_ids, labels

In [15]:
import numpy as np
import evaluate

metric = evaluate.load("sacrebleu")


def postprocess_text(preds, labels):
    preds = [pred.strip() for pred in preds]
    labels = [[label.strip()] for label in labels]

    return preds, labels


def compute_metrics(eval_preds):

    preds, labels = eval_preds

    if isinstance(preds, tuple):

        preds = preds[0]
    preds= np.where(preds != -100, preds, tokenizer.pad_token_id)
    decoded_preds = tokenizer.batch_decode(preds, skip_special_tokens=True, clean_up_tokenization_spaces=True)

    labels = np.where(labels != -100, labels, tokenizer.pad_token_id)
    decoded_labels = tokenizer.batch_decode(labels, skip_special_tokens=True, clean_up_tokenization_spaces=True)

    decoded_preds, decoded_labels = postprocess_text(decoded_preds, decoded_labels)

    result = metric.compute(predictions=decoded_preds, references=decoded_labels)
    result = {"bleu": result["score"]}

    prediction_lens = [np.count_nonzero(pred != tokenizer.pad_token_id) for pred in preds]
    result["gen_len"] = np.mean(prediction_lens)
    result = {k: round(v, 4) for k, v in result.items()}

    return result

In [16]:
max_steps = 2000
training_args = Seq2SeqTrainingArguments(
    output_dir="./ckpt-flan",
    eval_strategy="steps",
    eval_steps=250,
    learning_rate=2e-4,
    per_device_train_batch_size=16,
    per_device_eval_batch_size=16,
    weight_decay=0.01,
    save_total_limit=3,
    num_train_epochs=1,
    fp16=False, # fp16 nan loss,
    max_steps=max_steps,
)

trainer = Seq2SeqTrainer(
    model=model,
    args=training_args,
    train_dataset=tokenized_pairs["train"],
    eval_dataset=tokenized_pairs["test"],
    processing_class=tokenizer,
    data_collator=data_collator,
    compute_metrics=compute_metrics,
    preprocess_logits_for_metrics=preprocess_logits_for_metrics
)

trainer.train()

max_steps is given, it will override any value given in num_train_epochs
Passing a tuple of `past_key_values` is deprecated and will be removed in Transformers v4.48.0. You should pass an instance of `EncoderDecoderCache` instead, e.g. `past_key_values=EncoderDecoderCache.from_legacy_cache(past_key_values)`.


Step,Training Loss,Validation Loss,Bleu,Gen Len
250,No log,0.961517,9.9434,29.6437
500,1.259200,0.890793,10.7162,29.6437
750,1.259200,0.845328,11.0239,29.6437
1000,1.086400,0.814795,11.4636,29.6437
1250,1.086400,0.794564,11.9756,29.6437
1500,0.997200,0.771946,12.3276,29.6437
1750,0.997200,0.760798,12.5314,29.6437
2000,0.953800,0.754867,12.5798,29.6437


TrainOutput(global_step=2000, training_loss=1.0741595764160157, metrics={'train_runtime': 233.5918, 'train_samples_per_second': 136.991, 'train_steps_per_second': 8.562, 'total_flos': 925727744507904.0, 'train_loss': 1.0741595764160157, 'epoch': 0.2744613695622341})

In [17]:
model = AutoModelForSeq2SeqLM.from_pretrained(f"./ckpt-flan/checkpoint-{max_steps}")

output_sequences = model.generate(
    input_ids=inputs["input_ids"],
    attention_mask=inputs["attention_mask"],
    do_sample=False,  # disable sampling to test if batching affects output
    max_new_tokens=20
)

print(tokenizer.batch_decode(output_sequences, skip_special_tokens=True))

['Me gusta leer.', 'El perro negro.']
