In [1]:
import sys
import os
current_dir = os.getcwd()
parent_dir = os.path.abspath(os.path.join(current_dir, '..'))
sys.path.append(parent_dir)

from SHapRAG import*
import pandas as pd
import re
import pickle
import requests
import numpy as np
from tqdm import tqdm
import faiss
from sklearn.metrics.pairwise import cosine_similarity
from scipy.stats import spearmanr, pearsonr
from itertools import combinations
import math
from transformers import AutoModelForCausalLM, AutoTokenizer
import torch
from sentence_transformers import SentenceTransformer # type: ignore
# from google.oauth2 import service_account
# import vertexai
# from vertexai.generative_models import GenerativeModel, Image, Part
# tqdm.pandas()

# vertexai.init(
#     project="oag-ai",
#     credentials=service_account.Credentials.from_service_account_file("google-credentials.json"),
# )
torch.cuda.set_device(1)  # Use GPU1, or 2, or 3

splits = {'train': 'question-answer-passages/train-00000-of-00001.parquet', 'test': 'question-answer-passages/test-00000-of-00001.parquet'}
df = pd.read_parquet("hf://datasets/enelpol/rag-mini-bioasq/" + splits["train"])

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
df1 = pd.read_parquet("hf://datasets/enelpol/rag-mini-bioasq/text-corpus/test-00000-of-00001.parquet")
df1['passage']=df1['passage'].str.replace(r'[\n]', ' ', regex=True)
df['question']=df['question'].str.replace(r'[\n]', ' ', regex=True)

In [3]:
def gen_embs(qtext, model="nomic"):
    if model=="nomic":
    
        data = {
            "model": "nomic-embed-text",
            "prompt": qtext
        }
        return np.array(requests.post('http://localhost:11434/api/embeddings', json=data).json()['embedding'])
    else:
        return SentenceTransformer("abhinand/MedEmbed-large-v0.1").encode([qtext], convert_to_numpy=True)
    
# df1['embedding'] = df1['passage'].progress_apply(lambda x: gen_embs(x, model='medemb'))

# # Save the embeddings to a pickle file
# with open('embed_bioasq_medemb.pkl', 'wb') as f:
#     pickle.dump(df1['embedding'].tolist(), f)
# print("Embeddings saved to embed_bioasq_medemb.pkl")

In [4]:
with open("/data/embed_bioasq_medemb.pkl", "rb") as f:
    embeddings = pickle.load(f)

In [5]:
def normalize_embeddings(embeddings):
    return embeddings / np.linalg.norm(embeddings, axis=1, keepdims=True)

embeddings = normalize_embeddings(embeddings)

def index_documents(method="faiss", index_name="recipes_nomic", es_host="http://localhost:9200"):
    if method == "faiss":
        dimension = embeddings.shape[1]
        index = faiss.IndexFlatL2(dimension)
        index.add(embeddings)
        faiss.write_index(index, "/data/bioasq_nomic_faiss.index")
        print("FAISS index saved.")
        return index
    elif method == "elasticsearch":
        es = Elasticsearch(es_host)
        mapping = {"mappings": {"properties": {"text": {"type": "text"}, "vector": {"type": "dense_vector", "dims": embeddings.shape[1]}}}}
        es.indices.create(index=index_name, body=mapping, ignore=400)
        for i, (text, vector) in enumerate(zip(documents, embeddings)):
            es.index(index=index_name, id=i, body={"text": text, "vector": vector.tolist()})
        print("Elasticsearch index created.")
        return es
# index_documents(method="faiss", index_name="bioasq_nomic_faiss", es_host="http://localhost:9200")
faiss_index = faiss.read_index("/data/medemb_bioasq_faiss.index")


In [6]:
def retrieve_documents(query, k=5):
    query_embedding = gen_embs(query, model="medemb")
    query_embedding = normalize_embeddings(query_embedding.reshape(1, -1))
    scores, indices = faiss_index.search(query_embedding, k)
    return [df1['passage'][i] for i in indices[0]], scores

In [8]:
query='What is the implication of histone lysine methylation in medulloblastoma?'
docs, scores=retrieve_documents(query=query, k=10)

In [11]:
len("".join(docs))

13011

In [8]:
print("\nInstantiating ShapleyExperimentHarness (will pre-compute all utilities)...")
harness = ShapleyExperimentHarness(
    items=docs,
    query=query,
    # llm_model_name="TinyLlama/TinyLlama-1.1B-Chat-v1.0", # Use a smaller model for faster demo if preferred
    llm_model_name="meta-llama/Meta-Llama-3.1-8B-Instruct", # Or your chosen model
    verbose=True
)
print(f"Harness target_response automatically generated: '{harness.target_response}'")
print(f"Number of items (n): {harness.n_items}")

