## Fine-tuning a Seq2Seq model (T5) in Colab with limited RAM

This notebook is an extension of the notebook [Training NLP models in Colab without running out of RAM](https://github.com/datasci-w266/2024-summer-main/blob/master/materials/walkthrough_notebooks/keras_with_limited_ram/keras_training_with_limited_ram.ipynb). This series focuses on how to avoid running out of memory by loading part of your data at a time while you train, and saving model checkpoints as you go. We recommend reading that earlier notebook first, which has more complete explanations of these techniques shown, but for fine-tuning a BERT model.

This notebook focuses on sequence-to-sequence (encoder-decoder, text generation) models like T5, because the way you fine-tune the Huggingface pretrained versions of those models is a bit different than BERT. With T5, you use the full pre-trained model end-to-end without adding any additional layers.

That said, you can still set up the training process in a similar way to how you'd set it up for BERT. This notebook is for tensorflow models, which allows you to use keras. We also have a [similar notebook for pytorch models](https://github.com/datasci-w266/2024-summer-main/blob/master/materials/walkthrough_notebooks/keras_with_limited_ram/fine_tune_t5_with_limited_ram_pytorch.ipynb), since some huggingface pretrained models are only available in pytorch.

[![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/datasci-w266/2024-summer-main/blob/master/materials/walkthrough_notebooks/keras_with_limited_ram/fine_tune_t5_with_limited_ram_keras.ipynb)

In [1]:
!pip install -q transformers==4.37.2

[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m3.8/3.8 MB[0m [31m21.5 MB/s[0m eta [36m0:00:00[0m
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m897.5/897.5 kB[0m [31m24.8 MB/s[0m eta [36m0:00:00[0m
[?25h

In [2]:
!pip install sentencepiece



In [3]:
import os
import re
import numpy as np
import pandas as pd

import tensorflow as tf
from tensorflow import keras
from tensorflow.keras import layers

from transformers import T5Tokenizer, TFT5ForConditionalGeneration

### Data

To fine-tune T5, we'll use the dataset from the [week 6 lesson notebook](https://github.com/datasci-w266/2024-summer-main/blob/master/materials/lesson_notebooks/lesson_6_Machine_Translation.ipynb) for translating Shakespeare to modern English. You can [download the dataset here](https://github.com/cocoxu/Shakespeare), or access [the copy that is in the lesson_notebooks directory](https://github.com/datasci-w266/2024-summer-main/blob/master/materials/lesson_notebooks/train_plays-org-mod.txt) in the class git repo and then upload to your drive folder.

In [4]:
# This cell will authenticate you and mount your Drive in the Colab.
from google.colab import drive
drive.mount('/content/drive')

Mounted at /content/drive


In [5]:
# Modify this path to where you saved the Shakespear data in your Drive
text_file = 'drive/MyDrive/ISchool/MIDS/266/data/train_plays-org-mod.txt'

In [6]:
with open(text_file) as f:
    lines = f.read().split('\n')[:-1]

prefix = 'translate old to modern: '
text_pairs = []
for line in lines:
    orig, target = line.split('\t')
    text_pairs.append({'orig': prefix + orig, 'target': target})

In [7]:
# Look at some examples
for _ in range(5):
    print(np.random.choice(text_pairs))

{'orig': "translate old to modern: Think what thou wilt, I am thy lover's grace; And like Limander am I trusty still.", 'target': "Think what you will, I am your lover's grace; And like Limander, I am still trusty."}
{'orig': 'translate old to modern: Your brother is but young and tender, and, for your love I would be loath to foil him, as I must for my own honor if he come in.', 'target': 'Your brother is young and inexperienced, and because of my affection for you, I’d hate to crush him—though I’d have to, if he challenged me.'}
{'orig': 'translate old to modern: Give me some help.', 'target': 'Give me some help.'}
{'orig': 'translate old to modern: I had rather hear my dog bark at a crow than a man swear he loves me.', 'target': 'I would rather listen to my dog bark at a crow than hear a man swear that he loves me.'}
{'orig': 'translate old to modern: I will ask him for my place again; he shall tell me I am a drunkard!', 'target': 'I’ll ask him for my job back; he’ll tell me I am a 

In [8]:
# Let's create some splits
np.random.shuffle(text_pairs)
num_valid_samples = int(0.15 * len(text_pairs))
num_train_samples = len(text_pairs) - 2 * num_valid_samples
train_pairs = text_pairs[:num_train_samples]
valid_pairs = text_pairs[num_train_samples : num_train_samples + num_valid_samples]
test_pairs = text_pairs[num_train_samples + num_valid_samples :]

print(f"{len(text_pairs)} total pairs")
print(f"{len(train_pairs)} training pairs")
print(f"{len(valid_pairs)} validation pairs")
print(f"{len(test_pairs)} test pairs")

19088 total pairs
13362 training pairs
2863 validation pairs
2863 test pairs


In [10]:
# Save splits to separate csv files, to load only part at a time later
train_file = 'drive/MyDrive/ISchool/MIDS/266/data/train_pairs.csv'
valid_file = 'drive/MyDrive/ISchool/MIDS/266/data/valid_pairs.csv'
test_file = 'drive/MyDrive/ISchool/MIDS/266/data/test_pairs.csv'

pd.DataFrame(train_pairs).to_csv(train_file)
pd.DataFrame(valid_pairs).to_csv(valid_file)
pd.DataFrame(test_pairs).to_csv(test_file)

### Preprocessor and Data Generator

As in the earlier notebook for BERT models, we'll define a preprocessing function that takes a tokenizer and one batch of text data, tokenizes the text and returns the inputs to the model (input vocab ids, input attention mask, and output vocab ids as labels).

Then we'll define a data generator class that will load one batch of data from file every time keras gets a new batch for training. This way, we don't load all of our data into memory at once. The data generator will call the preprocessing function, returning a list of model inputs plus the labels.

For a seq2seq model, we'll not only pass in the input_ids and attention_mask for the encoder (original text), we'll also need to pass in the decoder_input_ids (vocab ids for the output text). The T5 model has a handy function to shift the output vocab ids (i.e. the labels) over by one, so they start with the starter token for the decoder inputs.

In [11]:
def preprocess_data(text_pairs, tokenizer, model, max_length=128):
    orig_text = [orig for orig, target in text_pairs]
    orig_encoded = tokenizer.batch_encode_plus(
        orig_text,
        max_length=max_length,
        padding='max_length',
        truncation=True,
        return_attention_mask=True,
        return_tensors='tf'
    )

    orig_input_ids = np.array(orig_encoded["input_ids"], dtype="int32")
    orig_attention_masks = np.array(orig_encoded["attention_mask"], dtype="int32")

    target_text = [target for orig, target in text_pairs]
    target_encoded = tokenizer.batch_encode_plus(
        target_text,
        max_length=max_length,
        padding='max_length',
        truncation=True,
        return_tensors='tf'
    )

    label_ids = np.array(target_encoded['input_ids'])
    decoder_input_ids = model._shift_right(label_ids)

    return [orig_input_ids, orig_attention_masks, decoder_input_ids], label_ids

In [35]:
class TranslationDataGenerator(tf.keras.utils.Sequence):

    def __init__(self,
                 tokenizer,
                 model,
                 n_examples,
                 data_filename,
                 max_length=128,
                 batch_size=16,
                 shuffle=True):

        self.tokenizer = tokenizer
        self.model = model
        self.n_examples = n_examples
        self.data_filename = data_filename
        self.max_length = max_length
        self.batch_size = batch_size
        self.shuffle = shuffle

        # Initialize row order, call on_epoch_end to shuffle row indices
        self.row_order = np.arange(1, self.n_examples+1)
        self.on_epoch_end()

    def __len__(self):
        # Return the number of batches in the full dataset
        return self.n_examples // self.batch_size

    def __getitem__(self, idx):
        batch_start = idx * self.batch_size
        batch_end = (idx + 1) * self.batch_size

        # Indices to skip are the ones in the shuffled row_order before and
        # after the chunk we'll use for this batch
        batch_idx_skip = self.row_order[:batch_start] + self.row_order[batch_end:]
        df = pd.read_csv(self.data_filename, skiprows=batch_idx_skip)

        text_pairs = df[['orig', 'target']].values.astype(str).tolist()

        batch_data = preprocess_data(
            text_pairs,
            self.tokenizer,
            self.model,
            self.max_length
        )

        return batch_data

    def __call__(self):
        for i in range(self.__len__()):
            yield self.__getitem__(i)

            if i == self.__len__()-1:
                self.on_epoch_end()

    def on_epoch_end(self):
        if self.shuffle:
            self.row_order = list(np.random.permutation(self.row_order))

### Pretrained model

Huggingface's pretrained tensorflow models are keras models, so you could call .compile() and .fit() directly on the pre-trained T5 model. But for sequence-to-sequence models, it can be tricky to make sure the right inputs are going into the right part of the model (encoder vs decoder, etc).

Even though we aren't adding any other layers, we can still create a keras model wrapper around the pretrained T5 model. That way, we can pass in the right inputs into the model using keyword arguments.

We'll use the first output of the T5 model (the logits for the output vocab) as the output of the overall model, and compile with crossentropy loss. Then we can call .fit on the wrapper model like we did in the last notebook, passing in the data generators for train and validation data instead of a fully loaded dataset.

In [31]:
# Load the pretrained tensorflow model

model_name = 't5-base'
t5_tokenizer = T5Tokenizer.from_pretrained(model_name)
t5_model = TFT5ForConditionalGeneration.from_pretrained(model_name)

All model checkpoint layers were used when initializing TFT5ForConditionalGeneration.

All the layers of TFT5ForConditionalGeneration were initialized from the model checkpoint at t5-base.
If your task is similar to the task the model of the checkpoint was trained on, you can already use TFT5ForConditionalGeneration for predictions without further training.


In [36]:
# Create the data generators for train and validation data, tensorflow version

max_length = 32
batch_size = 16

train_data_generator = TranslationDataGenerator(
    tokenizer=t5_tokenizer,
    model=t5_model,
    n_examples=len(train_pairs),
    data_filename=train_file,
    max_length=max_length,
    batch_size=batch_size
)

valid_data_generator = TranslationDataGenerator(
    tokenizer=t5_tokenizer,
    model=t5_model,
    n_examples=len(valid_pairs),
    data_filename=valid_file,
    max_length=max_length,
    batch_size=batch_size
)

In [37]:
def build_t5_training_wrapper_model(t5_model, max_length):
    input_ids = layers.Input(shape=(max_length), dtype=tf.int32, name='input_ids')
    attention_mask = layers.Input(shape=(max_length), dtype=tf.int32, name='attention_mask')
    decoder_input_ids = layers.Input(shape=(max_length), dtype=tf.int32, name='labels')

    t5_logits = t5_model(input_ids, attention_mask=attention_mask, decoder_input_ids=decoder_input_ids)[0]

    model = tf.keras.models.Model(inputs=[input_ids, attention_mask, decoder_input_ids],
                                  outputs=[t5_logits])
    model.compile(optimizer=tf.keras.optimizers.Adam(),
                  loss=tf.losses.SparseCategoricalCrossentropy(from_logits=True),
                  metrics=['accuracy'])

    return model

In [38]:
model_wrapper = build_t5_training_wrapper_model(t5_model, max_length)

In [39]:
# As in the first notebook, we should add a model checkpoint callback to save
# the trained model weights after each epoch. Edit the filepath to where
# you want to save the weights in your own Drive

checkpoint_dir = 'drive/MyDrive/ISchool/MIDS/266/model_checkpoints/'
checkpoint_filepath = checkpoint_dir + 't5_shakespeare_weights.{epoch:02d}-{val_accuracy:.2f}.hdf5'
model_checkpoint_callback = tf.keras.callbacks.ModelCheckpoint(
    filepath=checkpoint_filepath,
    save_weights_only=True)

In [40]:
# Now call .fit on the model_wrapper, passing in the data generators and the
# model checkpoint callback

model_wrapper.fit(train_data_generator,
                  validation_data=valid_data_generator,
                  epochs=1,
                  callbacks=[model_checkpoint_callback])



<keras.src.callbacks.History at 0x7fcb5028a4d0>

### Does it work?

Depending on your task, you'll add your own model evaluation after training. Here's a simple check to make sure it does seem to have fine-tuned T5 for this new task we defined.

In [41]:
for test_input_text in ['Hence forth thou shalt not vex me e\'er again.',
                        'Dost thou foresake me?',
                        'Makest thine own dinner.']:
    test_inputs = t5_tokenizer([prefix + test_input_text], return_tensors='tf')
    test_output_ids = t5_model.generate(test_inputs['input_ids'])

    print([t5_tokenizer.decode(out_ids, skip_special_tokens=True,
                               clean_up_tokenization_spaces=False) for out_ids in test_output_ids])

['You won’t vex me again.']
['Do you want to see me?']
['Make your own dinner.']


In [42]:
# To pick back up where you left off, load the saved model weights
# (Edit the filename to the last saved one that you want to load)

checkpoint_filepath = checkpoint_dir + 't5_shakespeare_weights.01-0.85.hdf5'
model_wrapper.load_weights(checkpoint_filepath)

In [43]:
# Still works?
for test_input_text in ['Hence forth thou shalt not vex me e\'er again.',
                        'Dost thou foresake me?',
                        'Makest thine own dinner.']:
    test_inputs = t5_tokenizer([prefix + test_input_text], return_tensors='tf')
    test_output_ids = t5_model.generate(test_inputs['input_ids'])

    print([t5_tokenizer.decode(out_ids, skip_special_tokens=True,
                               clean_up_tokenization_spaces=False) for out_ids in test_output_ids])

['You won’t vex me again.']
['Do you want to see me?']
['Make your own dinner.']
