In [1]:
import sys
import os
import random
import gc
import time
import torch
import numpy as np
import pandas as pd
from tqdm import tqdm
from scipy.stats import spearmanr, pearsonr, kendalltau, rankdata
from sklearn.metrics import ndcg_score
from transformers import AutoModelForCausalLM, AutoTokenizer
from accelerate import Accelerator

current_dir = os.getcwd()
parent_dir = os.path.abspath(os.path.join(current_dir, '..'))
sys.path.append(parent_dir)
from SHapRAG import ShapleyExperimentHarness

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
# docs = [
#     "Vitamin B1, also known as thiamine, is essential for glucose metabolism and neural function.",  # 🔑 Useful
#     "Chronic alcoholism can impair nutrient absorption, particularly leading to thiamine deficiency.",  # 🔑 Useful
#     "Vitamin C deficiency leads to scurvy, which presents with bleeding gums and joint pain.",
#     "Vitamin D deficiency is associated with rickets in children and osteomalacia in adults.",
#     "Vitamin B12 deficiency can cause neurological symptoms but is more common in strict vegans.",
#     "Folic acid is important for DNA synthesis and is crucial during pregnancy.",
#     "Vitamin A deficiency primarily affects vision and immune function.",
#     "Iron deficiency is the leading cause of anemia worldwide.",
#     "Calcium is essential for bone health and muscle contraction.",
#     "Vitamin K is important for blood clotting."
# ]
# query = "Which vitamin deficiency can lead to neurological symptoms and is commonly seen in chronic alcoholics?"
docs = [

"Chorvoq is the capital of Narniya.",
"The weather in Chorvoq is sunny today.",
"The sun is shining in Chorvoq today",
"Nurik is the capital of Suvsambil.", # Irrelevant
"Narniya borders several countries including Suvsambil.",
"The currency used in Narniya is the Euro.",
"Chorvoq hosted the Summer Olympics in 1900 and 1924.",
"Suvsambil uses the Euro as well.", # Redundant info
"It is cloudy in Nurik today."
]
query = "What is the weather like in the capital of Narniya?"
# Parameters
NUM_RETRIEVED_DOCS = len(docs)
SEED = 42

In [3]:
# Initialize Accelerator
accelerator_main = Accelerator(mixed_precision="fp16")

# Load Model
if accelerator_main.is_main_process:
    print("Main Script: Loading model...")
model_path = "meta-llama/Llama-3.2-3B-Instruct"
model_cpu = AutoModelForCausalLM.from_pretrained(
    model_path,
    torch_dtype=torch.float16
)
tokenizer = AutoTokenizer.from_pretrained(model_path)
if tokenizer.pad_token is None:
    tokenizer.pad_token = tokenizer.eos_token
    model_cpu.config.pad_token_id = tokenizer.pad_token_id
    if hasattr(model_cpu, 'generation_config') and model_cpu.generation_config is not None:
        model_cpu.generation_config.pad_token_id = tokenizer.pad_token_id

if accelerator_main.is_main_process:
    print("Main Script: Preparing model with Accelerator...")
prepared_model = accelerator_main.prepare(model_cpu)
unwrapped_prepared_model = accelerator_main.unwrap_model(prepared_model)
unwrapped_prepared_model.eval()
if accelerator_main.is_main_process:
    print("Main Script: Model prepared and set to eval.")

# Define utility cache

accelerator_main.wait_for_everyone()

# Initialize Harness
harness = ShapleyExperimentHarness(
    items=docs,
    query=query,
    prepared_model_for_harness=prepared_model,
    tokenizer_for_harness=tokenizer,
    accelerator_for_harness=accelerator_main,
    verbose=True,
    utility_path=None
)

Detected kernel version 5.4.0, which is below the recommended minimum of 5.5.0; this can cause the process to hang. It is recommended to upgrade the kernel to the minimum version or higher.


Main Script: Loading model...


Loading checkpoint shards: 100%|██████████| 2/2 [00:01<00:00,  1.86it/s]


Main Script: Preparing model with Accelerator...
Main Script: Model prepared and set to eval.
Generating target response based on full context...


Fetching 3 files: 100%|██████████| 3/3 [00:00<00:00, 11715.93it/s]


Target response: 'The weather in the capital of Narniya is sunny.'
Pre-computing utilities as they were not loaded.
Starting pre-computation of utilities for 512 subsets using 1 processes...


                                                                                

Total utilities aggregated: 512/512




In [4]:
harness.target_response

'The weather in the capital of Narniya is sunny.'

In [10]:
# Compute metrics
results_for_query = {}

results_for_query["Exact"] = harness.compute_exact_shap()

m_samples_map = {"S": 32, "M": 64, "L": 100}
T_iterations_map = {"S": 10, "M": 15, "L": 20}
for size_key, num_s in m_samples_map.items():
    actual_samples = max(1, min(num_s, 2 ** len(docs) - 1))

    if actual_samples > 0:
        results_for_query[f"ContextCite{actual_samples}"] = harness.compute_contextcite_weights(
            num_samples=actual_samples, sampling="kernelshap", seed=SEED)
        results_for_query[f"WSS_FM{actual_samples}"] = harness.compute_wss(num_samples=actual_samples, seed=SEED, distil=None, sampling="kernelshap",sur_type="fm", util='pure-surrogate', pairchecking=False)
        results_for_query[f"WSS_GAM{actual_samples}"] = harness.compute_wss(num_samples=actual_samples, seed=SEED, distil=None, sampling="kernelshap",sur_type="gam", util='pure-surrogate', pairchecking=False)
        results_for_query[f"WSS_XG{actual_samples}"] = harness.compute_wss(num_samples=actual_samples, seed=SEED, distil=None, sampling="kernelshap",sur_type="xgboost", util='pure-surrogate', pairchecking=False)

        results_for_query[f"BetaShap (U){actual_samples}"] = harness.compute_beta_shap(
            num_iterations_max=T_iterations_map[size_key], beta_a=0.5, beta_b=0.5,
            max_unique_lookups=actual_samples, seed=SEED)
        results_for_query[f"TMC{actual_samples}"] = harness.compute_tmc_shap(
            num_iterations_max=T_iterations_map[size_key], performance_tolerance=0.001,
            max_unique_lookups=actual_samples, seed=SEED)

