In [None]:
# CONFIGURATION (GLOBALS)
from rag.generator import Llama3_8b, Llama2_4b, Gpt2, Phi15_1b
from rag.retriever import Medical, Wiki10k

class Model:
    def __init__(self, name, generator, limit):
        self.name = name
        self.generator = generator
        self.limit = limit

    def __str__(self):
        return f"Model: name={self.name} limit={self.limit}"

class RetrieverData:
    def __init__(self, name, model):
        self.name = name
        self.model = model

    def __str__(self):
        return f"{self.name}"

seed=42
plot_dir = "./plots"

models = [
    Model("GPT 2", Gpt2(), 5), 
#   Model("Phi-1.5b", Phi15_1b(), 5), 
#   Model("Llama2-4b", Llama2_4b(), 5), 
#   Model("Llama3-8b", Llama3_8b(), 5), 
]

retrievers = [
    RetrieverData("Generator (without RAG)", None),
    RetrieverData("RAG - Wikipedia dataset", Wiki10k()),
    RetrieverData("RAG - Medical dataset", Medical()),
#   RetrieverData("RAG - WebSearch as retriever", WebSearch()),
]

In [None]:
from dotenv import load_dotenv
load_dotenv()

In [None]:
from datasets import load_dataset
import random
import re

# gets a response (string) as input and extracts the answer 
# (returns 0,1,2 or 3 for the picked answer, or 4 if no answer could be extracted
def extract_answer(response, entry):
    # get everthing between "Answer:" and "Explanation:"
    if "Answer:" in response:
      response = response.split("Answer:")[1] 
    if "Explanation:" in response:
      response = response.split("Explanation:")[0]
    response = response.strip()
    response = response.replace("\"", "'")
    response = response.replace("\n", " ")

    accepted_symbols = '),.]'
    
    if (re.match(rf'^A[{accepted_symbols}]', response)):
          model_answer = 0
    elif (re.match(rf'^B[{accepted_symbols}]', response)):
      model_answer = 1
    elif (re.match(rf'^C[{accepted_symbols}]', response)):
      model_answer = 2
    elif (re.match(rf'^D[{accepted_symbols}]', response)):
      model_answer = 3
    elif (re.search(rf'A[{accepted_symbols}]', response)):
      model_answer = 0
    elif (re.search(rf'B[{accepted_symbols}]', response)):
      model_answer = 1
    elif (re.search(rf'C[{accepted_symbols}]', response)):
      model_answer = 2
    elif (re.search(rf'D[{accepted_symbols}]', response)):
      model_answer = 3
    elif (entry['choices'][0] in response):
      model_answer = 0
    elif (entry['choices'][1] in response):
      model_answer = 1
    elif (entry['choices'][2] in response):
      model_answer = 2
    elif (entry['choices'][3] in response):
      model_answer = 3
    else:
        print("Could not extract answer")
        model_answer = 4 # not existing index --> definitley wrong
        print(f"\n-----------------------------------\nResponse: '{response}'\n-----------------------------------")
    
    return model_answer
        

def evaluate_model(dataset, get_query, get_solution, generator, retriever, limit):
    count = 0
    correct_answers = 0
    unretrievable = 0
    random.seed(seed) # set the seed for the suffle method
    
    for entry in dataset:
        random.shuffle(entry['choices'])
        query = get_query(entry)
        solution = get_solution(entry)
        context = ""
        context_prompt = ""
        if retriever != None:
            context = "Context:\n" + "\n".join(retriever(query))
            context_prompt = "Answer based on the context provided."
    
        promt = f"""
        You are a helpful AI assistant. {context_prompt} Be concise and answer either with A), B), C) or D). Add nothing else.
        
        {context}
        
        Query:
        {query}
        
        Answer:
        """
        response = generator(promt, max_new_tokens=100)
        print(f"\n#########################################\n{response}\n#########################################\n")

        # answer extraction
        model_answer = extract_answer(response, entry)

        print(f"Question {count}: Expected={entry['choices'].index(entry['answer'][0])} Got={model_answer}")
        
        if solution == model_answer:
            correct_answers += 1
        elif model_answer == 4:
            unretrievable += 1
        
        count += 1
        
        if count >= limit:
            return (correct_answers, unretrievable, count)


In [None]:
dataset = load_dataset("bigbio/sciq", "sciq_bigbio_qa", split="test")
dataset = dataset.shuffle(seed=seed)