# 3. Compute Attributions using different methods from the harness
results = {}
seed = 42 # For reproducibility of stochastic methods

print("\n--- Computing Attributions using Harness (from pre-computed utilities) ---")

# Adjust num_samples for WSS/ContextCite if n is very small
m_samples_for_approx = 64
results["ContextCite"] = harness.compute_contextcite_weights(num_samples=m_samples_for_approx, lasso_alpha=0.0, seed=seed) # LinReg
results["WSS"] = harness.compute_wss(num_samples=m_samples_for_approx, lasso_alpha=0.0, seed=seed) # LinReg

T_iterations = harness.n_items * 20 # Adjust iterations as needed
results["TMC"] = harness.compute_tmc_shap(num_iterations=T_iterations, performance_tolerance=0.001, seed=seed)
if beta_dist: 
    results["BetaShap (U)"] = harness.compute_beta_shap(num_iterations=T_iterations, beta_a=0.5, beta_b=0.5, seed=seed)


results["LOO"] = harness.compute_loo()
results["Exact"] = harness.compute_exact_shap() # n=6 is 64 calls, feasible

# 4. Display Results
print("\n\n--- Attribution Scores (from Harness) ---")
# Create item labels
item_labels = [f'Doc {i}' for i in range(harness.n_items)]

# Filter out None results before creating DataFrame
valid_results = {k:v for k, v in results.items() if v is not None and isinstance(v, np.ndarray) and len(v) == harness.n_items}

if valid_results:
    results_df = pd.DataFrame(valid_results, index=item_labels)
    print(results_df.round(4))

    if "Exact" in valid_results:
        print("\n--- Evaluation Metrics vs Exact Shapley ---")
        metrics_data = []
        exact_scores = valid_results["Exact"]
        for method, approx_scores in valid_results.items():
            if method != "Exact":

                # Handle potential constant arrays for correlation
                if np.all(exact_scores == exact_scores[0]) or np.all(approx_scores == approx_scores[0]):
                    pearson_c = 1.0 if np.allclose(exact_scores, approx_scores) else 0.0
                    spearman_c = 1.0 if np.allclose(exact_scores, approx_scores) else 0.0
                else:
                    pearson_c, _ = pearsonr(exact_scores, approx_scores)
                    spearman_c, _ = spearmanr(exact_scores, approx_scores)
                
                metrics_data.append({
                    "Method": method,
                    "Pearson": pearson_c,
                    "Spearman": spearman_c
                })
        
        if metrics_data:
            metrics_df = pd.DataFrame(metrics_data).set_index("Method")
            print(metrics_df.round(4))
        else:
            print("No approximate methods to compare against Exact.")
else:
    print("No valid attribution results were computed by the harness.")


Instantiating ShapleyExperimentHarness (will pre-compute all utilities)...
Loading LLM 'meta-llama/Meta-Llama-3.1-8B-Instruct' on device 'cuda'...


Loading checkpoint shards: 100%|██████████| 4/4 [00:02<00:00,  1.96it/s]
The attention mask is not set and cannot be inferred from input because pad token is same as eos token. As a consequence, you may observe unexpected behavior. Please pass your input's `attention_mask` to obtain reliable results.


LLM loaded successfully.
Generating target_response using all items...


Fetching 3 files: 100%|██████████| 3/3 [00:00<00:00, 13751.82it/s]


Generated target_response: 'The implication of histone lysine methylation in medulloblastoma is that it contributes to the pathogenesis of the disease. The study found that copy number aberrations of genes involved in histone lysine methylation, particularly at H3'
Pre-computing utilities for all 32 subsets (n=5)...


Pre-computing Utilities: 100%|██████████| 32/32 [06:39<00:00, 12.49s/it]


Pre-computation complete. Made 32 LLM calls.
Harness target_response automatically generated: 'The implication of histone lysine methylation in medulloblastoma is that it contributes to the pathogenesis of the disease. The study found that copy number aberrations of genes involved in histone lysine methylation, particularly at H3'
Number of items (n): 5

--- Computing Attributions using Harness (from pre-computed utilities) ---
Computing ContextCite Weights (m=64, using pre-computed utilities)...
Computing Weakly Supervised Shapley (m=64, using pre-computed utilities)...
Computing TMC-Shapley (T=100, using pre-computed utilities)...


                                                                    

