# Imports

In [None]:
import gc
import json
import logging
import matplotlib.pyplot as plt
import numpy as np
import os
import random
import torch
import utils

from torch.optim import Adam
from torch.utils.data import DataLoader, RandomSampler
from torch.utils.data import Dataset
from tqdm import tqdm
from transformers import T5Tokenizer, T5ForConditionalGeneration

# Dataset class

In [None]:
class ICTDevDataset(Dataset):

    def __init__(self, raw_data, tokenizer, args):
        examples = []
        for task in raw_data:
            for train_example in task['train_examples']:
                demos = utils.sample_demos(task['dev_examples'], args.k, utils.n_label(task['task_name']))
                input_text = utils.create_input_text(demos, train_example[0], 'label:', '. ')
                target = random.choice(train_example[1])
                if len(tokenizer(input_text)['input_ids']) <= args.max_input_len:
                    examples.append([input_text, target])

        tokenized_input = tokenizer([example[0] for example in examples], padding=True, truncation=True, max_length=args.max_input_len)
        tokenized_output = tokenizer([example[1] for example in examples], padding=True, truncation=True, max_length=args.max_target_len)

        # Replace padding token id's of the labels by -100 so it's ignored by the loss
        label_ids = tokenized_output['input_ids']
        for i in range(len(label_ids)):
            label_ids[i] = [-100 if id == tokenizer.pad_token_id else id for id in label_ids[i]]

        self.input_ids = tokenized_input['input_ids']  # list of list
        self.attention_mask = tokenized_input['attention_mask']
        self.label_ids = label_ids

    def __getitem__(self, idx):
        return torch.LongTensor(self.input_ids[idx]), torch.LongTensor(self.attention_mask[idx]), torch.LongTensor(self.label_ids[idx])

    def __len__(self):
        return len(self.input_ids)

# Parameters

In [None]:
class TrainingArgs:
    def __init__(self):
        self.train_data = 'data/train-train_classification_test_classification.json'
        self.val_data = 'data/val-train_classification_test_classification.json'
        self.data_dir = 'data'
        self.output_dir = 'output'
        self.checkpoint_path = None
        self.finetuned_model_name = 'mini_classification_{}.pt'
        self.t5_model = 't5-base'
        self.learning_rate = 1e-5
        self.batch_size = 8
        self.k = 8
        self.max_input_len = 1024
        self.max_target_len = 32
        self.st_epoch = 0
        self.n_epochs = 1

args = TrainingArgs()

# Preparation
Initialize variables, read data, load model

In [None]:
logger = logging.getLogger()
logger.setLevel(level=logging.INFO)

logFileFormatter = logging.Formatter(
    fmt='%(asctime)s - %(levelname)s - %(name)s - %(message)s',
    datefmt='%m/%d/%Y %H:%M:%S',
)
fileHandler = logging.FileHandler(filename=os.path.join(args.output_dir, 'log.txt'))
fileHandler.setFormatter(logFileFormatter)
fileHandler.setLevel(level=logging.INFO)

logger.addHandler(fileHandler)

In [None]:
utils.random_seed(0)
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

tokenizer = T5Tokenizer.from_pretrained(args.t5_model, model_max_length=1024)

# Read training set
train_raw_data = json.load(open(args.train_data))
train_dataset = ICTDevDataset(train_raw_data, tokenizer, args)
train_loader = DataLoader(train_dataset, sampler=RandomSampler(train_dataset), batch_size=args.batch_size)

# Read validation set
val_raw_data = json.load(open(args.val_data))
val_dataset = ICTDevDataset(val_raw_data, tokenizer, args)
val_loader = DataLoader(val_dataset, sampler=RandomSampler(val_dataset), batch_size=args.batch_size)

In [None]:
model = T5ForConditionalGeneration.from_pretrained(args.t5_model).to(device)
if args.checkpoint_path is not None:
    model.load_state_dict(torch.load(args.checkpoint_path, map_location=device))

optimizer = Adam(model.parameters(), lr=args.learning_rate)

# Fine-tuning

In [None]:
train_losses = []
val_losses = []
best_loss = float("inf")
step = 0

for epoch in range(args.st_epoch, args.st_epoch + args.n_epochs):
    model.train()
    train_loss = 0
    for batch in tqdm(train_loader, desc="Epoch {}".format(epoch)):
        input_ids, attention_mask, label_id = batch
        loss = model(
            input_ids=input_ids.to(device),
            attention_mask=attention_mask.to(device),
            labels=label_id.to(device)).loss
        train_loss += loss.item()

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        step += 1

    train_loss = train_loss / len(train_loader)
    train_losses.append(train_loss)
    logger.info('Epoch {}; Train loss: {}'.format(epoch, train_loss))

    # Evaluate with validation set
    model.eval()
    val_loss = 0
    for batch in val_loader:
        input_ids, attention_mask, label_id = batch
        loss = model(
            input_ids=input_ids.to(device),
            attention_mask=attention_mask.to(device),
            labels=label_id.to(device)).loss
        val_loss += loss.item()
    val_loss = val_loss / len(val_loader)
    val_losses.append(val_loss)
    logger.info('Validation loss: {}'.format(val_loss))

    if val_loss < best_loss:
        best_loss = val_loss
        logger.info('Found model with best loss at epoch {}'.format(epoch))
    
    # Always save checkpoint in every epoch
    torch.save(
        model.state_dict(),
        os.path.join(args.output_dir, args.finetuned_model_name.format(epoch))
    )
    
    torch.cuda.empty_cache()
    gc.collect()