# Imports

In [None]:
import gc
import logging
import matplotlib.pyplot as plt
import metrics
import numpy as np
import os
import random
import torch

from torch.utils.data import DataLoader, RandomSampler
from torch.utils.data import Dataset
from tqdm import tqdm
from transformers import T5Tokenizer, T5ForConditionalGeneration
from transformers.optimization import Adafactor
from unidecode import unidecode

# Class and functions

In [2]:
class MyDataset(Dataset):

    def __init__(self, args, tokenizer, split):
        self.split = split
        self.all_targets = []  # Only for test set. Empty otherwise.
        examples = []
        filename = args.task_prefix + "_" + split + ".tsv" if split != "test" else args.test_filename

        with open(os.path.join(args.data_dir, args.task_name, filename), encoding="utf-8") as fin:
            lines = fin.readlines()
        for line in lines:
            d = unidecode(line).strip().split("\t")
            examples.append([d[0], random.choice(d[1:])])
            if self.split == "test":
                self.all_targets.append(d[1:])

        tokenized_input = tokenizer([ex[0] for ex in examples], padding=True,
                                    truncation=True, max_length=args.max_input_len)
        self.input_ids = tokenized_input['input_ids']
        self.attention_mask = tokenized_input['attention_mask']

        if self.split != "test":
            tokenized_output = tokenizer([ex[1] for ex in examples], padding=True,
                                         truncation=True, max_length=args.max_target_len)
            # Replace padding token id's of the labels by -100 so it's ignored by the loss
            label_ids = tokenized_output['input_ids']
            for i in range(len(label_ids)):
                label_ids[i] = [-100 if id == tokenizer.pad_token_id else id for id in label_ids[i]]
            self.label_ids = label_ids

    def __getitem__(self, idx):
        if self.split == "test":
            return (torch.LongTensor(self.input_ids[idx]),
                    torch.LongTensor(self.attention_mask[idx]))
        else:
            return (torch.LongTensor(self.input_ids[idx]),
                    torch.LongTensor(self.attention_mask[idx]),
                    torch.LongTensor(self.label_ids[idx]))

    def __len__(self):
        return len(self.input_ids)

In [3]:
def random_seed(value):
    torch.backends.cudnn.deterministic = True
    torch.manual_seed(value)
    torch.cuda.manual_seed(value)
    np.random.seed(value)
    random.seed(value)


def train(args, logger, model, optimizer):
    # Read training set
    train_dataset = MyDataset(args, TOKENIZER, "train")
    train_loader = DataLoader(train_dataset, sampler=RandomSampler(train_dataset), batch_size=args.batch_size)

    # Read validation set
    val_dataset = MyDataset(args, TOKENIZER, "dev")
    val_loader = DataLoader(val_dataset, sampler=RandomSampler(val_dataset), batch_size=args.batch_size)

    train_losses = []  # Train loss every eval_period step
    val_losses = []  # Val loss every eval_period step
    train_loss = []
    step = 0
    best_loss = float("inf")
    best_step = 0
    stop_training = False
    model.train()

    for epoch in range(args.n_epochs):    
        for batch in tqdm(train_loader, desc="Epoch {}".format(epoch)):
            input_ids, attention_mask, label_id = batch
            loss = model(input_ids=input_ids.to(device),
                         attention_mask=attention_mask.to(device),
                         labels=label_id.to(device)).loss
            train_loss.append(loss.item())

            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
            step += 1

            # Evaluate with validation set
            if step % args.eval_period == 0:
                model.eval()
                val_loss = 0
                for batch in val_loader:
                    input_ids, attention_mask, label_id = batch
                    loss = model(input_ids=input_ids.to(device),
                                 attention_mask=attention_mask.to(device),
                                 labels=label_id.to(device)).loss
                    val_loss += loss.detach().item()
                val_loss = val_loss / len(val_loader)
                val_losses.append(val_loss)
                train_losses.append(np.mean(train_loss))
                logger.info('Step: {}; Train loss: {}; Val loss: {}'.format(step, np.mean(train_loss), val_loss))
                train_loss = []

                if val_loss < best_loss:
                    best_loss = val_loss
                    best_step = step
                    logger.info('Found model with best loss at step {}'.format(step))
                    torch.save(
                        model.state_dict(),
                        os.path.join(args.output_dir, args.finetuned_model_name.format(step)))
                model.train()

            #if step % args.save_period == 0:
            #    torch.save(
            #        model.state_dict(),
            #        os.path.join(args.output_dir, args.finetuned_model_name.format(step)))

            if step >= args.total_steps:
                stop_training = True
                break

        torch.cuda.empty_cache()
        gc.collect()        
        if stop_training:
            break

    return best_step


