In [None]:
import sys
import os
import random
import gc
import time
import torch
import numpy as np
import pandas as pd
from tqdm import tqdm
from scipy.stats import spearmanr, pearsonr, kendalltau, rankdata
from sklearn.metrics import ndcg_score
from transformers import AutoModelForCausalLM, AutoTokenizer
from accelerate import Accelerator

current_dir = os.getcwd()
parent_dir = os.path.abspath(os.path.join(current_dir, '..'))
sys.path.append(parent_dir)
from SHapRAG import *

In [None]:
df= pd.read_csv("../data/synthetic_data/20_synergy.csv",index_col=False)
# df= pd.read_csv("../data/complementary.csv")

In [None]:
# docs = [
#     "Vitamin B1, also known as thiamine, is essential for glucose metabolism and neural function.",  #
#     "Chronic alcoholism can impair nutrient absorption, particularly leading to thiamine deficiency.",  # 
#     "Vitamin C deficiency leads to scurvy, which presents with bleeding gums and joint pain.",
#     "Vitamin D deficiency is associated with rickets in children and osteomalacia in adults.",
#     "Vitamin B12 deficiency can cause neurological symptoms but is more common in strict vegans.",
#     "Folic acid is important for DNA synthesis and is crucial during pregnancy.",
#     "Vitamin A deficiency primarily affects vision and immune function.",
#     "Iron deficiency is the leading cause of anemia worldwide.",
#     "Calcium is essential for bone health and muscle contraction.",
#     "Vitamin K is important for blood clotting."
# ]
# query = "Which vitamin deficiency can lead to neurological symptoms and is commonly seen in chronic alcoholics?"
docs = [
"The sun is shining in Galaba today",
"Nurik is the capital of Suvsambil.", # Irrelevant
"Narniya borders several countries including Suvsambil.",
"The currency used in Narniya is the Euro.",
"The weather in Chorvoq is stormy today.",
"Chorvoq is the capital of Narniya.",
"Chorvoq hosted the Summer Olympics in 1900 and 1924.",
"Suvsambil uses the Euro as well.", # Redundant info
"It is cloudy in Nurik today."
]
query = "What is the weather like in the capital of Narniya?"
# Parameters
NUM_RETRIEVED_DOCS = len(docs)


In [None]:
NUM_RETRIEVED_DOCS

In [None]:
SEED = 42
# Initialize Accelerator
accelerator_main = Accelerator(mixed_precision="fp16")

# Load Model
if accelerator_main.is_main_process:
    print("Main Script: Loading model...")
# model_path = "mistralai/Mistral-7B-Instruct-v0.3"
# model_path = "meta-llama/Llama-3.1-8B-Instruct"
model_path = "Qwen/Qwen2.5-3B-Instruct"

model_cpu = AutoModelForCausalLM.from_pretrained(
    model_path,
    torch_dtype=torch.float16
)
tokenizer = AutoTokenizer.from_pretrained(model_path)
if tokenizer.pad_token is None:
    tokenizer.pad_token = tokenizer.eos_token
    model_cpu.config.pad_token_id = tokenizer.pad_token_id
    if hasattr(model_cpu, 'generation_config') and model_cpu.generation_config is not None:
        model_cpu.generation_config.pad_token_id = tokenizer.pad_token_id

if accelerator_main.is_main_process:
    print("Main Script: Preparing model with Accelerator...")
prepared_model = accelerator_main.prepare(model_cpu)
unwrapped_prepared_model = accelerator_main.unwrap_model(prepared_model)
unwrapped_prepared_model.eval()
if accelerator_main.is_main_process:
    print("Main Script: Model prepared and set to eval.")

# Define utility cache

accelerator_main.wait_for_everyone()