Computing Beta-Shapley (T=100, α=0.5, β=0.5, using pre-computed utilities)...


                                                                         

Computing LOO (n=5, using pre-computed utilities)...
Computing Exact Shapley (using pre-computed utilities)...


--- Attribution Scores (from Harness) ---
       ContextCite      WSS      TMC  BetaShap (U)      LOO    Exact
Doc 0      58.7644  58.4856  59.2063       57.7403  45.8141  58.4856
Doc 1       1.7093   3.4574   3.9607        7.7198   1.1588   3.4574
Doc 2       3.4685   3.8260   3.6045        5.6109  -1.0610   3.8260
Doc 3       7.9895   6.6452   5.7517        4.3082   1.2248   6.6452
Doc 4       1.6065   2.2331   2.1240        3.9184   0.6422   2.2331

--- Evaluation Metrics vs Exact Shapley ---
              Pearson  Spearman
Method                         
ContextCite    0.9990       1.0
WSS            1.0000       1.0
TMC            0.9998       0.9
BetaShap (U)   0.9951       0.6
LOO            0.9976       0.7




In [None]:
model_name="meta-llama/Meta-Llama-3.1-8B-Instruct"
tokenizer = AutoTokenizer.from_pretrained(model_name)
device = "cuda" if torch.cuda.is_available() else "cpu"
model = AutoModelForCausalLM.from_pretrained(
    model_name,
    torch_dtype=torch.float16 if device == "cuda" else torch.float32
).to(device)
model.eval()


evaluate_shap(query, n_rag=3)

In [None]:
torch.cuda.set_device(1)  # Use GPU1, or 2, or 3


In [None]:
print("--- Initializing Experiment Harness (Pre-computing all utilities) ---")
start_harness_time = time.time()
harness = ShapleyExperimentHarness(
    items=docs,
    query=query,
    target_response=target_response,
    llm_caller=compute_logprob# Use placeholder for speed
)
print(f"Harness initialized in {time.time() - start_harness_time:.2f}s. "
        f"{len(harness.all_true_utilities)} utilities computed.")

results_exp = {}
seed_exp = 42

print("\n--- Running Methods using Pre-computed Utilities ---")

start_method_time = time.time()
results_exp["Exact"] = harness.compute_exact_shap()
print(f"Exact Shapley (from cache) took {time.time() - start_method_time:.4f}s")

start_method_time = time.time()
results_exp["ContextCite"] = harness.compute_contextcite_weights(num_samples=16, lasso_alpha=0.0, seed=seed_exp)
print(f"ContextCite (from cache) took {time.time() - start_method_time:.4f}s")

start_method_time = time.time()
results_exp["WSS"] = harness.compute_wss(num_samples=16, lasso_alpha=0.0, seed=seed_exp)
print(f"WSS (from cache) took {time.time() - start_method_time:.4f}s")

start_method_time = time.time()
results_exp["TMC"] = harness.compute_tmc_shap(num_iterations=harness.n_items * 10, performance_tolerance=0.001, seed=seed_exp)
print(f"TMC (from cache) took {time.time() - start_method_time:.4f}s")

if beta_dist:
    start_method_time = time.time()
    results_exp["BetaShap (U)"] = harness.compute_beta_shap(num_iterations=harness.n_items * 10, beta_a=0.5, beta_b=0.5, seed=seed_exp)
    print(f"BetaShap (U, from cache) took {time.time() - start_method_time:.4f}s")

start_method_time = time.time()
results_exp["LOO"] = harness.compute_loo()
print(f"LOO (from cache) took {time.time() - start_method_time:.4f}s")

# Display Results
print("\n\n--- Experiment Harness: Comparison Table ---")
valid_results_exp = {k:v for k, v in results_exp.items() if v is not None}
if valid_results_exp:
    results_df_exp = pd.DataFrame(valid_results_exp, index=[f'Item {i}' for i in range(harness.n_items)])
    print(results_df_exp.round(4))

else:
    print("No valid results were computed by the harness.")

In [None]:
from scipy.stats import spearmanr, pearsonr

