In [24]:
import torch
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM, default_data_collator, get_linear_schedule_with_warmup
from datasets import Dataset
# from torch.utils.data import DataLoader, Dataset
from sklearn.metrics import accuracy_score
import json
import json_lines
import os
from tqdm import tqdm
from collections import defaultdict

In [20]:
# class MultipleChoiceDataset(Dataset):
#     def __init__(self, data):
#         self.data = data

#     def __len__(self):
#         return len(self.data)

#     def __getitem__(self, idx):
#         return self.data[idx]

In [21]:
def get_riddlesense_prompt(question, options):
    prompt = \
"""
Question: {}

What is the correct answer to the question from the following choices?
Options: 
(A): {}
(B): {}
(C): {}
(D): {}
(E): {}""".format(question, options[0], options[1], options[2], options[3], options[4])
    return prompt

In [26]:
def load_data(file_path):
    raw_data = []
    with open(file_path, 'rb') as f: 
        for item in json_lines.reader(f):
            raw_data.append(item)
    
    data = defaultdict(list)
    for item in raw_data:
        data['question'].append(item['question']['stem'])
        data['options'].append([_['text'] for _ in item['question']['choices']])
        data['answer'].append(item['answerKey'])
    return data

train_data = load_data("data/rs_train.jsonl")
valid_data = load_data("data/rs_dev.jsonl")

In [28]:
train_dataset = Dataset.from_dict(train_data)
valid_dataset = Dataset.from_dict(valid_data)

In [4]:
answer_map = {'A':0, 'B':1, 'C':2, 'D':3, 'E':4}
def load_data(file_path):
    data = []
    with open(file_path, 'rb') as f: 
        for item in json_lines.reader(f):
            data.append(item)

    processed_data = []
    for item in data:
        question = item['question']['stem']
        options = [_['text'] for _ in item['question']['choices']]
        answer = item['answerKey']
        text = get_riddlesense_prompt(question, options)

        model_inputs = tokenizer(
            text,
            truncation=True,
            max_length=512,
            padding='max_length',
            return_attention_mask=True,
            # return_tensors='pt'
        )

        # labels = tokenizer(answer, max_length=2, padding="max_length", truncation=True, return_tensors="pt")
        labels = tokenizer(answer, max_length=2, padding="max_length", truncation=True, return_tensors="pt")
        labels = labels["input_ids"]
        labels[labels == tokenizer.pad_token_id] = -100
        model_inputs["labels"] = labels
        processed_data.append(model_inputs)

        # input_ids = torch.stack([example['input_ids'] for example in examples]).squeeze()
        # attention_mask = torch.stack([example['attention_mask'] for example in examples]).squeeze()
        # label = torch.tensor(answer_map[item['answerKey']])
        # processed_data.append((input_ids, attention_mask, label))

    return processed_data

In [5]:
model_name = 'google/flan-t5-small'
model = AutoModelForSeq2SeqLM.from_pretrained(model_name)
tokenizer = AutoTokenizer.from_pretrained(model_name)

In [6]:
train_data = load_data("data/rs_train.jsonl")
valid_data = load_data("data/rs_dev.jsonl")

In [7]:
train_dataset = MultipleChoiceDataset(train_data)
valid_dataset = MultipleChoiceDataset(valid_data)

batch_size = 4

train_loader = DataLoader(train_dataset, shuffle=True, collate_fn=default_data_collator, batch_size=batch_size)
valid_loader = DataLoader(valid_dataset, shuffle=False, collate_fn=default_data_collator, batch_size=batch_size)

In [8]:
# hyperparameters
lr = 1e-2
num_epochs = 2
batch_size = 8
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [9]:
optimizer = torch.optim.AdamW(model.parameters(), lr=lr)
lr_scheduler = get_linear_schedule_with_warmup(
    optimizer=optimizer,
    num_warmup_steps=0,
    num_training_steps=(len(train_loader) * num_epochs),
)

In [10]:
model = model.to(device)

