In [None]:
#!pip install bert_score

In [None]:
import os
import logging
import transformers
import bert_score
import pandas as pd
from transformers import AutoModelForCausalLM, AutoTokenizer, pipeline
from bert_score import score, BERTScorer
from datasets import load_dataset
import matplotlib.pyplot as plt
from matplotlib import rcParams

transformers.tokenization_utils.logger.setLevel(logging.ERROR)
transformers.configuration_utils.logger.setLevel(logging.ERROR)
transformers.modeling_utils.logger.setLevel(logging.ERROR)

In [None]:
%matplotlib inline
bert_score.__version__

In [None]:
rcParams["xtick.major.size"] = 0
rcParams["xtick.minor.size"] = 0
rcParams["ytick.major.size"] = 0
rcParams["ytick.minor.size"] = 0

rcParams["axes.labelsize"] = "large"
rcParams["axes.axisbelow"] = True
rcParams["axes.grid"] = True

In [None]:
project_root = '..'
dataset_name = 'vblagoje/lfqa_support_docs'
dir_name = 'lfqa'
# dataset_name = 'wikitext'
# dataset_variant = 'wikitext-2-raw-v1'
# dir_name = 'wikitext'
data_dir = os.path.join(project_root, 'data', dir_name)
model_dir = os.path.join(project_root, 'models')
# model_checkpoint = 'gpt2'
model_checkpoint = 'mjphayes/distilgpt2-lfqa'

In [None]:
try:
    datasets = load_dataset(dataset_name, dataset_variant, cache_dir=data_dir)
except:
    datasets = load_dataset(dataset_name, cache_dir=data_dir)

In [None]:
datasets

In [None]:
test_data = datasets['validation']

In [None]:
def transform(examples):
    return {"question": examples['input'], "answer": examples['output'][0]['answer']}

In [None]:
test_data = test_data.map(transform, num_proc=4, remove_columns=["input", "output", "meta", "id"])

In [None]:
test_data

In [None]:
tokenizer = AutoTokenizer.from_pretrained(model_checkpoint, cache_dir=model_dir, use_fast=True, padding_side="left")

In [None]:
model = AutoModelForCausalLM.from_pretrained(model_checkpoint, cache_dir=model_dir)

In [None]:
example = test_data[0]
example

In [None]:
test_prompt = f"Question: {example['question']}? \n Answer: "

In [None]:
tokenizer.pad_token = tokenizer.eos_token  # Most LLMs don't have a pad token by default
model_inputs = tokenizer([test_prompt], return_tensors="pt", padding='max_length', max_length=200)

In [None]:
generated_ids = model.generate(**model_inputs, max_new_tokens=300, do_sample=True)
bable = tokenizer.batch_decode(generated_ids, skip_special_tokens=True)[0]

In [None]:
pred_answer = bable.split('\n')[1]

In [None]:
pred_answer

In [None]:
def generation(input, model=model):
    prompt = f"Question: {input}? \n Answer:"
    model_inputs = tokenizer([prompt], return_tensors="pt", padding='max_length', max_length=200)
    generated_ids = model.generate(**model_inputs, max_new_tokens=300, do_sample=True)
    bable = tokenizer.batch_decode(generated_ids, skip_special_tokens=True)[0]
    return bable.split('Question: ')[1]

In [None]:
num = 22
test_answer = generation(test_data[num]['question'])
test_answer

In [None]:
mini_val = test_data.select(range(20))
print(len(mini_val))
mini_val

In [None]:
questions = []
ideal_answers = []
gen_answers = []
for q_a in mini_val:
    question = q_a['question']
    gen_answer = generation(question).split('Answer: ')[-1]
    ideal = q_a['answer']
    questions.append(question)
    ideal_answers.append(ideal)
    gen_answers.append(gen_answer)



In [None]:
rectified = gen_answers[9].split('Answer: ')[-1]
rectified

In [None]:
df = pd.DataFrame({
    'Questions': questions,
    'Ideal': ideal_answers,
    'Generated':gen_answers,
})

In [None]:
df

In [None]:
P, R, F1 = score(gen_answers, ideal_answers, lang='en', verbose=True)

In [None]:
print(f"System level F1 score: {F1.mean():.3f}")

In [None]:
F1

In [None]:
import matplotlib.pyplot as plt

In [None]:
plt.hist(F1, bins=20)
plt.xlabel("score")
plt.ylabel("counts")
plt.show()

In [None]:
scorer = BERTScorer(lang="en", rescale_with_baseline=True)

In [None]:
P, R, F1 = scorer.score(gen_answers, ideal_answers)

In [None]:
F1

In [None]:
print(f"System level F1 score: {F1.mean():.3f}")

In [None]:
plt.hist(F1, bins=20)
plt.xlabel("score")
plt.ylabel("counts")
plt.show()