if "Exact" in valid_results_exp:
            print("\n--- Experiment Harness: Metrics vs Exact Shapley ---")
            exact_scores = valid_results_exp["Exact"]
            for method, approx_scores in valid_results_exp.items():
                if method != "Exact" and approx_scores is not None and len(approx_scores) == len(exact_scores):
                    mae = np.mean(np.abs(exact_scores - approx_scores))
                    
                    # Calculate Pearson and Spearman, handle potential constant arrays
                    if np.all(exact_scores == exact_scores[0]) or np.all(approx_scores == approx_scores[0]):
                        # If one or both are constant, correlation might be ill-defined or 0/1
                        # pearsonr/spearmanr might return NaN or raise warning for constant input.
                        pearson_corr_val = 1.0 if np.allclose(exact_scores, approx_scores) else 0.0
                        spearman_corr_val = 1.0 if np.allclose(exact_scores, approx_scores) else 0.0
                    else:
                        try:
                            pearson_corr_val, _ = pearsonr(exact_scores, approx_scores)
                            spearman_corr_val, _ = spearmanr(exact_scores, approx_scores)
                        except ValueError: # e.g. if NaNs are present or other issues
                            pearson_corr_val = np.nan
                            spearman_corr_val = np.nan
                            
                    print(f"{method}: MAE={mae:.4f}, Pearson={pearson_corr_val:.4f}, Spearman={spearman_corr_val:.4f}")
                elif method != "Exact":
                    print(f"{method}: Could not compute metrics (scores missing or length mismatch).")

In [None]:
from scipy.stats import spearmanr, pearsonr

if exact_values is not None:
    print("Spearman rank correlation score:")
    if wss_values is not None:
        mae_wss,_ = spearmanr(exact_values, wss_values)
        print(f" WSS vs Exact: {mae_wss}")
    if tmc_values is not None:
        mae_tmc,_ = spearmanr(exact_values, tmc_values)
        print(f" TMC vs Exact: {mae_tmc}")
    if beta_dist and beta_u_values is not None:
        mae_beta_u,_ = spearmanr(exact_values, beta_u_values)
        print(f" BetaShap (U) vs Exact: {mae_beta_u}")
    # Compare LOO to Exact if available
    if exact_values is not None and loo_values is not None:
        mae_loo ,_= spearmanr(exact_values, loo_values)
        print(f" LOO vs Exact: {mae_loo}")

In [None]:
if exact_values is not None:
    print("Pearson correlation score:")
    if wss_values is not None:
        mae_wss,_ = pearsonr(exact_values, wss_values)
        print(f" WSS vs Exact: {mae_wss}")
    if tmc_values is not None:
        mae_tmc,_ = pearsonr(exact_values, tmc_values)
        print(f" TMC vs Exact: {mae_tmc}")
    if beta_dist and beta_u_values is not None:
        mae_beta_u,_ = pearsonr(exact_values, beta_u_values)
        print(f" BetaShap (U) vs Exact: {mae_beta_u}")
    # Compare LOO to Exact if available
    if exact_values is not None and loo_values is not None:
        mae_loo ,_= pearsonr(exact_values, loo_values)
        print(f" LOO vs Exact: {mae_loo}")

In [None]:
hub='ollama'
model="llama2"
respin=query_rag(query=query, retrieved_docs=docs, hub=hub, model=model)
print(respin)

In [None]:
from sklearn.metrics import precision_score, recall_score, f1_score
import numpy as np

def compute_precision_recall(k_max, query, query_idx):
    # Get initial documents and shapley scores
    docs, _ = retrieve_documents(query=query, k=k_max)
    shap = shapley_values(docs, query)
    
    # Select new query based on highest Shapley value
    new_query = docs[np.argmax(shap)]

    # Initialize metrics
    precision, recall, f1score = [], [], []
    precision_new, recall_new, f1score_new = [], [], []

    relevant_ids = set(df['relevant_passage_ids'][query_idx])

    for k in range(1, k_max):
        # Retrieve documents with both original and shap-reformulated queries
        docs_old, _ = retrieve_documents(query=query, k=k)
        docs_new, _ = retrieve_documents(query=new_query, k=k)

        # Convert passages to IDs
        try:
            retrieved_ids_old = {
                df1[df1['passage'] == passage]['id'].values.item() 
                for passage in docs_old 
                if not df1[df1['passage'] == passage]['id'].empty
            }
            retrieved_ids_new = {
                df1[df1['passage'] == passage]['id'].values.item() 
                for passage in docs_new 
                if not df1[df1['passage'] == passage]['id'].empty
            }
        except ValueError:
            # Handles cases where .item() fails due to multiple matches
            continue

        # Precision and Recall calculations
        tp_old = len(relevant_ids.intersection(retrieved_ids_old))
        tp_new = len(relevant_ids.intersection(retrieved_ids_new))

        p_old = tp_old / k
        r_old = tp_old / len(relevant_ids) if relevant_ids else 0

        p_new = tp_new / k
        r_new = tp_new / len(relevant_ids) if relevant_ids else 0

        # F1 score with zero division check
        f1_old = 2 * p_old * r_old / (p_old + r_old) if (p_old + r_old) > 0 else 0
        f1_new = 2 * p_new * r_new / (p_new + r_new) if (p_new + r_new) > 0 else 0

        # Append metrics
        precision.append(p_old)
        recall.append(r_old)
        f1score.append(f1_old)

        precision_new.append(p_new)
        recall_new.append(r_new)
        f1score_new.append(f1_new)

    return precision, recall, f1score, precision_new, recall_new, f1score_new


