In [1]:
import pandas as pd
import re
import pickle
import requests
import numpy as np
from tqdm import tqdm
import faiss
from sklearn.metrics.pairwise import cosine_similarity
from scipy.stats import spearmanr
from itertools import combinations
import math
from sentence_transformers import SentenceTransformer
from google.oauth2 import service_account
import vertexai
from vertexai.generative_models import GenerativeModel, Image, Part
tqdm.pandas()

from rag_shap import*
vertexai.init(
    project="oag-ai",
    credentials=service_account.Credentials.from_service_account_file("google-credentials.json"),
)

import pandas as pd

splits = {'train': 'question-answer-passages/train-00000-of-00001.parquet', 'test': 'question-answer-passages/test-00000-of-00001.parquet'}
df = pd.read_parquet("hf://datasets/enelpol/rag-mini-bioasq/" + splits["train"])

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
df1 = pd.read_parquet("hf://datasets/enelpol/rag-mini-bioasq/text-corpus/test-00000-of-00001.parquet")
df1['passage']=df1['passage'].str.replace(r'[\n]', ' ', regex=True)
df['question']=df['question'].str.replace(r'[\n]', ' ', regex=True)

In [3]:
def gen_embs(qtext, model="nomic"):
    if model=="nomic":
    
        data = {
            "model": "nomic-embed-text",
            "prompt": qtext
        }
        return np.array(requests.post('http://localhost:11434/api/embeddings', json=data).json()['embedding'])
    else:
        return SentenceTransformer("abhinand/MedEmbed-large-v0.1").encode([qtext], convert_to_numpy=True)
    
# df1['embedding'] = df1['passage'].progress_apply(lambda x: gen_embs(x, model='medemb'))

# # Save the embeddings to a pickle file
# with open('embed_bioasq_medemb.pkl', 'wb') as f:
#     pickle.dump(df1['embedding'].tolist(), f)
# print("Embeddings saved to embed_bioasq_medemb.pkl")

In [4]:
with open("data/embed_bioasq_medemb.pkl", "rb") as f:
    embeddings = pickle.load(f)

In [5]:
def normalize_embeddings(embeddings):
    return embeddings / np.linalg.norm(embeddings, axis=1, keepdims=True)

embeddings = normalize_embeddings(embeddings)

def index_documents(method="faiss", index_name="recipes_nomic", es_host="http://localhost:9200"):
    if method == "faiss":
        dimension = embeddings.shape[1]
        index = faiss.IndexFlatL2(dimension)
        index.add(embeddings)
        faiss.write_index(index, "data/bioasq_nomic_faiss.index")
        print("FAISS index saved.")
        return index
    elif method == "elasticsearch":
        es = Elasticsearch(es_host)
        mapping = {"mappings": {"properties": {"text": {"type": "text"}, "vector": {"type": "dense_vector", "dims": embeddings.shape[1]}}}}
        es.indices.create(index=index_name, body=mapping, ignore=400)
        for i, (text, vector) in enumerate(zip(documents, embeddings)):
            es.index(index=index_name, id=i, body={"text": text, "vector": vector.tolist()})
        print("Elasticsearch index created.")
        return es
# index_documents(method="faiss", index_name="bioasq_nomic_faiss", es_host="http://localhost:9200")
faiss_index = faiss.read_index("data/medemb_bioasq_faiss.index")


In [6]:
def retrieve_documents(query, k=5):
    query_embedding = gen_embs(query, model="medemb")
    query_embedding = normalize_embeddings(query_embedding.reshape(1, -1))
    scores, indices = faiss_index.search(query_embedding, k)
    return [df1['passage'][i] for i in indices[0]], scores

In [7]:
query='What is the implication of histone lysine methylation in medulloblastoma?'
docs, scores=retrieve_documents(query=query, k=5)

In [8]:
docs

