In [None]:
import os
from collections import Counter
import string
import re
import json
import sys

import transformers
from transformers import AutoTokenizer
from transformers import T5Config, T5ForConditionalGeneration 
from datasets import load_dataset, load_metric
import torch
from torch.utils.data import DataLoader
from transformers import Adafactor
from tqdm.notebook import tqdm

## Load WebQA Dataset
Load and split the dataset. 10% of the training set is allocated as validation.

*Note*: To test for ablation, simply change `'train[:90%]'` to a lower value,
e.g. `'train[:45%]'` for half the original size.

In [None]:
# Download and split the dataset
TRAIN_IDX = 0
VAL_IDX = 1
TEST_IDX = 2
dataset = load_dataset('wiki_qa', split=['train[:90%]', 'train[-10%:]', 'test'])

## Hyperparameters

In [None]:
DROPOUT = 0.05  # Dropout; 0.05 in original paper
LR = 1e-3  # Learning rate; 1e-3 in original paper
BATCH_SIZE = 8  # Training batch size
VAL_BATCH_SIZE = 8  # Validation batch size

In [None]:
MODEL_NAME = 't5-small'

## Preprocessing
* Load tokenizer
* Tokenize and pad in batches

In [None]:
# Load the tokenizer
tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
assert isinstance(tokenizer, transformers.PreTrainedTokenizerFast)

In [None]:
def prepare_examples(examples):
    '''Passed to Dataset.map. Tokenize and pad in batches; pick first answer'''

    # Tokenize questions
    tokenized_examples = tokenizer(
        examples["question"],
        padding=True,
    )

    # Tokenize target as 'label_ids'
    tokenized_examples['label_ids'] = tokenizer(examples['answer'],
                                                padding=True).input_ids
    return tokenized_examples

def get_tokenized(dataset, idx, batch_size):
    '''Get the preprocessed (tokenized) dataset.
    
    Args:
        dataset:    The dataset to preprocess.
        idx:        The index of the dataset (one of TRAIN_IDX, VAL_IDX, TEST_IDX)
        batch_size: The batch size.
    '''
    return dataset[idx].map(prepare_examples,
                            batched=True,
                            batch_size=batch_size,
                            remove_columns=dataset[idx].column_names)

In [None]:
# Filter out answers with label 0
dataset[TRAIN_IDX] = dataset[TRAIN_IDX].filter(lambda x: x['label'] == 1)
dataset[VAL_IDX] = dataset[VAL_IDX].filter(lambda x: x['label'] == 1)

# Tokenize datasets
tokenized_train = get_tokenized(dataset, TRAIN_IDX, BATCH_SIZE)
tokenized_val = get_tokenized(dataset, VAL_IDX, VAL_BATCH_SIZE)

# Set the token ids to torch tensors
tokenized_train.set_format('torch', columns=['input_ids', 'attention_mask', 'label_ids'])
tokenized_val.set_format('torch', columns=['input_ids', 'attention_mask', 'label_ids'])

## Load T5 Model Checkpoint
* Load config separately to update dropout rate

In [None]:
# Load config; override dropout
# Overwrite task_specific_params since they are irrelevant to this task and
# may cause some obscure error
config = T5Config.from_pretrained(MODEL_NAME, dropout_rate=DROPOUT, task_specific_params=dict())

# Load pretrained model with config
model = T5ForConditionalGeneration.from_pretrained(MODEL_NAME, config=config)

assert model.config.dropout_rate == DROPOUT

# Change device of model
device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')
model = model.to(device)

## Utility Functions
* Get remaining memory
* Metrics
* Evaluation functions

In [None]:
def get_memory():
    '''Get remaining GPU memory in bytes; used to test for memory leak.'''
    t = torch.cuda.get_device_properties(0).total_memory
    r = torch.cuda.memory_reserved(0) 
    a = torch.cuda.memory_allocated(0)
    f = r - a  # free inside reserved
    return f

In [None]:
'''Adapted from the SQUAD evaluation script:
https://github.com/allenai/bi-att-flow/blob/master/squad/evaluate-v1.1.py'''

exclude = set(string.punctuation)

def normalize_answer(s):
    """Lower text and remove punctuation, articles and extra whitespace."""
    def remove_articles(text):
        return re.sub(r'\b(a|an|the)\b', ' ', text)

    def white_space_fix(text):
        return ' '.join(text.split())

    def remove_punc(text):
        return ''.join(ch for ch in text if ch not in exclude)

    def lower(text):
        return text.lower()

    return white_space_fix(remove_articles(remove_punc(lower(s))))


