In [None]:
from dotenv import load_dotenv
load_dotenv()

In [None]:
from datasets import load_dataset
from rag.generator import Llama3_8b, Gpt2
from rag.retriever import Medical, Wiki10k

def evaluate_model(dataset, get_query, get_solution, generator, retriever, limit):
    count = 0
    correct_answers = 0
    
    for entry in dataset:
        query = get_query(entry)
        solution = get_solution(entry)
        context = ""
        if retriever != None:
          context = "\n".join(retriever(query))
    
        promt = f"""
        You are a helpful AI assistant. Answer based on the context provided. Be concise and answer either with A), B), C) or D).
        
        Context:
        {context}
        
        Query:
        {query}
        
        Answer:
        """
        response = generator(promt, max_new_tokens=100)
        #print(f"{response}\n----------------------------------------\n\n")
    
        # get everthing between "Answer:" and "Explanation:"
        if "Answer:" in response:
          response = response.split("Answer:")[1] 
        if "Explanation:" in response:
          response = response.split("Explanation:")[0]
        response = response.strip()
        response = response.replace("\"", "'")
        response = response.replace("\n", " ")

        if (response.startswith("A)")):
          model_answer = 0
        elif (response.startswith("B)")):
          model_answer = 1
        elif (response.startswith("C)")):
          model_answer = 2
        elif (response.startswith("D)")):
          model_answer = 3
        else:
            print("Could not extract answer")
            model_answer = 4 # not existing index --> definitley wrong

        print(f"Expected={entry['choices'].index(entry['answer'][0])} Got={model_answer}")
        
        if solution == model_answer:
            correct_answers += 1
        
        count += 1
        if count >= limit:
            return (correct_answers, count)


In [None]:
dataset = load_dataset("bigbio/sciq", "sciq_bigbio_qa", split="test")

correct_medical_rag, total_medical_rag = evaluate_model(
    dataset=dataset,
    get_query= lambda entry: f"{entry['question']}\n A) {entry['choices'][0]}\n B) {entry['choices'][1]}\n C) {entry['choices'][2]}\n D) {entry['choices'][3]}\n",
    get_solution= lambda entry: entry['choices'].index(entry['answer'][0]),
    generator = Llama3_8b(),
    retriever = Wiki10k(),
    limit=10
)
print(f"RAG: Correct={correct_medical_rag} total={total_medical_rag}")

In [None]:
correct_model, total_model = evaluate_model(
    dataset=dataset,
    get_query= lambda entry: f"{entry['question']}\n A) {entry['choices'][0]}\n B) {entry['choices'][1]}\n C) {entry['choices'][2]}\n D) {entry['choices'][3]}\n",
    get_solution= lambda entry: entry['choices'].index(entry['answer'][0]),
    generator = Llama3_8b(),
    retriever = None,
    limit=5
)

print(f"Normal model: Correct={correct_model} total={total_model}")

In [None]:
import matplotlib.pyplot as plt

correct_medical_rag, total_medical_rag = (33,100)
correct_model, total_model = (66,100)

# Sample data
models = ['RAG', 'Normal']
data = [(correct_medical_rag, total_medical_rag), (correct_model, total_model)]
percentages = []

for correct, total in data:
  percentages.append(correct / total * 100)

# Plotting 
fig, ax = plt.subplots()
bar_width = 0.35
index = range(len(models))

bar1 = ax.bar(index, percentages, bar_width, label='Correct Percentage')

# Adding labels, title, and x-axis tick labels
ax.set_xlabel('Model')
ax.set_ylabel('Percentage')
ax.set_title('Correct Answers Percentage by Category')
ax.set_xticks(index)
ax.set_xticklabels(models)
ax.legend()

# Display the graph
plt.show()