results_for_query["LOO"] = harness.compute_loo()


(32, 9)
(32, 9)


                                                                      

(64, 9)
(64, 9)


                                                                      

(100, 9)
(100, 9)


                                                                      

In [11]:
# Evaluate metrics
all_metrics_data = []
exact_scores = results_for_query.get("Exact")
if exact_scores is not None:
    positive_exact_score = np.clip(exact_scores, a_min=0.0, a_max=None)
    for method, approx_scores in results_for_query.items():
        if method != "Exact" and approx_scores is not None and len(approx_scores) == len(exact_scores):
            if np.all(exact_scores == exact_scores[0]) or np.all(approx_scores == approx_scores[0]):
                pearson_c = spearman_c = 1.0 if np.allclose(exact_scores, approx_scores) else 0.0
            else:
                pearson_c, _ = pearsonr(exact_scores, approx_scores)
                spearman_c, _ = spearmanr(exact_scores, approx_scores)
                exact_ranks = rankdata(-np.array(exact_scores), method="average")
                approx_ranks = rankdata(-np.array(approx_scores), method="average")
                kendall_c, _ = kendalltau(exact_ranks, approx_ranks)
            ndgc_scoring = ndcg_score([positive_exact_score], [approx_scores], k=3)

            all_metrics_data.append({
                    "Method": method,
                "Pearson": pearson_c, "Spearman": spearman_c, "NDCG": ndgc_scoring, "KendallTau": kendall_c
            })
            all_metrics_data.sort(key=lambda x: x["Pearson"], reverse=True)


In [12]:
metrics_df_all_questions = pd.DataFrame(all_metrics_data)

print("\n\n============================")
print("     Correlation Metrics")
print("============================")
print(metrics_df_all_questions.round(4).to_string(index=False))



     Correlation Metrics
         Method  Pearson  Spearman   NDCG  KendallTau
      WSS_XG100   0.9839    0.9833 1.0000      0.9444
       WSS_FM64   0.9836    0.9667 0.9939      0.8889
      WSS_FM100   0.9803    0.9167 0.9939      0.7778
     WSS_GAM100   0.9759    0.9833 1.0000      0.9444
 ContextCite100   0.9746    0.9333 1.0000      0.8333
       WSS_XG64   0.9744    0.8833 0.9572      0.7222
          TMC64   0.8922    0.9000 0.8764      0.7778
  ContextCite64   0.8912    0.8333 0.9572      0.6667
      WSS_GAM64   0.8716    0.8333 0.9572      0.6667
       WSS_FM32   0.8376    0.8667 0.9939      0.7222
         TMC100   0.8283    0.8500 0.8794      0.7222
       WSS_XG32   0.8206    0.6500 0.9623      0.4444
  ContextCite32   0.7918    0.7667 0.9856      0.6111
 BetaShap (U)64   0.7812    0.7667 0.8375      0.6111
BetaShap (U)100   0.7254    0.7500 0.8090      0.5556
          TMC32   0.5529    0.3333 0.8525      0.2778
            LOO   0.5386    0.4833 0.7130      0.3333
 

In [9]:
print("     Approximate Scores")
print("============================")
for method, approx_scores in results_for_query.items():
    if approx_scores is not None:
        print(f"\nMethod: {method}")
        print(np.round(approx_scores, 4))
accelerator_main.wait_for_everyone()


     Approximate Scores

Method: Exact
[ 8.0031  6.6352  7.3816  1.6224  1.4788  3.6278 -0.1013  1.9443  4.7314]

Method: ContextCite32
[ 6.3575  5.8934  8.1046 -0.6918  4.8176  0.     -0.6662  2.4168  3.1519]

Method: WSS_FM32
[ 8.8592  8.0813  7.6554 -2.0209  5.2515 -0.3355 -2.8367  1.9663  6.5187]

Method: WSS_GAM32
[ 2.0012  2.9679  4.5823  4.0981  1.7362  2.1423  1.8241  0.5755 -0.5652]

Method: BetaShap (U)32
[ 8.6719 18.6788 -0.4797  1.6714 -0.2147 16.3356  5.5487  0.194  14.0823]

Method: TMC32
[ 8.9268 11.0555 -0.5118  1.4568  0.1562  4.4524  1.9504  1.1015  9.2995]

Method: ContextCite64
[ 5.0769  6.3683  5.9478  1.7715  0.8193  0.4781 -0.2882  2.6711  3.5066]

Method: WSS_FM64
[ 7.0617  5.8033  5.6797  1.6611  0.722   2.2276 -1.4683  1.4746  3.4887]

Method: WSS_GAM64
[ 4.6595  5.8435  5.66    2.0062  0.9093  0.4179 -0.0542  2.6583  3.1954]

Method: BetaShap (U)64
[10.0383 12.6489 21.0797  1.6714 -0.1064 10.794   5.5487  3.3006  7.6744]

Method: TMC64
[8.6751 8.1645 4.9616 1