In [None]:
from dotenv import load_dotenv
seed=42
load_dotenv()

In [None]:
from datasets import load_dataset
from rag.generator import Llama3_8b, Llama2_4b, Gpt2
from rag.retriever import Medical, Wiki10k
import random

# gets a response (string) as input and extracts the answer 
# (returns 0,1,2 or 3 for the picked answer, or 4 if no answer could be extracted
def extract_answer(response, entry):
    # get everthing between "Answer:" and "Explanation:"
    if "Answer:" in response:
      response = response.split("Answer:")[1] 
    if "Explanation:" in response:
      response = response.split("Explanation:")[0]
    response = response.strip()
    response = response.replace("\"", "'")
    response = response.replace("\n", " ")
    
    if (response.startswith("A)")):
          model_answer = 0
    elif (response.startswith("B)")):
      model_answer = 1
    elif (response.startswith("C)")):
      model_answer = 2
    elif (response.startswith("D)")):
      model_answer = 3
    elif ("A)" in response):
      model_answer = 0
    elif ("B)" in response):
      model_answer = 1
    elif ("C)" in response):
      model_answer = 2
    elif ("D)" in response):
      model_answer = 3
    elif (entry['choices'][0] in response):
      model_answer = 0
    elif (entry['choices'][1] in response):
      model_answer = 1
    elif (entry['choices'][2] in response):
      model_answer = 2
    elif (entry['choices'][3] in response):
      model_answer = 3
    else:
        print("Could not extract answer")
        model_answer = 4 # not existing index --> definitley wrong
        print(f"\n-----------------------------------\nResponse: '{response}'\n-----------------------------------")
    
    return model_answer
        

def evaluate_model(dataset, get_query, get_solution, generator, retriever, limit):
    count = 0
    correct_answers = 0
    random.seed(seed) # set the seed for the suffle method
    
    for entry in dataset:
        random.shuffle(entry['choices'])
        query = get_query(entry)
        solution = get_solution(entry)
        context = ""
        if retriever != None:
          context = "\n".join(retriever(query))
    
        promt = f"""
        You are a helpful AI assistant. Answer based on the context provided. Be concise and answer either with A), B), C) or D). Add nothing else.
        
        Context:
        {context}
        
        Query:
        {query}
        
        Answer:
        """
        response = generator(promt, max_new_tokens=100)
        #print(f"\n----------------------------------------\n{response}\n----------------------------------------\n")

        # answer extraction
        model_answer = extract_answer(response, entry)

        print(f"Question {count}: Expected={entry['choices'].index(entry['answer'][0])} Got={model_answer}")
        
        if solution == model_answer:
            correct_answers += 1
        
        count += 1
        if count >= limit:
            return (correct_answers, count)


In [None]:
dataset = load_dataset("bigbio/sciq", "sciq_bigbio_qa", split="test")
dataset = dataset.shuffle(seed=seed)
limit=5
evaluations = []

class DatasetEvaluation:
    def __init__(self, name, correct, total):
        self.name = name
        self.correct = correct
        self.total = total

In [None]:
correct_answers, total_answers = evaluate_model(
    dataset=dataset,
    get_query= lambda entry: f"{entry['question']}\n A) {entry['choices'][0]}\n B) {entry['choices'][1]}\n C) {entry['choices'][2]}\n D) {entry['choices'][3]}\n",
    get_solution= lambda entry: entry['choices'].index(entry['answer'][0]),
    generator = Gpt2(),
    retriever = Wiki10k(),
    limit=limit
)
print(f"RAG - Wikipedia dataset: Correct={correct_answers} total={total_answers}")
evaluations.append(DatasetEvaluation("RAG - Wikipedia dataset", correct_answers, total_answers))

In [None]:
"""
correct_answers, total_answers = evaluate_model(
    dataset=dataset,
    get_query= lambda entry: f"{entry['question']}\n A) {entry['choices'][0]}\n B) {entry['choices'][1]}\n C) {entry['choices'][2]}\n D) {entry['choices'][3]}\n",
    get_solution= lambda entry: entry['choices'].index(entry['answer'][0]),
    generator = Gpt2(),
    retriever = Medical(),
    limit=limit
)
print(f"RAG - Medical dataset: Correct={correct_answers} total={total_answers}")
evaluations.append(DatasetEvaluation("RAG - Medical dataset", correct_answers, total_answers))
"""

In [None]:
correct_answers, total_answers = evaluate_model(
    dataset=dataset,
    get_query= lambda entry: f"{entry['question']}\n A) {entry['choices'][0]}\n B) {entry['choices'][1]}\n C) {entry['choices'][2]}\n D) {entry['choices'][3]}\n",
    get_solution= lambda entry: entry['choices'].index(entry['answer'][0]),
    generator = Gpt2(),
    retriever = None,
    limit=limit
)

print(f"Generator (without RAG): Correct={correct_answers} total={total_answers}")
evaluations.append(DatasetEvaluation("Generator (without RAG)", correct_answers, total_answers))

In [None]:
"""
correct_web_model, total_web_model = evaluate_model(
    dataset=dataset,
    get_query= lambda entry: f"{entry['question']}\n A) {entry['choices'][0]}\n B) {entry['choices'][1]}\n C) {entry['choices'][2]}\n D) {entry['choices'][3]}\n",
    get_solution= lambda entry: entry['choices'].index(entry['answer'][0]),
    generator = Gpt2(),
    retriever = WebSearch(),
    limit=limit
)

print(f"RAG - WebSearch as retriever: Correct={correct_web_model} total={total_web_model}")
evaluations.append(DatasetEvaluation("RAG - WebSearch as retriever", correct_answers, total_answers))
"""

In [None]:
import matplotlib.pyplot as plt

correct_percentages = [e.correct / e.total * 100 for e in evaluations]

color_choices = ['red', 'blue', 'green', 'orange', 'purple', 'cyan', 'magenta', 'yellow']
colors = random.choices(color_choices, k=len(evaluations))

fig, ax = plt.subplots()
bar_width = 0.35
index = range(len(evaluations))

bar1 = ax.bar(index, correct_percentages, bar_width, color=colors, label='Correct Percentage')

ax.set_xlabel('Model type')
ax.set_ylabel('Correct answers')
ax.set_title('Model performance overview')
ax.set_xticks(index)
ax.set_xticklabels([evaluation.name for evaluation in evaluations])
ax.legend()

plt.savefig('model_performance_overview.png', dpi=100)
plt.show()