In [None]:
# CONFIGURATION (GLOBALS)
from rag.generator import Llama3_8b, Llama2_4b, Gpt2, Phi15_1b
from rag.retriever import Wiki10k, MedicalTextbook, WikiDoc, Wikipedia, WebSearch, PerfectSearch
import os
import json

class Model:
    def __init__(self, name, generator, limit):
        self.name = name
        self.generator = generator
        self.limit = limit

    def __str__(self):
        return f"Model: name={self.name} limit={self.limit}"

class RetrieverData:
    def __init__(self, name, model):
        self.name = name
        self.model = model

    def __str__(self):
        return f"{self.name}"


class DatasetEvaluation:
    def __init__(self, model_name, retriever_name, correct, unretrievable, total):
        self.model_name = model_name
        self.retriever_name = retriever_name
        self.correct = correct
        self.unretrievable = unretrievable
        self.total = total

    def to_dict(self):
        return {
            'model_name': self.model_name,
            'retriever_name': self.retriever_name,
            'correct': self.correct,
            'unretrievable': self.unretrievable,
            'total': self.total
        }
    
    @classmethod
    def from_dict(cls, data):
        return cls(data['model_name'], data['retriever_name'], data['correct'], data['unretrievable'], data['total'])


    def __str__(self):
        return f"{self.model_name}-{self.retriever_name}: correct={self.correct} unretrievable={self.unretrievable} total={self.total}"

    def __repr__(self):
            return self.__str__()

seed=42
plot_dir = "./plots/MMLU - CollegeMedicine with WebSearch"
evaluations_path = os.path.join(plot_dir,'evaluations.json')

limit = 30

models = [
#     Model("GPT 2", Gpt2(), limit), 
#     Model("Phi-1.5b", Phi15_1b(), limit), 
#     Model("Llama2-4b", Llama2_4b(), limit),
]

retrievers = [
#    RetrieverData("Generator (without RAG)", None),
#    RetrieverData("RAG - Wikipedia (10k) dataset", Wiki10k()),
#    RetrieverData("RAG - Medical textbook dataset", MedicalTextbook()),
#    RetrieverData("RAG - WikiDoc dataset", WikiDoc()),
#    RetrieverData("RAG - WebSearch as retriever", WebSearch()),
#    RetrieverData("RAG - Wikipedia (english sentences) dataset", Wikipedia()),
#    RetrieverData("RAG - Perfect context", PerfectSearch()),
]

def deserialize_evaluations(filename):
    if os.path.exists(filename):
        with open(filename, 'r') as f:
            data = json.load(f)
        
        # Convert the dictionary of dictionaries back to a dictionary of DatasetEvaluation instances
        deserialized_evaluations = {
            key: [DatasetEvaluation.from_dict(evaluation) for evaluation in value]
            for key, value in data.items()
        }
        
        return deserialized_evaluations
    return {}

evaluations = {}
evaluations = deserialize_evaluations(evaluations_path)
print(evaluations)

In [None]:
from dotenv import load_dotenv
load_dotenv()

In [None]:
from datasets import load_dataset
import random
import re

# gets a response (string) as input and extracts the answer 
# (returns 0,1,2 or 3 for the picked answer, or 4 if no answer could be extracted
def extract_answer(response, entry):
    # get everthing between "Answer:" and "Explanation:"
    if "Answer:" in response:
      response = response.split("Answer:")[1] 
    if "Explanation:" in response:
      response = response.split("Explanation:")[0]
    response = response.strip()
    response = response.replace("\"", "'")
    response = response.replace("\n", " ")

    # match letter answers
    pattern = re.compile(r'(A[.,)]|B[.,)]|C[.,)]|D[.,)])')
    match = pattern.search(response)
    
    if match:
        matched_string = match.group()
        if 'A' in matched_string:
            return 0
        elif 'B' in matched_string:
            return 1
        elif 'C' in matched_string:
            return 2
        elif 'D' in matched_string:
            return 3

    # match text answers in first 500 characters otherwise we regard it as not answered
    matched_string_index = [
        response.index(entry['choices'][0]) if entry['choices'][0] in response else 500, 
        response.index(entry['choices'][1]) if entry['choices'][1] in response else 500, 
        response.index(entry['choices'][2]) if entry['choices'][2] in response else 500, 
        response.index(entry['choices'][3]) if entry['choices'][3] in response else 500
    ]
    min_value = min(matched_string_index)
    if min_value < 500:
        return matched_string_index.index(min_value)

    #print("Could not extract answer")
    #print(f"\n-----------------------------------\nResponse: '{response}'\n-----------------------------------")
    return 4 # not existing index --> definitley wrong
        

def evaluate_model(dataset, get_query, get_solution, generator, retriever, limit):
    count = 0
    correct_answers = 0
    unretrievable = 0
    random.seed(seed) # set the seed for the suffle method
    
    for entry in dataset:
        solution_entry = entry['choices'][get_solution(entry)]
        random.shuffle(entry['choices'])
        query = get_query(entry)
        solution = entry['choices'].index(solution_entry)
        context = ""
        context_prompt = ""
        if retriever != None:
            context = "Context:\n" + "\n".join(retriever(query))  
            context_prompt = "Answer based on the context provided."
    
        
        # trim context if necessary
        context = context[:2048] # hard cap
        tokenizer = generator.get_tokenizer()
        # 400 as the length of the prompt + some extra
        while len(tokenizer.tokenize(query)) + 400 + len(tokenizer.tokenize(context)) > tokenizer.model_max_length:
            words = context.split()
            context = ' '.join(words[:-30])
        
        promt = f"""
        You are a helpful AI assistant. {context_prompt} Think step by step and answer either with A), B), C) or D). Add nothing else.
        
        {context}
        
        Query:
        {query}
        
        Answer:
        """
        
        response = generator(promt, max_new_tokens=200)
        #print(f"\n#########################################\n{response}\n#########################################\n")

        # answer extraction
        model_answer = extract_answer(response, entry)
        model_text_answer = entry['choices'][model_answer] if model_answer >= 0 and model_answer <= 3 else "None"
        
        print(f"Question {count}: Expected={solution} Got={model_answer}")
        print(f"\t Expected={entry['choices'][solution]} Got={model_text_answer}")
        
        if solution == model_answer:
            correct_answers += 1
        elif model_answer == 4:
            unretrievable += 1
        
        count += 1
        
        if count >= limit:
            return (correct_answers, unretrievable, count)