def f1_score(prediction, ground_truth):
    '''Return the F1 score given prediction and ground truth strings.
    
    The F1 score is token-based.
    '''
    prediction_tokens = normalize_answer(prediction).split()
    ground_truth_tokens = normalize_answer(ground_truth).split()
    common = Counter(prediction_tokens) & Counter(ground_truth_tokens)
    num_same = sum(common.values())
    if num_same == 0:
        return 0
    precision = 1.0 * num_same / len(prediction_tokens)
    recall = 1.0 * num_same / len(ground_truth_tokens)
    f1 = (2 * precision * recall) / (precision + recall)
    return f1

def exact_match_score(prediction, ground_truth):
    '''Return the exact match score given prediction and ground truth strings.
    
    1 is returned if two strings match exactly; 0 otherwise.
    '''
    return (normalize_answer(prediction) == normalize_answer(ground_truth))


def evaluate(gold_answers, predictions):
    '''Return F1 and exact match score, given the gold answers and predictions.

    The maximum is taken among all the possible answers.

    Args:
        gold_answers:   A list of correct answer strings.
        predictions:    A list of predicted answer strings.
    '''
    f1 = exact_match = total = 0

    for gt, prediction in zip(gold_answers, predictions):
      total += 1
      exact_match += exact_match_score(prediction, gt)
      f1 += f1_score(prediction, gt)
    
    exact_match = 100.0 * exact_match / total
    f1 = 100.0 * f1 / total

    return {'exact_match': exact_match, 'f1': f1}

In [None]:
def evaluate_dataset(model, dataloader, dataset_idx):
    '''Evaluate on an entire dataset.
    
    Args:
        model:          The model to evaluate.
        dataloader:     The DataLoader of the dataset.
        dataset_idx:    One of {TRAIN/VAL/TEST}_IDX
    '''
    exact_match_sum = 0
    f1_sum = 0
    with torch.no_grad():
        index = 0
        for batch in tqdm(dataloader):
            batch_size = len(batch['input_ids'])

            # Manually retrive list of answer strings for each question
            answers = dataset[dataset_idx]['answer'][index : index + batch_size]
            index += batch_size
            
            inputs = batch['input_ids'].to(device)
            mask = batch['attention_mask'].to(device)
            
            # The padding does not really matter, since they will be skipped
            # during decode anyway
            outs = model.generate(input_ids=inputs, attention_mask=mask,
                                  max_length=16, early_stopping=True)
            
            # Decode and skip special tokens such as padding.
            outs = [tokenizer.decode(ids, skip_special_tokens=True) for ids in outs] 

            result = evaluate(answers, outs)
            exact_match_sum += result['exact_match']
            f1_sum += result['f1']

    return {
        'exact_match': exact_match_sum / len(dataloader),
        'f1': f1_sum / len(dataloader),
    }

## Training

In [None]:
import time

t0 = time.time()

model.train()

train_loader = DataLoader(tokenized_train, batch_size=BATCH_SIZE)
val_loader = DataLoader(tokenized_val, batch_size=VAL_BATCH_SIZE)

optim = Adafactor(model.parameters(), lr=LR, relative_step=False)

loss_list = list()
f1_list = list()
exact_match_list = list()

for epoch in range(251):
    cur_loss = 0.

    # Train
    for batch in tqdm(train_loader):
        optim.zero_grad()
        input_ids = batch['input_ids'].to(device)
        label_ids = batch['label_ids'].to(device)
        loss = model(input_ids, labels=label_ids).loss
        cur_loss += loss.item()
        loss.backward()
        optim.step()

    # Print and save every 10 epochs
    if epoch % 10 == 0:
        val_result = evaluate_dataset(model, val_loader, VAL_IDX)
        # train_result = evaluate_dataset(model, train_loader, TRAIN_IDX)
        elapsed = time.time() - t0
        print('Epoch {3} - Loss: {0}; Val F1: {1}; Val exact match: {2}%;'
        ' Elapsed {4:.2f} secs'.format(cur_loss,
                                       val_result['f1'],
                                       val_result['exact_match'],
                                       epoch,
                                       elapsed))
        
        # Record loss and metric for plotting later
        loss_list.append(cur_loss)
        f1_list.append(val_result['f1'])
        exact_match_list.append(val_result['exact_match'])
        torch.save({
            'epoch': epoch,
            'model_state_dict': model.state_dict(),
            'optimizer_state_dict': optim.state_dict(),
            'loss_list': loss_list,
            'f1_list': f1_list,
            'exact_match_list': exact_match_list,
            'elapsed': elapsed,
            }, 'webqa-model-{}.pth'.format(epoch))
    # print('Remaining memory: {:.2f} GB'.format(get_memory() / 1024 / 1024 / 1024))


## Final Evaluation
This is not really necessary except for the sake of clarity, since we already evaluate during training.

The test set is not used, since the original paper uses validation only as well.

In [None]:
model.eval()
val_result = evaluate_dataset(model, val_loader, VAL_IDX)
print('Final validation results: {}'.format(val_result))