In [None]:
# CONFIGURATION (GLOBALS)
from rag.generator import Llama3_8b, Llama2_4b, Gpt2, Phi15_1b

limit=100
seed=42
generator = Phi15_1b()
generator_name = "Phi-1.5"
plot_dir = "./plots"

In [None]:
from dotenv import load_dotenv
load_dotenv()

In [None]:
from datasets import load_dataset
from rag.retriever import Medical, Wiki10k
import random
import re

# gets a response (string) as input and extracts the answer 
# (returns 0,1,2 or 3 for the picked answer, or 4 if no answer could be extracted
def extract_answer(response, entry):
    # get everthing between "Answer:" and "Explanation:"
    if "Answer:" in response:
      response = response.split("Answer:")[1] 
    if "Explanation:" in response:
      response = response.split("Explanation:")[0]
    response = response.strip()
    response = response.replace("\"", "'")
    response = response.replace("\n", " ")

    accepted_symbols = '),.]'
    
    if (re.match(rf'^A[{accepted_symbols}]', response)):
          model_answer = 0
    elif (re.match(rf'^B[{accepted_symbols}]', response)):
      model_answer = 1
    elif (re.match(rf'^C[{accepted_symbols}]', response)):
      model_answer = 2
    elif (re.match(rf'^D[{accepted_symbols}]', response)):
      model_answer = 3
    elif (re.search(rf'A[{accepted_symbols}]', response)):
      model_answer = 0
    elif (re.search(rf'B[{accepted_symbols}]', response)):
      model_answer = 1
    elif (re.search(rf'C[{accepted_symbols}]', response)):
      model_answer = 2
    elif (re.search(rf'D[{accepted_symbols}]', response)):
      model_answer = 3
    elif (entry['choices'][0] in response):
      model_answer = 0
    elif (entry['choices'][1] in response):
      model_answer = 1
    elif (entry['choices'][2] in response):
      model_answer = 2
    elif (entry['choices'][3] in response):
      model_answer = 3
    else:
        print("Could not extract answer")
        model_answer = 4 # not existing index --> definitley wrong
        print(f"\n-----------------------------------\nResponse: '{response}'\n-----------------------------------")
    
    return model_answer
        

def evaluate_model(dataset, get_query, get_solution, generator, retriever, limit):
    count = 0
    correct_answers = 0
    unretrievable = 0
    random.seed(seed) # set the seed for the suffle method
    
    for entry in dataset:
        random.shuffle(entry['choices'])
        query = get_query(entry)
        solution = get_solution(entry)
        context = ""
        context_prompt = ""
        if retriever != None:
            context = "Context:\n" + "\n".join(retriever(query))
            context_prompt = "Answer based on the context provided."
    
        promt = f"""
        You are a helpful AI assistant. {context_prompt} Be concise and answer either with A), B), C) or D). Add nothing else.
        
        {context}
        
        Query:
        {query}
        
        Answer:
        """
        response = generator(promt, max_new_tokens=100)
        print(f"\n#########################################\n{response}\n#########################################\n")

        # answer extraction
        model_answer = extract_answer(response, entry)

        print(f"Question {count}: Expected={entry['choices'].index(entry['answer'][0])} Got={model_answer}")
        
        if solution == model_answer:
            correct_answers += 1
        elif model_answer == 4:
            unretrievable += 1
        
        count += 1
        
        if count >= limit:
            return (correct_answers, unretrievable, count)


In [None]:
dataset = load_dataset("bigbio/sciq", "sciq_bigbio_qa", split="test")
dataset = dataset.shuffle(seed=seed)
evaluations = []

class DatasetEvaluation:
    def __init__(self, name, correct, unretrievable, total):
        self.name = name
        self.correct = correct
        self.unretrievable = unretrievable
        self.total = total

    def __str__(self):
        return f"{self.name}: correct={self.correct} unretrievable={self.unretrievable} total={self.total}"

