<a href="https://colab.research.google.com/github/jeanlucjackson/w266_final_project/blob/main/code/model_training/awesome_T5_pt_train_seq.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [1]:
from google.colab import drive
drive.mount('/content/drive')

Drive already mounted at /content/drive; to attempt to forcibly remount, call drive.mount("/content/drive", force_remount=True).


In [2]:
model_name = "/content/drive/MyDrive/w266 NLP Final Project/Models/T5_base_pt_long.sn/"
tokenizer_name = "google/t5-v1_1-base"
dataset_name = "quac"
max_length=1024
batch_size=4

# Modify this filepath to where you want to save the model after fine-tuning
dir_path = "/content/drive/MyDrive/w266 NLP Final Project/Models/T5_base_pt_long.snq/"


## Question Generation using T5 in Colab without running out of RAM

This notebook is based on Natalie Ahn's notebook showing how to fine tune T5 in Colab without running out of RAM.  It is limited to PyTorch and has been modified to increase the speed of the dataset generators by dramatically reducing the amount of disk I/O.

In [3]:
!pip install -q transformers

In [4]:
!pip install -q sentencepiece

In [5]:
import os
import re
import numpy as np
import pandas as pd
import json

from pprint import pprint

import tensorflow as tf
import torch

from transformers import T5Tokenizer, T5ForConditionalGeneration
from transformers import Seq2SeqTrainingArguments, Seq2SeqTrainer

import warnings
warnings.simplefilter(action='ignore', category=FutureWarning)

In [6]:
dataset_root = "/content/drive/MyDrive/w266 NLP Final Project/Data/"

dataset_folder = dataset_root+dataset_name
training_file = dataset_folder + '/train_pairs.csv'
validation_file = dataset_folder + '/valid_pairs.csv'

In [7]:
training_df = pd.read_csv(training_file)
validation_df = pd.read_csv(validation_file)

In [8]:
training_df[:3][['orig', 'target']]

Unnamed: 0,orig,target
0,generate question: answer: Jonny Lang's Still ...,What are some of the performances he has done?
1,generate question: answer: The circuits overca...,How far could the images be sent?
2,generate question: answer: In the generic plan...,Are there any other interesting aspects about ...


In [9]:
validation_df[:3][['orig', 'target']]

Unnamed: 0,orig,target
0,generate question: answer: Greaves later admit...,Did Jimmy learn a lesson at West Ham United?
1,generate question: answer: Kapoor is of Hindu ...,What do you find interesting about the article?
2,"generate question: answer: On September 20, 20...",What number album was You Fail Me?


In [10]:
def preprocess_data_pt(text_pair, tokenizer, max_length=max_length):
    orig_text, target_text = text_pair
    orig_encoded = tokenizer.batch_encode_plus(
        [orig_text],
        max_length=max_length,
        padding='max_length',
        truncation=True,
        return_attention_mask=True,
        return_tensors='pt'
    )

    orig_input_ids = orig_encoded['input_ids'][0]
    orig_attention_mask = orig_encoded['attention_mask'][0]
    
    target_encoded = tokenizer.batch_encode_plus(
        [target_text],
        max_length=max_length,
        padding='max_length',
        truncation=True,
        return_attention_mask=True,
        return_tensors='pt'
    )
    
    label_ids = target_encoded['input_ids'][0]
    
    return {'input_ids': orig_input_ids,
            'attention_mask': orig_attention_mask,
            'labels': label_ids}

