# Options Generator

In [1]:
import torch
from datasets import load_dataset
from transformers import AutoTokenizer, Seq2SeqTrainingArguments, Seq2SeqTrainer
from transformers import BartForConditionalGeneration

# make sure to include cuda
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(device)

model = BartForConditionalGeneration.from_pretrained('facebook/bart-base')
tokenizer = AutoTokenizer.from_pretrained('facebook/bart-base')

cpu


In [2]:
data = load_dataset("nlp-group-6/sciq-with-generated-questions")
train_data = data['train']
val_data = data['validation']

In [3]:
max_input = 512
max_target = 128
batch_size = 36

In [4]:
# dataset has:
# question, distractor3, distractor1, distractor2, correct_answer, support

def calculate_max_length(data):
    max_input = 0
    max_target = 0
    for i in range(len(data)):
        max_input = max(max_input, len(data[i]["generated_question"]), len(data[i]["correct_answer"]))
        max_target = max(max_target, len(data[i]["distractor1"]) + len(data[i]["distractor2"]) + len(data[i]["distractor3"]))
    return max_input, max_target

# dataset has:
# question, distractor3, distractor1, distractor2, correct_answer, support
def pre_process_data(data):
    question_answer_context = [question + "</s><s>" + correct_answer + "</s><s>" + support for question, correct_answer, support in zip(data['generated_question'], data['correct_answer'], data['support'])]
    
    # tokenize the data
    inputs = tokenizer(question_answer_context, padding="max_length", truncation=True, max_length=max_input, return_tensors="pt")
    targets = tokenizer(data['distractor1'], data['distractor2'], data['distractor3'], padding="max_length", truncation=True, max_length=max_target, return_tensors="pt")
    return {"input_ids": inputs.input_ids, "attention_mask": inputs.attention_mask, "labels": targets.input_ids}

max_input, max_target = calculate_max_length(train_data)
train_data = train_data.map(pre_process_data, batched=True).shuffle(seed=42)

max_input, max_target = calculate_max_length(val_data)
val_data = val_data.map(pre_process_data, batched=True).shuffle(seed=42)

Map:   0%|          | 0/11679 [00:00<?, ? examples/s]

Map:   0%|          | 0/1000 [00:00<?, ? examples/s]

In [5]:
# empty memory
torch.cuda.empty_cache()

In [6]:
model.to(device)
args = Seq2SeqTrainingArguments(
    output_dir="./results_option_generation",
    evaluation_strategy='epoch',
    learning_rate=2e-5,
    per_device_train_batch_size=batch_size,
    per_device_eval_batch_size= batch_size,
    gradient_accumulation_steps=2,
    weight_decay=0.01,
    save_total_limit=2,
    num_train_epochs=32,
    predict_with_generate=True,
    eval_accumulation_steps=32,
    fp16=True #available only with CUDA
)


trainer = Seq2SeqTrainer(
    model, 
    args,
    train_dataset=train_data,
    eval_dataset=val_data,
    tokenizer=tokenizer,
)

trainer.train()
# lets save the model
OUT_DIR = "sciq_options1_generator"
model.save_pretrained(OUT_DIR)
tokenizer.save_pretrained(OUT_DIR)


ValueError: FP16 Mixed precision training with AMP or APEX (`--fp16`) and FP16 half precision evaluation (`--fp16_full_eval`) can only be used on CUDA or NPU devices or certain XPU devices (with IPEX).

In [11]:
OUT_DIR = "sciq_options1_generator"

In [12]:
from transformers import BartTokenizer, BartForConditionalGeneration

model = BartForConditionalGeneration.from_pretrained(f"./{OUT_DIR}")
tokenizer = BartTokenizer.from_pretrained(f"./{OUT_DIR}")
# put them both on the same device
_ = model.to(device)

#TODO: The correct answer might be in the options.

In [14]:
input_text = "What amazing machines smash particles that are smaller than atoms into each other head-on?"
correct_answer = "particle accelerators"

input_ids = tokenizer(input_text, correct_answer, return_tensors="pt").input_ids.to(device)
output = model.generate(input_ids, max_length=128, num_beams=4, num_return_sequences=3, early_stopping=True)
outputs = [tokenizer.decode(g, skip_special_tokens=True, clean_up_tokenization_spaces=False) for g in output]
print(outputs)

['particle accelerators', 'kinetic accelerators', 'neutron accelerators']