for epoch in range(num_epochs):
    model.train()
    total_loss = 0
    for step, batch in enumerate(tqdm(train_loader)):
        # batch = {k: v.to(device) for k, v in batch.items()}
        inputs = {key: val.to(device) for key, val in batch.items() if key != "labels"}
        print(inputs["input_ids"].shape)
        labels = batch["labels"].to(device)
        outputs = model(**inputs, labels=labels)
        # print(batch)
        # outputs = model(**batch)
        loss = outputs.loss
        total_loss += loss.detach().float()
        loss.backward()
        optimizer.step()
        lr_scheduler.step()
        optimizer.zero_grad()

    model.eval()
    eval_loss = 0
    eval_preds = []
    true_preds = []
    for step, batch in enumerate(tqdm(valid_loader)):
        batch = {k: v.to(device) for k, v in batch.items()}
        with torch.no_grad():
            outputs = model(**batch)
        loss = outputs.loss
        eval_loss += loss.detach().float()
        eval_preds.extend(
            tokenizer.batch_decode(torch.argmax(outputs.logits, -1).detach().cpu().numpy(), skip_special_tokens=True)
        )
        true_preds.extend(
            tokenizer.batch_decode(batch["labels"].detach().cpu().numpy(), skip_special_tokens=True)
        )

    eval_epoch_loss = eval_loss / len(eval_dataloader)
    eval_ppl = torch.exp(eval_epoch_loss)
    train_epoch_loss = total_loss / len(train_dataloader)
    train_ppl = torch.exp(train_epoch_loss)
    print(f"{epoch=}: {train_ppl=} {train_epoch_loss=} {eval_ppl=} {eval_epoch_loss=}")

model.save_pretrained(f'{model_name[7:]}_finetuned')

  0%|          | 0/878 [00:00<?, ?it/s]

torch.Size([4, 512])


  0%|          | 0/878 [00:00<?, ?it/s]


ValueError: too many values to unpack (expected 2)

In [11]:
from transformers import DataCollatorForSeq2Seq

# we want to ignore tokenizer pad token in the loss
label_pad_token_id = -100
# Data collator
data_collator = DataCollatorForSeq2Seq(
    tokenizer,
    model=model,
    label_pad_token_id=label_pad_token_id,
    pad_to_multiple_of=8
)

In [14]:
from transformers import Seq2SeqTrainer, Seq2SeqTrainingArguments

output_dir=f"{model_name[7:]}_full_ft"

# Define training args
training_args = Seq2SeqTrainingArguments(
    output_dir=output_dir,
	auto_find_batch_size=True,
    learning_rate=1e-3,
    num_train_epochs=2,
    logging_dir=f"{output_dir}/logs",
    logging_strategy="steps",
    logging_steps=500,
    save_strategy="no",
    # report_to="tensorboard",
)

# Create Trainer instance
trainer = Seq2SeqTrainer(
    model=model,
    args=training_args,
    data_collator=data_collator,
    train_dataset=train_loader,
)
# model.config.use_cache = False  # silence the warnings. Please re-enable for inference!
trainer.train()

TypeError: 'DataLoader' object is not subscriptable

In [None]:
correct = 0
total = 0
for pred, true in zip(eval_preds, true_preds):
    if pred.strip() == true.strip():
        correct += 1
    total += 1
accuracy = correct / total * 100
print(f"{accuracy=} % on the evaluation dataset")
print(f"{eval_preds[:10]=}")
print(f"{true_preds[:10]=}")

In [None]:
from huggingface_hub import notebook_login
notebook_login()
# model.push_to_hub("devanshrj/t5-large_PREFIX_TUNING_SEQ2SEQ", use_auth_token=True)

## Inference

In [60]:
def get_brainteaser_prompt(question, options):
    prompt = \
"""
Question: {}

What is the correct answer to the question from the following choices?
Options: 
(A): {}
(B): {}
(C): {}
(D): {}""".format(question, options[0], options[1], options[2], options[3])
    return prompt

In [63]:
question = "Mr. and Mrs. Mustard have six daughters and each daughter has one brother. But there are only 9 people in the family, how is that possible?"
options = ["Some daughters get married and have their own family.", "Each daughter shares the same brother.", "Some brothers were not loved by family and moved away.", "None of above."]
bt_prompt = get_brainteaser_prompt(question, options)

In [64]:
inputs = tokenizer(bt_prompt, return_tensors="pt")

In [66]:
model.to(device)
with torch.no_grad():
    inputs = {k: v.to(device) for k, v in inputs.items()}
    outputs = model.generate(input_ids=inputs["input_ids"], max_new_tokens=10)
    print(tokenizer.batch_decode(outputs.detach().cpu().numpy(), skip_special_tokens=True))

['(C)']