In [None]:
num_questions_to_run=1
all_metrics_data = []
all_results=[]
for i in tqdm(range(num_questions_to_run), desc="Processing Questions", disable=not accelerator_main.is_main_process):
    query = df.question[i]
    if accelerator_main.is_main_process:
        print(f"\n--- Question {i+1}/{num_questions_to_run}: {query[:60]}... ---")

    docs=df.context[i]
    # Initialize Harness
    harness = ShapleyExperimentHarness(
        items=docs,
        query=query,
        prepared_model_for_harness=prepared_model,
        tokenizer_for_harness=tokenizer,
        accelerator_for_harness=accelerator_main,
        verbose=True,
        utility_path=None
    )
    # Compute metrics
    results_for_query = {}

    if accelerator_main.is_main_process:
        results_for_query["Exact"] = harness.compute_exact_shap()

        m_samples_map = {"S": 32, "M": 64, "L": 100} 
        T_iterations_map = {"S": 5, "M": 10, "L":20} 

        for size_key, num_s in m_samples_map.items():
            if 2**len(docs) < num_s and size_key != "L":
                actual_samples = max(1, 2**len(docs)-1 if 2**len(docs)>0 else 1)
            else:
                actual_samples = num_s

            if actual_samples > 0: 
                results_for_query[f"ContextCite{actual_samples}"] = harness.compute_contextcite_weights(num_samples=actual_samples, sampling="kernelshap", seed=SEED)
                
                # results_for_query[f"WSS_BGAM{actual_samples}"] = harness.compute_wss(num_samples=actual_samples, seed=SEED, distil=None, sampling="kernelshap",sur_type="boosted_gam", util='pure-surrogate', pairchecking=False)
                results_for_query[f"WSS_FM{actual_samples}"], F = harness.compute_wss(num_samples=actual_samples, seed=SEED, distil=None, sampling="kernelshap",sur_type="fm", util='pure-surrogate', pairchecking=False)
                # results_for_query[f"WSS_XGB{actual_samples}"] = harness.compute_wss(num_samples=actual_samples, seed=SEED, distil=None, sampling="kernelshap",sur_type="xgboost", util='pure-surrogate', pairchecking=False)
                results_for_query[f"BetaShap (U){actual_samples}"] = harness.compute_beta_shap(num_iterations_max=T_iterations_map[size_key], beta_a=0.5, beta_b=0.5, max_unique_lookups=actual_samples, seed=SEED)
                results_for_query[f"TMC{actual_samples}"] = harness.compute_tmc_shap(num_iterations_max=T_iterations_map[size_key], performance_tolerance=0.001, max_unique_lookups=actual_samples, seed=SEED)

        results_for_query["LOO"] = harness.compute_loo()

        exact_scores = results_for_query.get("Exact")
        all_results.append(results_for_query)
        if exact_scores is not None:
            positive_exact_score = np.clip(exact_scores, a_min=0.0, a_max=None) # FOR NDGC SCORE COMPUTATION
            for method, approx_scores in results_for_query.items():
                if method != "Exact" and approx_scores is not None:
                    if len(approx_scores) == len(exact_scores):
                        if np.all(exact_scores == exact_scores[0]) or np.all(approx_scores == approx_scores[0]):
                            pearson_c = 1.0 if np.allclose(exact_scores, approx_scores) else 0.0
                            spearman_c = 1.0 if np.allclose(exact_scores, approx_scores) else 0.0
                        else:
                            pearson_c, _ = pearsonr(exact_scores, approx_scores)
                            spearman_c, _ = spearmanr(exact_scores, approx_scores)
                            exact_ranks = rankdata(-np.array(exact_scores), method="average") # rank scores with the smallest =1 and when there is a tie assign the average rank
                            approx_ranks = rankdata(-np.array(approx_scores), method = "average")
                            kendall_c, _ = kendalltau(exact_ranks, approx_ranks) # return tau and pval (if pval is < 0.005 we can say that correlation is statistically significant) 
                        ndgc_scoring  = ndcg_score(
                            [positive_exact_score], 
                            [approx_scores],
                            k = 3 # focus on top k document scoring
                        )
                        
                        all_metrics_data.append({
                            "Question_Index": i, "Query": query, "Method": method,
                            "Pearson": pearson_c, "Spearman": spearman_c, "NDCG" : ndgc_scoring, "KendallTau" : kendall_c,
                        })
                    else:
                        print(f"    Score length mismatch for method {method} (Exact: {len(exact_scores)}, Approx: {len(approx_scores)}). Skipping metrics.")
        else:
            print(f"    Skipping metric calculation for Q{i} as Exact Shapley was not computed or failed.")
    
    accelerator_main.wait_for_everyone() 
   
    if torch.cuda.is_available():
        if accelerator_main.is_main_process: # Print from one process
            print(f"Attempting to empty CUDA cache on rank {accelerator_main.process_index} after Q{i}")
        torch.cuda.empty_cache()
        gc.collect()
        if accelerator_main.is_main_process:
            print(f"CUDA cache empty attempt complete on rank {accelerator_main.process_index}.")
    accelerator_main.wait_for_everyone()


