In [None]:
import pandas as pd
from advanced_rag import answer_query, llm
# from sklearn.metrics.pairwise import cosine_similarity
from sentence_transformers import SentenceTransformer, util

def is_correct(pred: str, gold: str, semantic_model, threshold: float = 0.65) -> bool:
    emb_pred = semantic_model.encode(pred, convert_to_tensor=True)
    emb_gold = semantic_model.encode(gold, convert_to_tensor=True)
    sim = util.cos_sim(emb_pred, emb_gold).item()
    return sim >= threshold

def compute_accuracy(results_df, group_col):
    metrics = results_df.groupby(group_col)["Correct"].mean().reset_index()
    metrics.columns = [group_col, "Accuracy"]
    return metrics


def evaluate_rag(rag_query_fn, questions_df, semantic_model):
    results = []
    for _, row in questions_df.iterrows():
        q, gold, category, trick, difficulty = row["Question"], row["Answer"], row["Category"], row["IsTrick"], row["Difficulty"]
        pred = rag_query_fn(q)
        correct = is_correct(pred, gold, semantic_model)
        results.append({
            "Question": q,
            "Gold Answer": gold,
            "Predicted Answer": pred,
            "Category": category,
            "Difficulty": difficulty,
            "IsTrick": trick,
            "Correct": correct
        })

    results_df = pd.DataFrame(results)
    return results_df

def compute_metrics(results_df):
    category_acc = compute_accuracy(results_df, "Category")
    difficulty_acc = compute_accuracy(results_df, "Difficulty")

    print("\n=== Overall Accuracy ===")
    print(results_df["Correct"].mean())

    print("\n=== Accuracy by Category ===")
    print(category_acc)

    print("\n=== Accuracy by Difficulty ===")
    print(difficulty_acc)



for rag in (answer_query, llm.invoke):
    eval_questions_path = "evaluation/1984_test_questions.json"
    model = SentenceTransformer('all-MiniLM-L6-v2')

    df = pd.read_json(eval_questions_path)

    results_df = evaluate_rag(rag_query_fn=rag, questions_df=df, semantic_model=model)

    results_df.to_json("", orient="records", indent=2)



Preparing query expansion...
Expanding query...

Expanded Prompts:
 - What item does Winston purchase from the vintage store
 - and what makes it important?
What artifact does Winston acquire at the antique store
 - and what's its significance?
What relic does Winston buy in the second-hand shop
 - and why is it meaningful? 
What keepsake does Winston purchase from the old shop
 - and what's its importance?

Retrieving context and generating answer...
Done!


Preparing query expansion...
Expanding query...

Expanded Prompts:
 - Who conveys the message to Winston about meeting in a location without shadows
 - What informs Winston of their future meeting in a place free from darkness
 - Who communicates with Winston about their encounter in an area devoid of night
 - Who delivers the information to Winston that they will meet where there is no shadow.

Retrieving context and generating answer...
Done!


Preparing query expansion...
Expanding query...

Expanded Prompts:
 - What phrase is