In [11]:
class TranslationDataGeneratorPT(tf.keras.utils.Sequence):
    
    def __init__(self,
                 tokenizer,
                 n_examples,
                 data_filename,
                 max_length=max_length,
                 shuffle=True):
        
        self.tokenizer = tokenizer
        self.n_examples = n_examples
        self.data = pd.read_csv(data_filename)
        self.max_length = max_length
        self.shuffle = shuffle
        
        # Initialize row order, call on_epoch_end to shuffle row indices
        self.row_order = np.arange(1, self.n_examples+1)
        self.on_epoch_end()
    
    def __len__(self):
        return self.n_examples
    
    def __getitem__(self, idx):
        row_to_load = self.row_order[idx]
        df = self.data[self.row_order[idx] - 1: self.row_order[idx]]
        
        text_pairs = df[['orig', 'target']].values.astype(str)[0]
        
        batch_data = preprocess_data_pt(
            text_pairs,
            self.tokenizer,
            self.max_length
        )

        return batch_data
    
    def __call__(self):
        for i in range(self.__len__()):
            yield self.__getitem__(i)
            
            if i == self.__len__()-1:
                self.on_epoch_end()
    
    def on_epoch_end(self):
        if self.shuffle:
            self.row_order = list(np.random.permutation(self.row_order))

In [12]:
# Download tokenizer and model

tokenizer = T5Tokenizer.from_pretrained(tokenizer_name)
model = T5ForConditionalGeneration.from_pretrained(model_name)

In [13]:
# Create the data generators for train and validation data, pytorch version

train_data_generator = TranslationDataGeneratorPT(
    tokenizer=tokenizer,
    n_examples=training_df.shape[0],
    data_filename=training_file,
    max_length=max_length
)

valid_data_generator = TranslationDataGeneratorPT(
    tokenizer=tokenizer,
    n_examples=validation_df.shape[0],
    data_filename=validation_file,
    max_length=max_length
)

In [14]:
file_path = dir_path + 'checkpoints'

args = Seq2SeqTrainingArguments(
    file_path,
    save_steps=2000,
    save_total_limit=5,
    evaluation_strategy='epoch',
    per_device_train_batch_size=batch_size,
    per_device_eval_batch_size=batch_size,
    num_train_epochs=1,
)

In [15]:
# Define the trainer, passing in the model, training args, and data generators

trainer = Seq2SeqTrainer(
    model,
    args,
    train_dataset=train_data_generator,
    eval_dataset=valid_data_generator
)

In [None]:
# Call train

trainer.train()

***** Running training *****
  Num examples = 69109
  Num Epochs = 1
  Instantaneous batch size per device = 4
  Total train batch size (w. parallel, distributed & accumulation) = 4
  Gradient Accumulation steps = 1
  Total optimization steps = 17278
  Number of trainable parameters = 247577856


Epoch,Training Loss,Validation Loss


In [None]:
trainer.save_model(dir_path)

### Post Training Review

In [None]:
%load_ext tensorboard

In [None]:
%tensorboard --logdir "/content/drive/MyDrive/w266 NLP Final Project/Models/T5_base_pt_long.snq/checkpoints"

### Does it seem to have worked?


In [None]:
sample_df = validation_df[0:20].copy()

In [None]:
predictions = []
for input_text in sample_df['orig']:
  inputs = tokenizer(input_text, return_tensors='pt', max_length=max_length, truncation=True)
  output_ids = model.generate(inputs['input_ids'].cuda(), max_length=50)
  prediction = "".join([tokenizer.decode(out_ids, skip_special_tokens=True, 
                                            clean_up_tokenization_spaces=False) for out_ids in output_ids])
  predictions.append(prediction)

sample_df['prediction'] = predictions

In [None]:
sample_df

You can load the model you trained using the .from_pretrained function you use for pretrained models. If you look in your drive folder, at the filepath you used in the trainer arguments, you'll see a checkpoint folder. Use the full path to that checkpoint folder as the argument to .from_pretrained, to load the model you saved again later.

In [None]:
bart_model_saved = T5ForConditionalGeneration.from_pretrained(dir_path)

In [None]:
device = torch.device('cuda:0')
bart_model_saved.to(device)
pass

In [None]:
dir_path

In [None]:
# Still works?
predictions = []
for input_text in sample_df['orig']:
  inputs = tokenizer(input_text, return_tensors='pt', max_length=max_length, truncation=True)
  output_ids = bart_model_saved.generate(inputs['input_ids'].cuda(), max_length=max_length)
  prediction = "".join([tokenizer.decode(out_ids, skip_special_tokens=True, 
                                    clean_up_tokenization_spaces=False) for out_ids in output_ids])
  predictions.append(prediction)

predictions