# Imports

In [None]:
import json
import logging
import metrics
import os
import pandas as pd
import random
import torch
import utils

from torch.utils.data import DataLoader, RandomSampler
from torch.utils.data import Dataset
from tqdm import tqdm
from transformers import T5Tokenizer, T5ForConditionalGeneration

# Dataset class

In [None]:
class ICTTestDataset(Dataset):

    def __init__(self, test_examples, demos, tokenizer, args):
        self.examples = []
        for test_example in test_examples:
            input_text = utils.create_input_text(demos, test_example[0], 'label:', '. ')
            if len(tokenizer(input_text)['input_ids']) <= args.max_input_len:
                self.examples.append([input_text, test_example[1]])

        tokenized_input = tokenizer([example[0] for example in self.examples], padding=True, truncation=True, max_length=args.max_input_len)
        self.input_ids = tokenized_input['input_ids']
        self.attention_mask = tokenized_input['attention_mask']

    def __getitem__(self, idx):
        return torch.LongTensor(self.input_ids[idx]), torch.LongTensor(self.attention_mask[idx])

    def __len__(self):
        return len(self.input_ids)

# Parameters

In [None]:
class TrainingArgs:
  def __init__(self):
      self.test_data = 'data/test-train_classification_test_classification.json'
      self.output_dir = 'output'
      self.checkpoint_path = 'output/train_classification_test_classification_9.pt'
      self.t5_model = 't5-base'
      self.batch_size = 8
      self.k = 8
      self.n_prompt = 4
      self.max_input_len = 1024

args = TrainingArgs()

# Preparation

In [None]:
logger = logging.getLogger()
logger.setLevel(level=logging.INFO)

logFileFormatter = logging.Formatter(
    fmt='%(asctime)s - %(levelname)s - %(name)s - %(message)s',
    datefmt='%m/%d/%Y %H:%M:%S',
)
fileHandler = logging.FileHandler(filename=os.path.join(args.output_dir, 'log.txt'))
fileHandler.setFormatter(logFileFormatter)
fileHandler.setLevel(level=logging.INFO)

logger.addHandler(fileHandler)

In [None]:
utils.random_seed(0)
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

tokenizer = T5Tokenizer.from_pretrained(args.t5_model, model_max_length=1024)
model = T5ForConditionalGeneration.from_pretrained(args.t5_model).to(device)
model.load_state_dict(torch.load(args.checkpoint_path, map_location=device))
model.eval()

test_raw_data = json.load(open(args.test_data))

# Generate prompts

In [None]:
data = []
for task in test_raw_data:
    new_row = [task['task_name']]
    for _ in range(args.n_prompt):
        demos = utils.sample_demos(task['dev_examples'], args.k, utils.n_label(task['task_name']))
        demo_text = utils.create_input_text(demos, None, 'label:', '. ')
        demo_text_len = len(tokenizer(demo_text)['input_ids'])
        new_row.extend([demos, demo_text_len])
    data.append(new_row)

df = pd.DataFrame(data)
df

In [None]:
EXCLUDED_TASKS = ['yelp_polarity', 'tab_fact']

# Evaluation

In [None]:
result = []

for task in test_raw_data:
    if task['task_name'] in EXCLUDED_TASKS:
        continue
    logger.info('Evaluating on task {}...'.format(task['task_name']))

    # Prepare data
    demos = df[df[0] == task['task_name']].iloc[0][1]
    test_dataset = ICTTestDataset(task['test_examples'], demos, tokenizer, args)
    test_loader = DataLoader(test_dataset, batch_size=args.batch_size)

    # Predict
    predictions = []
    for batch in tqdm(test_loader):
        input_ids, attention_mask = batch
        outputs = model.generate(
            input_ids=input_ids.to(device),
            attention_mask=attention_mask.to(device),
            do_sample=False)
        pred = tokenizer.batch_decode(outputs, skip_special_tokens=True)
        predictions.extend(pred)

    test_performance = metrics.evaluate(predictions, test_dataset.examples, metrics.METRICS[task['task_name']])
    logger.info('Test score: {}; Metric: {}'.format(test_performance, metrics.METRICS[task['task_name']]))

    result.append([task['task_name'], predictions, test_performance])