def evaluate(args, logger, model, best_step):
    # Read test set
    test_dataset = MyDataset(args, TOKENIZER, "test")
    test_loader = DataLoader(test_dataset, batch_size=args.batch_size)

    # Use the best model
    model.load_state_dict(torch.load(
        os.path.join(args.output_dir, args.finetuned_model_name.format(best_step)),
        map_location=device))
    model.eval()

    predictions = []
    for batch in tqdm(test_loader):
        input_ids, attention_mask = batch
        output = model.generate(
            input_ids=input_ids.to(device),
            attention_mask=attention_mask.to(device),
            max_length=args.max_target_len,
            early_stopping=True)
        predictions.extend(TOKENIZER.batch_decode(output, skip_special_tokens=True))

    test_performance = metrics.evaluate(
        predictions, test_dataset.all_targets, metrics.METRICS[args.task_name])
    logger.info("Task: {}; Test score: {}; Metric: {}".format(
        args.task_prefix, test_performance, metrics.METRICS[args.task_name]))


def plot_learning_curve(args, last_step, train_losses, val_losses):
    li = list(zip(range(args.eval_period, last_step + 1, args.eval_period), train_losses))
    plt.plot(*zip(*li), label="train")
    li = list(zip(range(args.eval_period, last_step + 1, args.eval_period), val_losses))
    plt.plot(*zip(*li), label="val")
    plt.xlabel('step')
    plt.ylabel('loss')
    plt.legend()
    plt.show()


def get_task_prefixes_and_testfile(data_path: str, task_name: str) -> list:
    """Returns all task prefixes (e.g., adversarialqa_32_13) of a task."""
    files = sorted(os.listdir(os.path.join(data_path, task_name)))
    prefixes = []
    test_filename = None
    for filename in files:
        if not filename.endswith(".tsv"):
            continue
        if filename.endswith("test.tsv"):
            test_filename = filename
        prefix = "_".join(filename.split("_")[:-1])
        if prefix not in prefixes:
            prefixes.append(prefix)
    return prefixes, test_filename

# Parameters

In [4]:
TASK_NAMES = ["paws copy", "glue-mrpc copy"]

class TrainingArgs:
    def __init__(self):
        self.task_name = ""  # Leave it empty, will be updated based on TASK_NAMES.
        self.task_prefix = ""  # Leave it empty, will be updated based on TASK_NAMES.
        self.test_filename = ""  # Leave it empty, will be updated based on TASK_NAMES.
        self.data_dir = 'data/crossfit'
        self.output_dir = 'output'
        self.finetuned_model_name = ""  # Leave it empty, will be updated based on TASK_NAMES.
        self.t5_model = 'google/t5-v1_1-base'
        self.batch_size = 8
        self.max_input_len = 1024
        self.max_target_len = 8
        self.n_epochs = 1
        self.total_steps = 1000
        self.eval_period = 1  # Evaluate with validation step every 100 steps
        self.save_period = 200  # Save model every 200 steps

# Set up

In [5]:
args = TrainingArgs()
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
TOKENIZER = T5Tokenizer.from_pretrained(args.t5_model, model_max_length=1024)

logger = logging.getLogger()
logger.setLevel(level=logging.INFO)

logFileFormatter = logging.Formatter(
    fmt='%(asctime)s - %(levelname)s - %(name)s - %(message)s',
    datefmt='%m/%d/%Y %H:%M:%S',
)
fileHandler = logging.FileHandler(filename=os.path.join(args.output_dir, 'log.txt'))
fileHandler.setFormatter(logFileFormatter)
fileHandler.setLevel(level=logging.INFO)

logger.addHandler(fileHandler)

# Main

In [None]:
for task_name in TASK_NAMES:
    prefixes, test_filename = get_task_prefixes_and_testfile(args.data_dir, task_name)
    for prefix in prefixes:
        # Update args
        args.task_name = task_name
        args.task_prefix = prefix
        args.test_filename = test_filename
        args.finetuned_model_name = prefix + "_{}.pt"

        model = T5ForConditionalGeneration.from_pretrained(args.t5_model).to(device)
        optimizer = Adafactor(model.parameters())
        random_seed(0)

        print("Start training for task {}".format(args.task_prefix))
        best_step = train(args, logger, model, optimizer)
        print("Evaluating task {}".format(args.task_prefix))
        evaluate(args, logger, model, best_step)