['We used high-resolution SNP genotyping to identify regions of genomic gain and  loss in the genomes of 212 medulloblastomas, malignant pediatric brain tumors.  We found focal amplifications of 15 known oncogenes and focal deletions of 20  known tumor suppressor genes (TSG), most not previously implicated in  medulloblastoma. Notably, we identified previously unknown amplifications and  homozygous deletions, including recurrent, mutually exclusive, highly focal  genetic events in genes targeting histone lysine methylation, particularly that  of histone 3, lysine 9 (H3K9). Post-translational modification of histone  proteins is critical for regulation of gene expression, can participate in  determination of stem cell fates and has been implicated in carcinogenesis.  Consistent with our genetic data, restoration of expression of genes controlling  H3K9 methylation greatly diminishes proliferation of medulloblastoma in vitro.  Copy number aberrations of genes with critical roles in writi

In [9]:
df.answer[0]

'Aberrant patterns of H3K4, H3K9, and H3K27 histone lysine methylation were shown to result in histone code alterations, which induce changes in gene expression, and affect the proliferation rate of cells in medulloblastoma.'

In [10]:
compute_logprob(query, df.answer[0], context="".join(docs), response=True)

The attention mask is not set and cannot be inferred from input because pad token is same as eos token. As a consequence, you may observe unexpected behavior. Please pass your input's `attention_mask` to obtain reliable results.


'Histone lysine methylation is a crucial epigenetic event in medulloblastoma, which contributes to the'

In [None]:
from sklearn.metrics.pairwise import cosine_similarity
cosine_similarity()

In [11]:
shap_calculator = ShapRAG(
    docs=docs,
    query=query,
    target_response='Histone lysine methylation is a crucial epigenetic event in medulloblastoma, which contributes to the pathogenesis of this disease. Recently, biochemical techniques have enabled researchers to study the histone modifications more easily and accurately. One of these modifications, histone lysine methylation, has been shown to be highly stable and to represent an epigenetic alteration. Extensive biochemical analyses',
    llm_caller=compute_logprob
)

# Compute Shapley values
n_actual = len(docs)
num_samples_to_run = 4 # As requested for n=10 scenario

shapley_values, lasso_weights = shap_calculator.compute_shapley_values(
    num_samples=num_samples_to_run,
    lasso_alpha=0.01,
    return_weights=True
)

print("\n--- Results ---")
print("LASSO Weights (Direct Attribution from Surrogate):")
for i, w in enumerate(lasso_weights):
    print(f"  Doc {i}: {w:.4f}")

print("\nShapley Values (Calculated from HYBRID Utilities):")
for i, s_val in enumerate(shapley_values):
    print(f"  Doc {i}: {s_val:.4f}")

Starting ShapRAG computation for 5 documents...


                                                         

Calculating true utilities for 4 sampled subsets...


LLM Calls: 100%|██████████| 4/4 [31:12<00:00, 468.10s/it]


Training surrogate linear model (LASSO)...
Surrogate Model Weights (w): [ 39.2358 116.5072  -0.       0.      -9.1505]
Surrogate Model Intercept (b): -175.5774
Building hybrid utility set for all 32 subsets...


                                                                 

Calculating exact Shapley values from hybrid utilities...


                                                          

ShapRAG computation finished.

--- Results ---
LASSO Weights (Direct Attribution from Surrogate):
  Doc 0: 39.2358
  Doc 1: 116.5072
  Doc 2: -0.0000
  Doc 3: 0.0000
  Doc 4: -9.1505

Shapley Values (Calculated from HYBRID Utilities):
  Doc 0: 39.2381
  Doc 1: 116.5099
  Doc 2: -0.0035
  Doc 3: 0.0023
  Doc 4: -9.1544




In [12]:
correlation, p_value = spearmanr(shapley_values, lasso_weights)
print(f" Spearman Correlation: {correlation:.4f}, p-value: {p_value:.4f}\n")


 Spearman Correlation: 0.9747, p-value: 0.0048



In [None]:
hub='ollama'
model="llama2"
respin=query_rag(query=query, retrieved_docs=docs, hub=hub, model=model)
print(respin)

In [None]:
from sklearn.metrics import precision_score, recall_score, f1_score
import numpy as np

def compute_precision_recall(k_max, query, query_idx):
    # Get initial documents and shapley scores
    docs, _ = retrieve_documents(query=query, k=k_max)
    shap = shapley_values(docs, query)
    
    # Select new query based on highest Shapley value
    new_query = docs[np.argmax(shap)]

    # Initialize metrics
    precision, recall, f1score = [], [], []
    precision_new, recall_new, f1score_new = [], [], []

    relevant_ids = set(df['relevant_passage_ids'][query_idx])

    for k in range(1, k_max):
        # Retrieve documents with both original and shap-reformulated queries
        docs_old, _ = retrieve_documents(query=query, k=k)
        docs_new, _ = retrieve_documents(query=new_query, k=k)

        # Convert passages to IDs
        try:
            retrieved_ids_old = {
                df1[df1['passage'] == passage]['id'].values.item() 
                for passage in docs_old 
                if not df1[df1['passage'] == passage]['id'].empty
            }
            retrieved_ids_new = {
                df1[df1['passage'] == passage]['id'].values.item() 
                for passage in docs_new 
                if not df1[df1['passage'] == passage]['id'].empty
            }
        except ValueError:
            # Handles cases where .item() fails due to multiple matches
            continue

        # Precision and Recall calculations
        tp_old = len(relevant_ids.intersection(retrieved_ids_old))
        tp_new = len(relevant_ids.intersection(retrieved_ids_new))

        p_old = tp_old / k
        r_old = tp_old / len(relevant_ids) if relevant_ids else 0

        p_new = tp_new / k
        r_new = tp_new / len(relevant_ids) if relevant_ids else 0

        # F1 score with zero division check
        f1_old = 2 * p_old * r_old / (p_old + r_old) if (p_old + r_old) > 0 else 0
        f1_new = 2 * p_new * r_new / (p_new + r_new) if (p_new + r_new) > 0 else 0

        # Append metrics
        precision.append(p_old)
        recall.append(r_old)
        f1score.append(f1_old)

        precision_new.append(p_new)
        recall_new.append(r_new)
        f1score_new.append(f1_new)

    return precision, recall, f1score, precision_new, recall_new, f1score_new


# Run evaluation on top-10 questions
precision, recall, f1score = [], [], []
precision_new, recall_new, f1score_new = [], [], []

for j, query in enumerate(df['question'][:10]):
    prec, rec, f1, prec_new, rec_new, f1_new = compute_precision_recall(5, query, j)
    precision.append(prec)
    recall.append(rec)
    f1score.append(f1)
    precision_new.append(prec_new)
    recall_new.append(rec_new)
    f1score_new.append(f1_new)


In [None]:
import matplotlib.pyplot as plt
import numpy as np

# Convert lists of lists to numpy arrays for easier manipulation
precision = np.array(precision)
recall = np.array(recall)
f1score = np.array(f1score)

precision_new = np.array(precision_new)
recall_new = np.array(recall_new)
f1score_new = np.array(f1score_new)

# Calculate average across queries
avg_precision = precision.mean(axis=0)
avg_recall = recall.mean(axis=0)
avg_f1score = f1score.mean(axis=0)

avg_precision_new = precision_new.mean(axis=0)
avg_recall_new = recall_new.mean(axis=0)
avg_f1score_new = f1score_new.mean(axis=0)

k_values = list(range(1, precision.shape[1] + 1))

# Plot
plt.figure(figsize=(15, 5))

plt.subplot(1, 3, 1)
plt.plot(k_values, avg_precision, label='Original', marker='o')
plt.plot(k_values, avg_precision_new, label='Shap Retriever', marker='s')
plt.title('Average Precision')
plt.xlabel('Top-k')
plt.ylabel('Precision')
plt.legend()
plt.grid(True)

plt.subplot(1, 3, 2)
plt.plot(k_values, avg_recall, label='Original', marker='o')
plt.plot(k_values, avg_recall_new, label='Shap Retriever', marker='s')
plt.title('Average Recall')
plt.xlabel('Top-k')
plt.ylabel('Recall')
plt.legend()
plt.grid(True)

plt.subplot(1, 3, 3)
plt.plot(k_values, avg_f1score, label='Original', marker='o')
plt.plot(k_values, avg_f1score_new, label='Shap Retriever', marker='s')
plt.title('Average F1 Score')
plt.xlabel('Top-k')
plt.ylabel('F1 Score')
plt.legend()
plt.grid(True)

plt.tight_layout()
plt.show()

In [None]:
for a, b in zip(scores[0], docs):
    print(f"{a}\t{b[:230]}\n")

In [None]:
for a, b in zip(scores[0], docs):
    print(f"{a}\t{b[:230]}\n")

In [None]:
hub='google'
model="publishers/meta/models/llama-3.3-70b-instruct-maas"
respin=query_rag(query=query, retrieved_docs=docs, hub=hub, model=model)
print(respin)

In [None]:
hub='google'
model="gemini-2.5-pro-exp-03-25"
resp=query_rag(query=query, retrieved_docs=docs, hub=hub, model=model)

In [None]:
print(resp)

In [None]:
cosine_similarity(normalize_embeddings(gen_embs(df[df['question']==query]['answer'].values[0], model='medemb').reshape(1, -1)), normalize_embeddings(gen_embs(resp1, model='medemb').reshape(1, -1)))

In [None]:
cosine_similarity(normalize_embeddings(gen_embs(df[df['question']==query]['answer'].values[0], model='medemb').reshape(1, -1)), normalize_embeddings(gen_embs(shapmax, model='medemb').reshape(1, -1)))

In [None]:
cosine_similarity(normalize_embeddings(gen_embs(df[df['question']==query]['answer'].values[0], model='medemb').reshape(1, -1)), normalize_embeddings(gen_embs(shapres, model='medemb').reshape(1, -1)))

In [None]:
shap=shapley_values(docs)

In [None]:
shap

In [None]:
shapres=ragshap(shap, retrival_type='reponse')

In [None]:
shapmax=ragshap(shap, retrival_type='max_shap')

In [None]:
shap.argmax()