In [None]:
dataset = load_dataset("cais/mmlu", 'college_medicine', split="test")
#dataset = load_dataset("bigbio/sciq", "sciq_bigbio_qa", split="test")
dataset = dataset.shuffle(seed=seed)

# filter out long questions
dataset = dataset.filter(lambda entry: len(entry['question']) <= 500)

In [None]:
from tqdm import tqdm 

for model in tqdm(models):
    evaluations[model.name] = []
    for retriever in tqdm(retrievers):
        print(f"Evaluating model {model.name} - {retriever.name}")
        correct_answers, unretrievable_answers, total_answers = evaluate_model(
            dataset=dataset,
            get_query= lambda entry: f"{entry['question']}\n A) {entry['choices'][0]}\n B) {entry['choices'][1]}\n C) {entry['choices'][2]}\n D) {entry['choices'][3]}\n",
            get_solution= lambda entry: entry['answer'], # how to get index of correct solution
            #get_solution= lambda entry: entry['choices'].index(entry['answer'][0]),
            generator = model.generator,
            retriever = retriever.model,
            limit=model.limit
        )

        evaluation = DatasetEvaluation(model.name, retriever.name, correct_answers, unretrievable_answers, total_answers)
        evaluations[model.name].append(evaluation)
        print(evaluation)

In [None]:
# comparison chart
import matplotlib.pyplot as plt
from itertools import cycle
import os
import math

def find_y_lim():
    max_percentages = []
    for model_name in evaluations.keys():
        max_percentages.append(max([e.correct / e.total * 100 for e in evaluations[model_name]]))
        
    max_value = max(max_percentages)
    return min(10 + math.ceil(max_value / 10) * 10, 100)

y_lim = find_y_lim()

for model_name in evaluations.keys():
    correct_percentages = [e.correct / e.total * 100 for e in evaluations[model_name]]
    
    color_choices = cycle(['#007F73', '#4CCD99', '#FFC700', '#FFF455', '#95D2B3'])
    colors = [next(color_choices) for _ in range(len(evaluations[model_name]))]
    
    fig, ax = plt.subplots()
    bar_width = 0.35
    bars = [e.retriever_name for e in evaluations[model_name]]
    
    bar_container = ax.bar(bars, correct_percentages, bar_width, color=colors)
    ax.bar_label(bar_container, labels=[f'{perc:.2f}%' for perc in correct_percentages])
    
    ax.set_title(f'{model_name}: performance overview ')
    ax.set_xlabel('Retriever type')
    ax.set_ylabel('Correct answers (%)')
    ax.set_ylim(0, y_lim)
    plt.xticks(rotation=50)
    fig.tight_layout()
    
    if not os.path.exists(plot_dir):
       os.makedirs(plot_dir)
    
    plt.savefig(os.path.join(plot_dir, f'performance_overview_{model_name}.png'), dpi=300)
    plt.show()

In [None]:
import math

def find_y_lim(model_name):
    max_values = []
    for evaluation in evaluations[model_name]:
        max_value = max([
            evaluation.correct / evaluation.total * 100, 
            (evaluation.total - (evaluation.correct + evaluation.unretrievable)) / evaluation.total * 100 ,
            evaluation.unretrievable/evaluation.total * 100
        ])
        max_values.append(max_value)
        
    max_value = max(max_values)
    return min(10 + math.ceil(max_value / 10) * 10, 100)

for model_name in evaluations.keys():
    y_lim = find_y_lim(model_name)
    for evaluation in evaluations[model_name]:
        fig, ax = plt.subplots()
        bar_width = 0.35
        bars = ["Correct", "False answer", "Unretrievable answer"]
        values = [evaluation.correct / evaluation.total * 100, (evaluation.total - (evaluation.correct + evaluation.unretrievable)) / evaluation.total * 100 ,evaluation.unretrievable/evaluation.total * 100]
    
        color_choices = cycle(['#7469B6', '#AD88C6', '#E1AFD1', '#FFE6E6'])
        colors = [next(color_choices) for _ in range(len(bars))]
        
        bar_container = ax.bar(bars, values, bar_width, color=colors)
        ax.bar_label(bar_container, labels=[f'{val:.2f}%' for val in values])
        
        ax.set_title(f'{evaluation.model_name}: {evaluation.retriever_name} details')
        ax.set_xlabel('Type')
        ax.set_ylabel('Answers')
        ax.set_ylim(0, y_lim)
        
        plt.savefig(os.path.join(plot_dir, f'{evaluation.model_name}_{evaluation.retriever_name}_details.png'), dpi=300)
        plt.show()

In [None]:
def serialize_evaluations(evaluations, filename):
    serialized_evaluations = {
        key: [evaluation.to_dict() for evaluation in value] for key, value in evaluations.items()
    }
    
    with open(filename, 'w') as f:
        json.dump(serialized_evaluations, f, indent=4)


print(evaluations)
serialize_evaluations(evaluations, evaluations_path)