if accelerator_main.is_main_process:
    if all_metrics_data:
        metrics_df_all_questions = pd.DataFrame(all_metrics_data)
        print("\n\n--- Average Correlation Metrics Across All Questions ---")
        average_metrics = metrics_df_all_questions.groupby("Method").agg(
            Avg_Pearson=("Pearson", "mean"),
            Avg_Spearman=("Spearman", "mean"),
            Avg_Kendall =("KendallTau", "mean"),
            Avg_NDCG = ("NDCG", "mean"),
            Num_Valid_Queries=("Query", "nunique")
        ).sort_values(by="Avg_Pearson", ascending=False)
        
        print(average_metrics.round(4))
    else:
        print("\nNo metrics were collected. This might be due to all calculations failing or only non-main processes running sections.")

# Final synchronization before script ends
accelerator_main.wait_for_everyone()
if accelerator_main.is_main_process:
    print("Script finished.")

if torch.distributed.is_available() and torch.distributed.is_initialized():
    if accelerator_main.is_local_main_process:
        print(f"Rank {accelerator_main.process_index} (Local Main): Manually destroying process group...")
    torch.distributed.destroy_process_group()
    if accelerator_main.is_local_main_process:
        print(f"Rank {accelerator_main.process_index} (Local Main): Process group destroyed.")
else:
    if accelerator_main.is_local_main_process:
        print(f"Rank {accelerator_main.process_index} (Local Main): Distributed environment not initialized or not available, skipping destroy_process_group.")

if accelerator_main.is_main_process:
    print("Script fully exited.")

In [None]:
harness.target_response

In [None]:
harness.compute_exhaustive_top_k(2)

In [None]:
# Evaluate metrics
all_metrics_data = []
exact_scores = results_for_query.get("Exact")
if exact_scores is not None:
    positive_exact_score = np.clip(exact_scores, a_min=0.0, a_max=None)
    for method, approx_scores in results_for_query.items():
        if method != "Exact" and approx_scores is not None and len(approx_scores) == len(exact_scores):
            if np.all(exact_scores == exact_scores[0]) or np.all(approx_scores == approx_scores[0]):
                pearson_c = spearman_c = 1.0 if np.allclose(exact_scores, approx_scores) else 0.0
            else:
                pearson_c, _ = pearsonr(exact_scores, approx_scores)
                spearman_c, _ = spearmanr(exact_scores, approx_scores)
                exact_ranks = rankdata(-np.array(exact_scores), method="average")
                approx_ranks = rankdata(-np.array(approx_scores), method="average")
                kendall_c, _ = kendalltau(exact_ranks, approx_ranks)
            ndgc_scoring = ndcg_score([positive_exact_score], [approx_scores], k=3)

            all_metrics_data.append({
                    "Method": method,
                "Pearson": pearson_c, "Spearman": spearman_c, "NDCG": ndgc_scoring, "KendallTau": kendall_c
            })
            all_metrics_data.sort(key=lambda x: x["Pearson"], reverse=True)


In [None]:
metrics_df_all_questions = pd.DataFrame(all_metrics_data)

print("\n\n============================")
print("     Correlation Metrics")
print("============================")
print(metrics_df_all_questions.round(4).to_string(index=False))

In [None]:
print("     Approximate Scores")
print("============================")
for i in all_results:
    for method, approx_scores in i.items():
        if approx_scores is not None:
            print(f"\nMethod: {method}")
            print(np.round(approx_scores, 4))
    accelerator_main.wait_for_everyone()