# Run evaluation on top-10 questions
precision, recall, f1score = [], [], []
precision_new, recall_new, f1score_new = [], [], []

for j, query in enumerate(df['question'][:10]):
    prec, rec, f1, prec_new, rec_new, f1_new = compute_precision_recall(5, query, j)
    precision.append(prec)
    recall.append(rec)
    f1score.append(f1)
    precision_new.append(prec_new)
    recall_new.append(rec_new)
    f1score_new.append(f1_new)


In [None]:
import matplotlib.pyplot as plt
import numpy as np

# Convert lists of lists to numpy arrays for easier manipulation
precision = np.array(precision)
recall = np.array(recall)
f1score = np.array(f1score)

precision_new = np.array(precision_new)
recall_new = np.array(recall_new)
f1score_new = np.array(f1score_new)

# Calculate average across queries
avg_precision = precision.mean(axis=0)
avg_recall = recall.mean(axis=0)
avg_f1score = f1score.mean(axis=0)

avg_precision_new = precision_new.mean(axis=0)
avg_recall_new = recall_new.mean(axis=0)
avg_f1score_new = f1score_new.mean(axis=0)

k_values = list(range(1, precision.shape[1] + 1))

# Plot
plt.figure(figsize=(15, 5))

plt.subplot(1, 3, 1)
plt.plot(k_values, avg_precision, label='Original', marker='o')
plt.plot(k_values, avg_precision_new, label='Shap Retriever', marker='s')
plt.title('Average Precision')
plt.xlabel('Top-k')
plt.ylabel('Precision')
plt.legend()
plt.grid(True)

plt.subplot(1, 3, 2)
plt.plot(k_values, avg_recall, label='Original', marker='o')
plt.plot(k_values, avg_recall_new, label='Shap Retriever', marker='s')
plt.title('Average Recall')
plt.xlabel('Top-k')
plt.ylabel('Recall')
plt.legend()
plt.grid(True)

plt.subplot(1, 3, 3)
plt.plot(k_values, avg_f1score, label='Original', marker='o')
plt.plot(k_values, avg_f1score_new, label='Shap Retriever', marker='s')
plt.title('Average F1 Score')
plt.xlabel('Top-k')
plt.ylabel('F1 Score')
plt.legend()
plt.grid(True)

plt.tight_layout()
plt.show()

In [None]:
for a, b in zip(scores[0], docs):
    print(f"{a}\t{b[:230]}\n")

In [None]:
for a, b in zip(scores[0], docs):
    print(f"{a}\t{b[:230]}\n")

In [None]:
hub='google'
model="publishers/meta/models/llama-3.3-70b-instruct-maas"
respin=query_rag(query=query, retrieved_docs=docs, hub=hub, model=model)
print(respin)

In [None]:
hub='google'
model="gemini-2.5-pro-exp-03-25"
resp=query_rag(query=query, retrieved_docs=docs, hub=hub, model=model)

In [None]:
print(resp)

In [None]:
cosine_similarity(normalize_embeddings(gen_embs(df[df['question']==query]['answer'].values[0], model='medemb').reshape(1, -1)), normalize_embeddings(gen_embs(resp1, model='medemb').reshape(1, -1)))

In [None]:
cosine_similarity(normalize_embeddings(gen_embs(df[df['question']==query]['answer'].values[0], model='medemb').reshape(1, -1)), normalize_embeddings(gen_embs(shapmax, model='medemb').reshape(1, -1)))

In [None]:
cosine_similarity(normalize_embeddings(gen_embs(df[df['question']==query]['answer'].values[0], model='medemb').reshape(1, -1)), normalize_embeddings(gen_embs(shapres, model='medemb').reshape(1, -1)))

In [None]:
shap=shapley_values(docs)

In [None]:
shap

In [None]:
shapres=ragshap(shap, retrival_type='reponse')

In [None]:
shapmax=ragshap(shap, retrival_type='max_shap')

In [None]:
shap.argmax()