**Imports**

In [None]:
from simple_RAG_utilities import *
from config import prompt
from tqdm import tqdm

import json

In [None]:
embedding, model, vector_store = setup_simple_rag()

In [None]:
test_examples, ground_truth_answers = prepare_benchmark_data(dataset_name="medmcqa")

# for example, answer in zip(test_examples, ground_truth_answers):
#     print(f"Example: {example}\nAnswer: {answer}\n\n")

Naive RAG Precision

In [None]:
simple_rag_answers = []
precision = 0.0

chain = prompt | model

for example in tqdm(test_examples, desc="Generating Answers", total=len(test_examples)):

    simple_rag_answers.append(run_simple_rag(chain, example[0], example[1], vector_store))

for i in range(len(simple_rag_answers)):
    if simple_rag_answers[i] == ground_truth_answers[i]:
        precision += 1

precision = precision / len(simple_rag_answers)
print(f"Precision: {precision}")

In [None]:
simple_rag_responses = []

chain = prompt | model

temp_dict = {}

for index, example in tqdm(enumerate(test_examples), desc="Generating Responses", total=len(test_examples)):
    question = example[0]
    options = example[1]

    context = retrieve_relevant_document(question, vector_store)
    context_text = context[0].page_content.capitalize()
    original_response = generate_response(chain, {"paragraph": context_text, "question": question, "options": options})

    if original_response in ["A", "B", "C", "D"]:
        temp_dict[f"question_{index}"] = { "original_response": original_response }
        perturbations = remove_word_span(context_text, 5)

        for per in tqdm(perturbations, desc=f"Processing Perturbations (Test Example: {index})", total=len(perturbations)):
            perturbed_text = per[0]
            removed_token = per[1]
            position = per[2]

            temp_response = generate_response(chain, {"paragraph": perturbed_text, "question": question, "options": options})

            if temp_response != original_response:
                temp_dict[f"question_{index}"][f"perturbation_{position}"] = {"perturbed_text": perturbed_text, "removed_token(critical)": removed_token, "answer": temp_response}

with open(r"../results/simple_rag_result_per_word_5.json", "w", encoding="utf-8") as json_file:
    json.dump(temp_dict, json_file, indent=3)

**Comparison of Exhaustive (RAG-Ex) Approach (span_size: 5) and KGRAG-Ex: LLM Calls and Token Count for**

In [None]:
import matplotlib.pyplot as plt
import numpy as np

with open(r"../results/mmlu_calls_amount.json", "r", encoding="utf-8") as json_file:
    kg_rag = json.load(json_file)

with open(r"../results/mmlu_calls_amount_simple.json", "r", encoding="utf-8") as json_file:
    simple_rag = json.load(json_file)

example_comparisons = {}

for test in kg_rag:
    example_comparisons[f"comparison_of_{test}"] = {"simple": simple_rag[test], "kg": kg_rag[test]}

with open(r"../results/mmlu_calls_amount_comparison.json", "w", encoding="utf-8") as json_file:
    json.dump(example_comparisons, json_file, indent=3)





data = example_comparisons

llm_calls_simple = []
llm_calls_kg = []

tokens_simple = []
tokens_kg = []

for key, values in data.items():
    simple = values["simple"]
    kg = values["kg"]

    if "llm_calls" in simple and "llm_calls" in kg:
        llm_calls_simple.append(simple["llm_calls"])
        llm_calls_kg.append(kg["llm_calls"])
    else:
        print(f"Missing llm_calls data for {key}")

    if "total_tokens" in simple and "total_tokens" in kg:
        tokens_simple.append(simple["total_tokens"])
        tokens_kg.append(kg["total_tokens"])
    else:
        print(f"Missing total_tokens data for {key}")

median_llm_simple = np.median(llm_calls_simple)
median_llm_kg = np.median(llm_calls_kg)
diff_llm = median_llm_simple - median_llm_kg

median_tokens_simple = np.median(tokens_simple)
median_tokens_kg = np.median(tokens_kg)
diff_tokens = median_tokens_simple - median_tokens_kg

print("\n--- Median Summary ---")
print(f"{'Metric':<20} {'Simple':>10} {'KG':>10} {'Diff (S - KG)':>15}")
print("-" * 55)
print(f"{'LLM Calls':<20} {median_llm_simple:>10.1f} {median_llm_kg:>10.1f} {diff_llm:>15.1f}")
print(f"{'Total Tokens':<20} {median_tokens_simple:>10.1f} {median_tokens_kg:>10.1f} {diff_tokens:>15.1f}")

llm_calls_diff_raw = [s - k for s, k in zip(llm_calls_simple, llm_calls_kg)]
tokens_diff_raw = [s - k for s, k in zip(tokens_simple, tokens_kg)]

# plt.figure(figsize=(12, 6))
# plt.subplot(2, 1, 1)
# plt.plot(llm_calls_diff_raw, marker='o', linestyle='-', color='blue')
# plt.axhline(diff_llm, color='gray', linestyle='--', label='Median Difference')
# plt.title('Raw Difference in LLM Calls (Simple - KG)')
# plt.ylabel('Difference')
# plt.grid(True)
# plt.legend()

# plt.subplot(2, 1, 2)
# plt.plot(tokens_diff_raw, marker='o', linestyle='-', color='green')
# plt.axhline(diff_tokens, color='gray', linestyle='--', label='Median Difference')
# plt.title('Raw Difference in Total Tokens (Simple - KG)')
# plt.xlabel('Example Index')
# plt.ylabel('Difference')
# plt.grid(True)
# plt.legend()

# plt.tight_layout()
# plt.show()