evaluations = {}

class DatasetEvaluation:
    def __init__(self, model_name, retriever_name, correct, unretrievable, total):
        self.model_name = model_name
        self.retriever_name = retriever_name
        self.correct = correct
        self.unretrievable = unretrievable
        self.total = total

    def __str__(self):
        return f"{self.model_name}-{self.retriever_name}: correct={self.correct} unretrievable={self.unretrievable} total={self.total}"

In [None]:
for model in models:
    evaluations[model.name] = []
    for retriever in retrievers:
        correct_answers, unretrievable_answers, total_answers = evaluate_model(
            dataset=dataset,
            get_query= lambda entry: f"{entry['question']}\n A) {entry['choices'][0]}\n B) {entry['choices'][1]}\n C) {entry['choices'][2]}\n D) {entry['choices'][3]}\n",
            get_solution= lambda entry: entry['choices'].index(entry['answer'][0]),
            generator = model.generator,
            retriever = retriever.model,
            limit=model.limit
        )
        
        evaluation = DatasetEvaluation(model.name, retriever.name, correct_answers, unretrievable_answers, total_answers)
        evaluations[model.name].append(evaluation)
        print(evaluation)

In [None]:
# comparison chart
import matplotlib.pyplot as plt
from itertools import cycle
import os
import math

def find_y_lim():
    max_percentages = []
    for model in models:
        max_percentages.append(max([e.correct / e.total * 100 for e in evaluations[model.name]]))
        
    max_value = max(max_percentages)
    return min(10 + math.ceil(max_value / 10) * 10, 100)

y_lim = find_y_lim()

for model in models:
    correct_percentages = [e.correct / e.total * 100 for e in evaluations[model.name]]
    
    color_choices = cycle(['#007F73', '#4CCD99', '#FFC700', '#FFF455'])
    colors = [next(color_choices) for _ in range(len(evaluations[model.name]))]
    
    fig, ax = plt.subplots()
    bar_width = 0.35
    bars = [e.retriever_name for e in evaluations[model.name]]
    
    bar_container = ax.bar(bars, correct_percentages, bar_width, color=colors)
    ax.bar_label(bar_container, labels=[f'{perc:.2f}%' for perc in correct_percentages])
    
    ax.set_title(f'{model.name}: performance overview ')
    ax.set_xlabel('Retriever type')
    ax.set_ylabel('Correct answers (%)')
    ax.set_ylim(0, y_lim)
    
    if not os.path.exists(plot_dir):
       os.makedirs(plot_dir)
    
    plt.savefig(os.path.join(plot_dir, f'performance_overview_{model.name}.png'), dpi=100)
    plt.show()

In [None]:
import math

def find_y_lim():
    max_values = []
    for evaluation in evaluations[model.name]:
        max_value = max([
            evaluation.correct / evaluation.total * 100, 
            (evaluation.total - (evaluation.correct + evaluation.unretrievable)) / evaluation.total * 100 ,
            evaluation.unretrievable/evaluation.total * 100
        ])
        max_values.append(max_value)
        
    max_value = max(max_values)
    return min(10 + math.ceil(max_value / 10) * 10, 100)

for model in models:
    y_lim = find_y_lim()
    for evaluation in evaluations[model.name]:
        fig, ax = plt.subplots()
        bar_width = 0.35
        bars = ["Correct", "False answer", "Unretrievable answer"]
        values = [evaluation.correct / evaluation.total * 100, (evaluation.total - (evaluation.correct + evaluation.unretrievable)) / evaluation.total * 100 ,evaluation.unretrievable/evaluation.total * 100]
    
        color_choices = cycle(['#7469B6', '#AD88C6', '#E1AFD1', '#FFE6E6'])
        colors = [next(color_choices) for _ in range(len(bars))]
        
        bar_container = ax.bar(bars, values, bar_width, color=colors)
        ax.bar_label(bar_container, labels=[f'{val:.2f}%' for val in values])
        
        ax.set_title(f'{evaluation.model_name}: {evaluation.retriever_name} details')
        ax.set_xlabel('Type')
        ax.set_ylabel('Answers')
        ax.set_ylim(0, y_lim)
        
        plt.savefig(os.path.join(plot_dir, f'{evaluation.model_name}_{evaluation.retriever_name}_details.png'), dpi=100)
        plt.show()