In [None]:
from dotenv import load_dotenv
load_dotenv()

In [None]:
# Retriever but instead of a predefined db we use a WebSearch
# Save your APi key as an environment variable: WEB_SEARCH_TOKEN=<token>

import requests
import json

class WebSearch():
    def __init__(self):
      pass

    def __call__(self, query, k=10):
      url = 'https://api.tavily.com/search'
      parameters = {
        "api_key": os.environ.get("WEB_SEARCH_TOKEN"),
        "query": query,
        "search_depth": "basic",
        "include_answer": False,
        "include_images": False,
        "include_raw_content": False,
        "max_results": k,
        "include_domains": [],
        "exclude_domains": []
      }

      response = requests.post(url, json = parameters)
      result = response.json()

      contextEntries = []

      for entry in result["results"]:
        content = entry["content"]
        if not content or content=="[Removed]":
          continue

        # optionally trim content
        # content = content[:50]

        contextEntries.append(content)
        if len(contextEntries) >= 10:
          break

      return "\n".join(contextEntries)

In [None]:
from datasets import load_dataset
from rag.generator import Llama3_8b, Llama2_4b
from rag.retriever import Medical, Wiki10k
import random

def evaluate_model(dataset, get_query, get_solution, generator, retriever, limit):
    count = 0
    correct_answers = 0
    random.seed(42) # set the seed for the suffle method
    
    for entry in dataset:
        random.shuffle(entry['choices'])
        query = get_query(entry)
        solution = get_solution(entry)
        context = ""
        if retriever != None:
          context = "\n".join(retriever(query))
    
        promt = f"""
        You are a helpful AI assistant. Answer based on the context provided. Be concise and answer either with A), B), C) or D). Add nothing else.
        
        Context:
        {context}
        
        Query:
        {query}
        
        Answer:
        """
        response = generator(promt, max_new_tokens=100)
        #print(f"{response}\n----------------------------------------\n\n")
    
        # get everthing between "Answer:" and "Explanation:"
        if "Answer:" in response:
          response = response.split("Answer:")[1] 
        if "Explanation:" in response:
          response = response.split("Explanation:")[0]
        response = response.strip()
        response = response.replace("\"", "'")
        response = response.replace("\n", " ")

        #print("Response: " + response)

        if (response.startswith("A)")):
          model_answer = 0
        elif (response.startswith("B)")):
          model_answer = 1
        elif (response.startswith("C)")):
          model_answer = 2
        elif (response.startswith("D)")):
          model_answer = 3
        else:
            print("Could not extract answer")
            model_answer = 4 # not existing index --> definitley wrong

        print(f"Question {count}: Expected={entry['choices'].index(entry['answer'][0])} Got={model_answer}")
        
        if solution == model_answer:
            correct_answers += 1
        
        count += 1
        if count >= limit:
            return (correct_answers, count)


In [None]:
dataset = load_dataset("bigbio/sciq", "sciq_bigbio_qa", split="test")
dataset = dataset.shuffle(seed=42)

limit=5

correct_medical_rag, total_medical_rag = evaluate_model(
    dataset=dataset,
    get_query= lambda entry: f"{entry['question']}\n A) {entry['choices'][0]}\n B) {entry['choices'][1]}\n C) {entry['choices'][2]}\n D) {entry['choices'][3]}\n",
    get_solution= lambda entry: entry['choices'].index(entry['answer'][0]),
    generator = Llama2_4b(),
    retriever = Wiki10k(),
    limit=limit
)
print(f"RAG: Correct={correct_medical_rag} total={total_medical_rag}")

In [None]:
correct_model, total_model = evaluate_model(
    dataset=dataset,
    get_query= lambda entry: f"{entry['question']}\n A) {entry['choices'][0]}\n B) {entry['choices'][1]}\n C) {entry['choices'][2]}\n D) {entry['choices'][3]}\n",
    get_solution= lambda entry: entry['choices'].index(entry['answer'][0]),
    generator = Llama2_4b(),
    retriever = None,
    limit=limit
)

print(f"Normal model: Correct={correct_model} total={total_model}")

In [None]:
"""
correct_web_model, total_web_model = evaluate_model(
    dataset=dataset,
    get_query= lambda entry: f"{entry['question']}\n A) {entry['choices'][0]}\n B) {entry['choices'][1]}\n C) {entry['choices'][2]}\n D) {entry['choices'][3]}\n",
    get_solution= lambda entry: entry['choices'].index(entry['answer'][0]),
    generator = Llama2_4b(),
    retriever = None,
    limit=limit
)

print(f"Web Search model: Correct={correct_web_model} total={total_web_model}")
"""

In [None]:
import matplotlib.pyplot as plt

models = ['RAG', 'Normal']
data = [(correct_medical_rag, total_medical_rag), (correct_model, total_model)]
percentages = []

for correct, total in data:
  percentages.append(correct / total * 100)

fig, ax = plt.subplots()
bar_width = 0.35
index = range(len(models))

bar1 = ax.bar(index, percentages, bar_width, label='Correct Percentage')

ax.set_xlabel('Model type')
ax.set_ylabel('Correct answers')
ax.set_title('Model performance overview')
ax.set_xticks(index)
ax.set_xticklabels(models)
ax.legend()

plt.savefig('model_performance_overview.png', dpi=100)
plt.show()