# Generate greedy answers
They are used to evaluate if answer is correct or incorrect. To see if an answer is correct/incorrect take fuzzy matching criterion $L(s, s') = \mathbb{1}_{\text{RougeL}(s, s') > 0.3}$ (page 7)


Structure:

```python 
{ 1131: {"greedy_answer": ..., 
         "rouge_l_score": ...
        }, 
  4295: ...
}
```
rouge_l_score stands for the Rouge-L score of the generated answer against the true answer. 

In [24]:
import torch
from datasets import load_dataset
from transformers import AutoTokenizer, OPTForCausalLM
import yaml
import os
import pickle
from rouge_score import rouge_scorer
from tqdm import tqdm

with open("config.yaml", "r") as f:
    config = yaml.safe_load(f)

In [4]:
# Load data
model_dir = config["model_dir"]
data_trivia = load_dataset("trivia_qa", "rc.nocontext")
data_trivia = data_trivia.remove_columns(["question_source", "entity_pages", "search_results"])
data_trivia_train = data_trivia["train"]
data_trivia_val = data_trivia["validation"]

Resolving data files:   0%|          | 0/26 [00:00<?, ?it/s]

In [5]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
device

device(type='cuda')

In [7]:
checkpoint = config["checkpoint"]
tokenizer = AutoTokenizer.from_pretrained(f"facebook/{checkpoint}", cache_dir=model_dir)
model = OPTForCausalLM.from_pretrained(f"facebook/{checkpoint}", cache_dir=model_dir)
model = model.to(device)

In [8]:
selected_training_data = data_trivia_train.select(range(0, 10))
ten_shot_prompt = ""
for data in selected_training_data:
    ten_shot_prompt += "QUESTION:" + data["question"] + "ANSWER:" + data["answer"]["value"] + "\n"

# Define stop tokens, use token on position 1 bc position 0 is special token
stop_tokens = ["Q:", "Question:", "QUESTION:", "questions:", " Q:", " Question:", " QUESTION:", " questions:",
               "A:", "Answer:", "ANSWER:", "answers:", " A:", " Answer:", " ANSWER:", " answers:", "Answers:",
               " Answers:",
               "Topic:", " Topic:", "TOPIC:", " TOPIC:", ".", " .", "...", " ...", "?", " ?", ":", " :", "!", " !"]
stop_tokens = [[tokenizer(stop_token)["input_ids"][1]] for stop_token in stop_tokens]

# Define eos token
eos_token = tokenizer("\n")["input_ids"][1]
tokenizer.pad_token_id = eos_token
tokenizer.eos_token_id = eos_token

# Maximum token length that generated answer can have
max_new_tokens = config["max_output_length"]

In [20]:
save_path = config["path_to_saved_generations"]
with open(os.path.join(save_path, "group_indices.txt"), "r") as f:
    indices_questions = {int(i) for line in f for i in line.strip().split(",")}

In [18]:
# https://thepythoncode.com/article/calculate-rouge-score-in-python#rouge-l
scorer = rouge_scorer.RougeScorer(["rougeL"], use_stemmer=True)

In [25]:
if os.path.exists(os.path.join(save_path, "greedy_answers.pkl")):
    with open(os.path.join(save_path, "greedy_answers.pkl"), "rb") as f:
        result = pickle.load(f)
else:
    result = dict()

for idx in tqdm(indices_questions):
    if idx in result:
        print(f"Question {idx} already exists. Skipping...")

    result_question = dict()
    question = ten_shot_prompt + "QUESTION:" + data_trivia_val[idx]["question"] + "ANSWER:"
    answer = data_trivia_val[idx]["answer"]["value"]

    inputs = tokenizer(question, padding=False, truncation=False, return_tensors="pt").to(device)
    length_input = inputs["input_ids"].shape[1]

    # Generate sequence by always taking token with max probability (greedy)
    output_generate = model.generate(inputs.input_ids,
                                     max_new_tokens=max_new_tokens,
                                     eos_token_id=eos_token,
                                     bad_words_ids=stop_tokens)

    output = tokenizer.batch_decode(output_generate[0][length_input:], skip_special_tokens=True)
    output = "".join(output).replace("\n", "")
    result_question["greedy_answer"] = output
    rouge_l_score = scorer.score(answer, output)["rougeL"].fmeasure
    result_question["rouge_l_score"] = rouge_l_score
    result[idx] = result_question
    
    # Save new dictionary
    with open(os.path.join(save_path, "greedy_answers.pkl"), "wb") as f:
        pickle.dump(result, f)

100%|██████████| 4484/4484 [14:25<00:00,  5.18it/s]