In [None]:
correct_answers, unretrievable_answers, total_answers = evaluate_model(
    dataset=dataset,
    get_query= lambda entry: f"{entry['question']}\n A) {entry['choices'][0]}\n B) {entry['choices'][1]}\n C) {entry['choices'][2]}\n D) {entry['choices'][3]}\n",
    get_solution= lambda entry: entry['choices'].index(entry['answer'][0]),
    generator = generator,
    retriever = None,
    limit=limit
)

evaluation = DatasetEvaluation("Generator (without RAG)", correct_answers, unretrievable_answers, total_answers)
evaluations.append(evaluation)
print(evaluation)

In [None]:
correct_answers, unretrievable_answers, total_answers = evaluate_model(
    dataset=dataset,
    get_query= lambda entry: f"{entry['question']}\n A) {entry['choices'][0]}\n B) {entry['choices'][1]}\n C) {entry['choices'][2]}\n D) {entry['choices'][3]}\n",
    get_solution= lambda entry: entry['choices'].index(entry['answer'][0]),
    generator = generator,
    retriever = Wiki10k(),
    limit=limit
)

evaluation = DatasetEvaluation("RAG - Wikipedia dataset", correct_answers, unretrievable_answers, total_answers)
evaluations.append(evaluation)
print(evaluation)

In [None]:
correct_answers, unretrievable_answers, total_answers = evaluate_model(
    dataset=dataset,
    get_query= lambda entry: f"{entry['question']}\n A) {entry['choices'][0]}\n B) {entry['choices'][1]}\n C) {entry['choices'][2]}\n D) {entry['choices'][3]}\n",
    get_solution= lambda entry: entry['choices'].index(entry['answer'][0]),
    generator = generator,
    retriever = Medical(),
    limit=limit
)

evaluation = DatasetEvaluation("RAG - Medical dataset", correct_answers, unretrievable_answers, total_answers)
evaluations.append(evaluation)
print(evaluation)

In [None]:
"""
correct_web_model, unretrievable_answers, total_web_model = evaluate_model(
    dataset=dataset,
    get_query= lambda entry: f"{entry['question']}\n A) {entry['choices'][0]}\n B) {entry['choices'][1]}\n C) {entry['choices'][2]}\n D) {entry['choices'][3]}\n",
    get_solution= lambda entry: entry['choices'].index(entry['answer'][0]),
    generator = generator,
    retriever = WebSearch(),
    limit=limit
)

evaluation = DatasetEvaluation("RAG - WebSearch as retriever", correct_answers, unretrievable_answers, total_answers)
evaluations.append(evaluation)
print(evaluation)
"""

In [None]:
# comparison chart
import matplotlib.pyplot as plt
from itertools import cycle
import os

correct_percentages = [e.correct / e.total * 100 for e in evaluations]

color_choices = cycle(['#007F73', '#4CCD99', '#FFC700', '#FFF455'])
colors = [next(color_choices) for _ in range(len(evaluations))]

fig, ax = plt.subplots()
bar_width = 0.35
bars = [e.name for e in evaluations]

ax.bar(bars, correct_percentages, bar_width, color=colors)

ax.set_title(f'{generator_name}: performance overview ')
ax.set_xlabel('Retriever type')
ax.set_ylabel('Correct answers (%)')

if not os.path.exists(plot_dir):
   os.makedirs(plot_dir)

plt.savefig(os.path.join(plot_dir, f'performance_overview_{generator_name}.png'), dpi=100)
plt.show()

In [None]:
for evaluation in evaluations:
    fig, ax = plt.subplots()
    bar_width = 0.35
    bars = ["Correct", "False answer", "Unretrievable answer"]
    values = [evaluation.correct, evaluation.total - (evaluation.correct + evaluation.unretrievable) ,evaluation.unretrievable]

    color_choices = cycle(['#7469B6', '#AD88C6', '#E1AFD1', '#FFE6E6'])
    colors = [next(color_choices) for _ in range(len(bars))]
    
    ax.bar(bars, values, bar_width, color=colors)
    
    ax.set_title(f'{generator_name}: {evaluation.name} details')
    ax.set_xlabel('Type')
    ax.set_ylabel('Answers')
    
    plt.savefig(os.path.join(plot_dir, f'{generator_name}_{evaluation.name}_details.png'), dpi=100)
    plt.show()