Notebook adapted from:  
https://medium.com/askdata/train-t5-for-text-summarization-a1926f52d281  
https://colab.research.google.com/drive/14_A2kM8sOVpzwHn-0pMbfnD2htzI2Nte

# 0. Set up environment

In [1]:
import os
import torch
import numpy as np
import pandas as pd

from sklearn import model_selection
from torch import nn

SEED = 2557

In [2]:
%%script false  --no-raise-error
!pip install transformers
!pip install datasets

Let's use Weights & Biases for tracking

In [3]:
import wandb
wandb.login()

%env WANDB_LOG_MODEL=true

[34m[1mwandb[0m: Currently logged in as: [33mbryanli[0m (use `wandb login --relogin` to force relogin)


env: WANDB_LOG_MODEL=true


In [4]:
%cd ../glucose/

GLUCOSE_DIR = os.getcwd()
TRAIN_PATH = os.path.join(GLUCOSE_DIR, 't5_data/t5_training_data.tsv')
TEST_PATH = os.path.join(GLUCOSE_DIR, 't5_data/t5_test_data.txt')

/mnt/nlpgridio3/data/bryanli/projects/stories/glucose


In [5]:
T5_HEADER = ['input', 'output']
df_train_orig = pd.read_csv(TRAIN_PATH, sep='\t', names=T5_HEADER)
df_test_orig = pd.read_csv(TEST_PATH, sep='\t', names=T5_HEADER)
df_train_orig['input'] = '#' + df_train_orig['input']

# Data Preprocessing


In [6]:
def get_story_ids(story_col):
    stories = story_col.unique()
    story2id = {story: i for i, story in enumerate(stories)}
    return story_col.map(story2id)

def make_df(X_input):
    '''
    Creates an intermediate df, used for later formatting of input/output. Assigns a unique `story_id` to each story 
    
    Args:
        X_input (pd.Series): input field of T5 GLUCOSE dataset
    '''
    X_split = X_input.str.split(': ', 1, expand=True)
    dim, story = X_split[0], X_split[1]
    selected_split = story.str.split('*', 2, expand=True)
    story_before, target_sentence, story_after = selected_split[0], selected_split[1], selected_split[2]
    story = story_before + target_sentence + story_after
    story_id = get_story_ids(story)
    d = {'dim': dim, 'story_before': story_before, 'target': target_sentence, 'story_after': story_after, 'story': story, 'story_id': story_id}
    df = pd.DataFrame(d)
    return df

In [7]:
df_train = make_df(df_train_orig['input'])

Next, we split the dataset into train/val sets. We ensure that stories are not shared between the splits by randomly selecting 10% of `story_id` fields for validation.

In [8]:
story_ids = df_train['story_id'].unique()
ids_train, ids_val = model_selection.train_test_split(story_ids, test_size=.1, random_state=SEED)
df_train1 = df_train[df_train['story_id'].isin(ids_train)]
df_val1 = df_train[df_train['story_id'].isin(ids_val)]

# Experiment 1: Generation
Here, we frame the task as a generation problem.

In [9]:
def get_in_out_df(df):
    # for next sentence task, we exclude cases where there are no sentences before or after
    df = df[(df['story_before'] != '') & (df['story_after'] != '')].reset_index()
    df['input'] = df['dim'] + ': ' + df['story_before'].str.strip()
    df['output'] = df['target']
    return df

The task set up is  
input = #<dim\>: <story up to the target sentence\>  
output = <next sentence\> 

In [10]:
df_train1 = get_in_out_df(df_train1)
df_val1 = get_in_out_df(df_val1)
df_train1 = df_train1.sample(frac=1, random_state=SEED)
df_val1 = df_val1.sample(frac=1, random_state=SEED)

In [11]:
df_train1

Unnamed: 0,index,dim,story_before,target,story_after,story,story_id,input,output
53538,96957,#2,Mike wanted to earn a pizza party at school. I...,"That month, he worked as hard as he could.","Luckily, he got his grade up high enough and ...",Mike wanted to earn a pizza party at school. I...,3146,#2: Mike wanted to earn a pizza party at schoo...,"That month, he worked as hard as he could."
57487,103860,#7,My family recently got locked in an escape roo...,The clues were very tricky.,We missed escaping by thirty seconds!,My family recently got locked in an escape roo...,1643,#7: My family recently got locked in an escape...,The clues were very tricky.
16717,30188,#5,Gary went to hang with new friends.,They wanted to test him out.,They asked him to tag a train car. He said he...,Gary went to hang with new friends. They wante...,705,#5: Gary went to hang with new friends.,They wanted to test him out.
69568,125205,#1,Lynn gets a new dog. She loves her dog. She is...,Her dog runs off away from her.,She cannot catch him and he runs away.,Lynn gets a new dog. She loves her dog. She is...,3154,#1: Lynn gets a new dog. She loves her dog. Sh...,Her dog runs off away from her.
132708,237451,#10,Rebecca had a brand new trampoline! She invite...,They all had a blast jumping on the trampoline.,"One of them fell off, hit his head, and died.",Rebecca had a brand new trampoline! She invite...,479,#10: Rebecca had a brand new trampoline! She i...,They all had a blast jumping on the trampoline.
...,...,...,...,...,...,...,...,...,...
80734,145119,#2,I took my dog to the dog park today.,My dog ran around and played with all the othe...,I was starting to get bored of watching dogs ...,I took my dog to the dog park today. My dog ra...,152,#2: I took my dog to the dog park today.,My dog ran around and played with all the othe...
90785,162861,#6,I tried serving sweet potato fries to my family.,None of them liked them at all.,They left most of them on their plates. We ta...,I tried serving sweet potato fries to my famil...,2689,#6: I tried serving sweet potato fries to my f...,None of them liked them at all.
118507,212266,#1,Jane was in the bathroom.,But she dropped her phone in the toilet.,It no longer worked. So she needed a new one....,Jane was in the bathroom. But she dropped her ...,2391,#1: Jane was in the bathroom.,But she dropped her phone in the toilet.
88687,159058,#2,"Ann left the house, nervous for her first day ...",The teacher smiled and showed Ann where to sit.,The girl in front of Ann turned around and sm...,"Ann left the house, nervous for her first day ...",3777,"#2: Ann left the house, nervous for her first ...",The teacher smiled and showed Ann where to sit.


## Set up wand

In [12]:
EXP_NAME = 'glucose_exp1'
wandb.init(name=EXP_NAME)

## Tokenization

In [16]:
from transformers import AutoTokenizer
tokenizer = AutoTokenizer.from_pretrained('t5-base')

TOK_SAVE_DIR = f"{GLUCOSE_DIR}/t5_data/tokenized/"

In [15]:
import datasets
ds_train = datasets.Dataset.from_pandas(df_train1)
ds_val = datasets.Dataset.from_pandas(df_val1)

In [17]:
def get_src_tgt_len(source_text, target_text):
    tokenized_source_text = tokenizer(list(source_text), truncation=False, padding=False)
    tokenized_target_text = tokenizer(list(target_text), truncation=False, padding=False)

    max_source = 0
    for item in tokenized_source_text['input_ids']:
        if len(item) > max_source:
            max_source = len(item)

    max_target = 0
    for item in tokenized_target_text['input_ids']:
        if len(item) > max_target:
            max_target = len(item)
    return max_source, max_target

max_source, max_target = get_src_tgt_len(df_train1['input'], df_train1['output'])
print(max_source, max_target)

67 31


In [18]:
# %%script false --no-raise-error

def encode(batch):
    inp = tokenizer(batch['input'], padding='max_length', truncation=True, max_length=max_source)
    outp = tokenizer(batch['output'], padding='max_length', truncation=True, max_length=max_target)
    inp['labels'] = outp['input_ids']
    return inp

BATCH_SIZE_ENCODE = 512

ds_train = ds_train.map(encode, batched=True, batch_size=BATCH_SIZE_ENCODE)
ds_val = ds_val.map(encode, batched=True, batch_size=BATCH_SIZE_ENCODE)

ds_train.set_format('numpy', columns=['input_ids', 'attention_mask', 'labels'])
ds_val.set_format('numpy', columns=['input_ids', 'attention_mask', 'labels'])

# ds_train.save_to_disk(f'{TOK_SAVE_DIR}/train')
# ds_val.save_to_disk(f'{TOK_SAVE_DIR}/val')


HBox(children=(FloatProgress(value=0.0, max=325.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=37.0), HTML(value='')))




In [20]:
%%script false --no-raise-error 
TOK_SAVE_DIR = f"{GLUCOSE_DIR}/t5_data/tokenized/"
ds_train = datasets.load_from_disk(f'{TOK_SAVE_DIR}/train')
ds_val = datasets.load_from_disk(f'{TOK_SAVE_DIR}/val')

In [21]:
COLS_TO_FORMAT = ['input_ids', 'labels', 'attention_mask']
ds_train.set_format(type='torch', columns=COLS_TO_FORMAT)
ds_val.set_format(type='torch', columns=COLS_TO_FORMAT)

## Load pretrained model

In [28]:
from transformers import T5ForConditionalGeneration, Trainer, TrainingArguments

model = T5ForConditionalGeneration.from_pretrained('t5-base')

## Finetune

In [None]:
# os.environ["WANDB_WATCH"] = "false"
os.environ['CUDA_VISIBLE_DEVICES'] = '5,6,7'
OUTPUT_DIR = f'{GLUCOSE_DIR}/outputs/exp1'

training_args = TrainingArguments(
    output_dir=OUTPUT_DIR,
    num_train_epochs=3,
    per_device_train_batch_size=12,
    per_device_eval_batch_size=24,
    eval_accumulation_steps=1, # Number of eval steps to keep in GPU (the higher, the mor vRAM used)
    # prediction_loss_only=True, # If I need co compute only loss and not other metrics, setting this to true will use less RAM
    learning_rate=0.001,
    evaluation_strategy='steps', # Run evaluation every eval_steps
    save_steps=1000, # How often to save a checkpoint
    save_total_limit=1, # Number of maximum checkpoints to save
    remove_unused_columns=True, # Removes useless columns from the dataset
    run_name=EXP_NAME, # Wandb run name
    logging_steps=1000, # How often to log loss to wandb
    eval_steps=1000, # How often to run evaluation on the val_set
    logging_first_step=False, # Whether to log also the very first training step to wandb
    load_best_model_at_end=True, # Whether to load the best model found at each evaluation.
    metric_for_best_model="loss", # Use loss to evaluate best model.
    greater_is_better=False, # Best model is the one with the lowest loss, not highest.
)

trainer = Trainer(
    model=model,
    args=training_args,
    train_dataset=ds_train,
    eval_dataset=ds_val,

)
trainer.args._n_gpu = 3
trainer.train()
trainer.save_model(OUTPUT_DIR + '/model')



Step,Training Loss,Validation Loss


In [None]:
ds_val[0]['attention_mask'].shape

In [25]:
ds_val

Dataset({
    features: ['__index_level_0__', 'attention_mask', 'dim', 'index', 'input', 'input_ids', 'labels', 'output', 'story', 'story_after', 'story_before', 'story_id', 'target'],
    num_rows: 18778
})