# Correct Answer Generator, (TODO)

In [None]:
!pip install -r requirements.txt

In [None]:
import torch
from datasets import load_dataset
from transformers import AutoTokenizer, Seq2SeqTrainingArguments, Seq2SeqTrainer
from transformers import BartTokenizer, BartForConditionalGeneration

# make sure to include cuda
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(device)

model = BartForConditionalGeneration.from_pretrained('facebook/bart-base')
tokenizer = AutoTokenizer.from_pretrained('facebook/bart-base')

In [None]:
data = load_dataset("allenai/sciq")
train_data = data['train']
test_data = data['test']

In [None]:
max_input = 512
max_target = 128
batch_size = 36

In [None]:
# dataset has:
# question, distractor3, distractor1, distractor2, correct_answer, support

def calculate_max_length(data):
    max_input = 0
    max_target = 0
    for i in range(len(data)):
        max_input = max(max_input, len(data[i]["question"]))
        max_target = max(max_target, len(data[i]["correct_answer"]))
    return max_input, max_target

# dataset has:
# question, distractor3, distractor1, distractor2, correct_answer, support
def pre_process_data(data):
    # tokenize the data
    inputs = tokenizer(data['question'], padding="max_length", truncation=True, max_length=max_input, return_tensors="pt")
    targets = tokenizer(data['correct_answer'], padding="max_length", truncation=True, max_length=max_target, return_tensors="pt")
    return {"input_ids": inputs.input_ids, "attention_mask": inputs.attention_mask, "labels": targets.input_ids}

max_input, max_target = calculate_max_length(train_data)
train_data = train_data.map(pre_process_data, batched=True)

max_input, max_target = calculate_max_length(test_data)
test_data = test_data.map(pre_process_data, batched=True)

In [None]:
# empty memory
torch.cuda.empty_cache()

In [None]:
# TODO: add versioning

model.to(device)
args = Seq2SeqTrainingArguments(
    output_dir="./results_option_generation",
    evaluation_strategy='epoch',
    learning_rate=2e-5,
    per_device_train_batch_size=batch_size,
    per_device_eval_batch_size= batch_size,
    gradient_accumulation_steps=2,
    weight_decay=0.01,
    save_total_limit=2,
    num_train_epochs=32,
    predict_with_generate=True,
    eval_accumulation_steps=32,
    fp16=True #available only with CUDA
)


trainer = Seq2SeqTrainer(
    model, 
    args,
    train_dataset=train_data,
    eval_dataset=test_data,
    tokenizer=tokenizer,
)

trainer.train()
# lets save the model
#!!! we also have the one with context.
OUT_DIR = "sciq_correct_answer_generator"
model.save_pretrained(OUT_DIR)
tokenizer.save_pretrained(OUT_DIR)


In [None]:
from transformers import BartTokenizer, BartForConditionalGeneration

model = BartForConditionalGeneration.from_pretrained(f"./{OUT_DIR}")
tokenizer = BartTokenizer.from_pretrained(f"./{OUT_DIR}")
# put them both on the same device
_ = model.to(device)

#TODO: The correct answer might be in the options.

In [None]:
input_text = "What amazing machines smash particles that are smaller than atoms into each other head-on?"
correct_answer = "particle accelerators"

input_ids = tokenizer(input_text, correct_answer, return_tensors="pt").input_ids.to(device)
output = model.generate(input_ids, max_length=128, num_beams=4, num_return_sequences=3, early_stopping=True)
outputs = [tokenizer.decode(g, skip_special_tokens=True, clean_up_tokenization_spaces=False) for g in output]
print(outputs)
