In [2]:
import sys
import os
current_dir = os.getcwd()
parent_dir = os.path.abspath(os.path.join(current_dir, '..'))
sys.path.append(parent_dir)

from SHapRAG import*
import pandas as pd
from scipy.stats import pearsonr
import time

In [3]:
documents = [
    "The weather in Paris is sunny today.",
    "Paris is the capital of France.",
    "The sun is shining in Paris today",
    "Berlin is the capital of Germany.", # Irrelevant
    # "The Eiffel Tower is located in Paris, France.",
    # "France borders several countries including Germany.",
    # "The currency used in France is the Euro.",
    # "Paris hosted the Summer Olympics in 1900 and 1924.",
    # "Germany uses the Euro as well.", # Redundant info
    "It is cloudy in Berlin today." # Irrelevant
]
query = "What is the weather like in the capital of France?"
target_response = "Paris is sunny." # The ideal answer fragment

In [5]:
compute_logprob(query=query, context=documents, ground_truth_answer=target_response, response=True)

Loading checkpoint shards: 100%|██████████| 4/4 [00:01<00:00,  2.88it/s]
The attention mask is not set and cannot be inferred from input because pad token is same as eos token. As a consequence, you may observe unexpected behavior. Please pass your input's `attention_mask` to obtain reliable results.
Fetching 3 files: 100%|██████████| 3/3 [00:00<00:00, 12348.29it/s]


'Based on the given context, the weather in Paris is sunny today. The context explicitly states "The weather in Paris is sunny today." and also mentions'

In [None]:


attributor = ShapleyAttributor(
        items=documents,
        query=query,
        target_response=target_response,
        llm_caller=compute_logprob # Use your actual llm_caller
    )

print("\n--- Computing ContextCite (Surrogate Weights) ---")
cc_weights = attributor.compute(method_name="contextcite", num_samples=16, lasso_alpha=0.01)
if cc_weights is not None: print("ContextCite Weights:", np.round(cc_weights, 4))

print("\n--- Computing Weakly Supervised Shapley (WSS) ---")
wss_values, wss_surrogate_weights = attributor.compute(method_name="wss", num_samples=16, lasso_alpha=0.01, return_weights=True)
if wss_values is not None:
    print("WSS Values:", np.round(wss_values, 4))
    print("WSS Surrogate Weights:", np.round(wss_surrogate_weights, 4))

print("\n--- Computing TMC-Shapley ---")
tmc_values = attributor.compute(method_name="tmc") # 5*n iterations
if tmc_values is not None: print("TMC Values:", np.round(tmc_values, 4))

if beta_dist: # Only if scipy is available
    print("\n--- Computing Beta-Shapley (U-shaped, more weight to ends) ---")
    # For U-shaped, alpha and beta < 1
    beta_u_values = attributor.compute(method_name="betashap", num_iterations=attributor.n_items, beta_a=0.5, beta_b=0.5)
    if beta_u_values is not None: print("BetaShap (U-shaped) Values:", np.round(beta_u_values, 4))

    print("\n--- Computing Beta-Shapley (Uniform, equivalent to standard MC) ---")
    beta_uniform_values = attributor.compute(method_name="betashap", num_iterations=attributor.n_items, beta_a=1.0, beta_b=1.0)
    if beta_uniform_values is not None: print("BetaShap (Uniform) Values:", np.round(beta_uniform_values, 4))
print("\n--- Computing Leave-One-Out (LOO) ---")
loo_values = attributor.compute(method_name="loo") 
if loo_values is not None: 
    print("LOO Values:", np.round(loo_values, 4))

print("\n--- Computing Exact Shapley ---")
exact_values = attributor.compute(method_name="exact", exact_confirm=False) # Disable confirm for n=5 demo
if exact_values is not None: print("Exact Values:", np.round(exact_values, 4))

if exact_values is not None:
    if wss_values is not None:
        mae_wss = np.mean(np.abs(exact_values - wss_values))
        print(f"\nMAE WSS vs Exact: {mae_wss:.4f}")
    if tmc_values is not None:
        mae_tmc = np.mean(np.abs(exact_values - tmc_values))
        print(f"MAE TMC vs Exact: {mae_tmc:.4f}")
    if beta_dist and beta_u_values is not None:
        mae_beta_u = np.mean(np.abs(exact_values - beta_u_values))
        print(f"MAE BetaShap (U) vs Exact: {mae_beta_u:.4f}")
    if beta_dist and beta_uniform_values is not None:
        mae_beta_uni = np.mean(np.abs(exact_values - beta_uniform_values))
        print(f"MAE BetaShap (Uni) vs Exact: {mae_beta_uni:.4f}")
    # Compare LOO to Exact if available
    if exact_values is not None and loo_values is not None:
        mae_loo = np.mean(np.abs(exact_values - loo_values))
        print(f"\nMAE LOO vs Exact: {mae_loo:.4f}")



--- Computing ContextCite (Surrogate Weights) ---
Starting contextcite (n=5, m=16)...
Computing true utilities for 16 samples...


Loading checkpoint shards: 100%|██████████| 4/4 [00:01<00:00,  2.96it/s]
Fetching 3 files: 100%|██████████| 3/3 [00:00<00:00, 11204.73it/s]
Loading checkpoint shards: 100%|██████████| 4/4 [00:01<00:00,  3.15it/s]
Loading checkpoint shards: 100%|██████████| 4/4 [00:01<00:00,  3.25it/s]
Loading checkpoint shards: 100%|██████████| 4/4 [00:01<00:00,  3.22it/s]
Loading checkpoint shards: 100%|██████████| 4/4 [00:01<00:00,  3.16it/s]
Loading checkpoint shards: 100%|██████████| 4/4 [00:01<00:00,  3.23it/s]
Loading checkpoint shards: 100%|██████████| 4/4 [00:01<00:00,  3.20it/s]
Loading checkpoint shards: 100%|██████████| 4/4 [00:01<00:00,  3.26it/s]
Loading checkpoint shards: 100%|██████████| 4/4 [00:01<00:00,  3.04it/s]
Loading checkpoint shards: 100%|██████████| 4/4 [00:01<00:00,  3.04it/s]
Loading checkpoint shards: 100%|██████████| 4/4 [00:01<00:00,  3.16it/s]
Loading checkpoint shards: 100%|██████████| 4/4 [00:01<00:00,  3.16it/s]
Loading checkpoint shards: 100%|██████████| 4/4 [00:01<00

Training surrogate model...
  Surrogate Weights (w): [ 2.6634 -0.1679  1.9914  1.6188  3.2961]
  Surrogate Intercept (b): -19.4259
Method 'contextcite' finished in 85.42s.
ContextCite Weights: [ 2.6634 -0.1679  1.9914  1.6188  3.2961]

--- Computing Weakly Supervised Shapley (WSS) ---
Starting wss (n=5, m=16)...
Computing true utilities for 16 samples...


Loading checkpoint shards: 100%|██████████| 4/4 [00:01<00:00,  3.16it/s]
Loading checkpoint shards: 100%|██████████| 4/4 [00:01<00:00,  3.12it/s]
Loading checkpoint shards: 100%|██████████| 4/4 [00:01<00:00,  3.24it/s]
Loading checkpoint shards: 100%|██████████| 4/4 [00:01<00:00,  3.24it/s]
Loading checkpoint shards: 100%|██████████| 4/4 [00:01<00:00,  3.21it/s]
Loading checkpoint shards: 100%|██████████| 4/4 [00:01<00:00,  3.23it/s]
Loading checkpoint shards: 100%|██████████| 4/4 [00:01<00:00,  3.26it/s]
Loading checkpoint shards: 100%|██████████| 4/4 [00:01<00:00,  3.25it/s]
Loading checkpoint shards: 100%|██████████| 4/4 [00:01<00:00,  3.19it/s]
Loading checkpoint shards: 100%|██████████| 4/4 [00:01<00:00,  3.15it/s]
Loading checkpoint shards: 100%|██████████| 4/4 [00:01<00:00,  3.25it/s]
Loading checkpoint shards: 100%|██████████| 4/4 [00:01<00:00,  3.17it/s]
Loading checkpoint shards: 100%|██████████| 4/4 [00:01<00:00,  3.19it/s]
Loading checkpoint shards: 100%|██████████| 4/4 [00

Training surrogate model...
  Surrogate Weights (w): [ 3.3178 -1.4586  1.2333  1.1678  3.4187]
  Surrogate Intercept (b): -18.6363
Building hybrid utility set for WSS...
Calculating Shapley values from hybrid utilities for WSS...


                                                                    

Method 'wss' finished in 77.58s.
WSS Values: [ 3.7737 -1.4224  1.5222  1.6907  3.6614]
WSS Surrogate Weights: [ 3.3178 -1.4586  1.2333  1.1678  3.4187]

--- Computing TMC-Shapley ---
Starting TMC-Shapley (n=5, T=5)...


Loading checkpoint shards: 100%|██████████| 4/4 [00:01<00:00,  3.27it/s]
Loading checkpoint shards: 100%|██████████| 4/4 [00:01<00:00,  3.17it/s]
Loading checkpoint shards: 100%|██████████| 4/4 [00:01<00:00,  3.24it/s]
Loading checkpoint shards: 100%|██████████| 4/4 [00:01<00:00,  3.20it/s]
Loading checkpoint shards: 100%|██████████| 4/4 [00:01<00:00,  3.15it/s]
Loading checkpoint shards: 100%|██████████| 4/4 [00:01<00:00,  3.10it/s]
Loading checkpoint shards: 100%|██████████| 4/4 [00:01<00:00,  3.15it/s]
Loading checkpoint shards: 100%|██████████| 4/4 [00:01<00:00,  3.16it/s]
Loading checkpoint shards: 100%|██████████| 4/4 [00:01<00:00,  3.22it/s]
Loading checkpoint shards: 100%|██████████| 4/4 [00:01<00:00,  3.23it/s]
Loading checkpoint shards: 100%|██████████| 4/4 [00:01<00:00,  3.22it/s]
Loading checkpoint shards: 100%|██████████| 4/4 [00:01<00:00,  3.11it/s]
Loading checkpoint shards: 100%|██████████| 4/4 [00:01<00:00,  3.23it/s]
Loading checkpoint shards: 100%|██████████| 4/4 [00

TMC-Shapley made 14 unique LLM calls.
Method 'tmc' finished in 68.08s.
TMC Values: [ 3.2356  2.8438  1.076  -4.4321  6.5021]

--- Computing Beta-Shapley (U-shaped, more weight to ends) ---
Starting Beta-Shapley (n=5, T=5, α=0.5, β=0.5)...


Loading checkpoint shards: 100%|██████████| 4/4 [00:01<00:00,  3.14it/s]
Loading checkpoint shards: 100%|██████████| 4/4 [00:01<00:00,  3.18it/s]
Loading checkpoint shards: 100%|██████████| 4/4 [00:01<00:00,  3.17it/s]
Loading checkpoint shards: 100%|██████████| 4/4 [00:01<00:00,  3.14it/s]
Loading checkpoint shards: 100%|██████████| 4/4 [00:01<00:00,  3.10it/s]
Loading checkpoint shards: 100%|██████████| 4/4 [00:01<00:00,  3.14it/s]
Loading checkpoint shards: 100%|██████████| 4/4 [00:01<00:00,  3.16it/s]
Loading checkpoint shards: 100%|██████████| 4/4 [00:01<00:00,  3.20it/s]
Loading checkpoint shards: 100%|██████████| 4/4 [00:01<00:00,  3.23it/s]
Loading checkpoint shards: 100%|██████████| 4/4 [00:01<00:00,  3.24it/s]
Loading checkpoint shards: 100%|██████████| 4/4 [00:01<00:00,  3.14it/s]
Loading checkpoint shards: 100%|██████████| 4/4 [00:01<00:00,  3.15it/s]
Loading checkpoint shards: 100%|██████████| 4/4 [00:01<00:00,  3.15it/s]
Loading checkpoint shards: 100%|██████████| 4/4 [00

Beta-Shapley made 19 unique LLM calls.
Method 'betashap' finished in 95.05s.
BetaShap (U-shaped) Values: [5.1041 0.1326 2.7947 0.1765 2.7262]

--- Computing Beta-Shapley (Uniform, equivalent to standard MC) ---
Starting Beta-Shapley (n=5, T=5, α=1.0, β=1.0)...


Loading checkpoint shards: 100%|██████████| 4/4 [00:01<00:00,  3.09it/s]
Loading checkpoint shards: 100%|██████████| 4/4 [00:01<00:00,  3.24it/s]
Loading checkpoint shards: 100%|██████████| 4/4 [00:01<00:00,  3.17it/s]
Loading checkpoint shards: 100%|██████████| 4/4 [00:01<00:00,  3.15it/s]
Loading checkpoint shards: 100%|██████████| 4/4 [00:01<00:00,  3.14it/s]
Loading checkpoint shards: 100%|██████████| 4/4 [00:01<00:00,  3.10it/s]
Loading checkpoint shards: 100%|██████████| 4/4 [00:01<00:00,  3.21it/s]
Loading checkpoint shards: 100%|██████████| 4/4 [00:01<00:00,  3.08it/s]
Loading checkpoint shards: 100%|██████████| 4/4 [00:01<00:00,  3.25it/s]
Loading checkpoint shards: 100%|██████████| 4/4 [00:01<00:00,  3.18it/s]
Loading checkpoint shards: 100%|██████████| 4/4 [00:01<00:00,  3.22it/s]
Loading checkpoint shards: 100%|██████████| 4/4 [00:01<00:00,  3.18it/s]
Loading checkpoint shards: 100%|██████████| 4/4 [00:01<00:00,  3.24it/s]
Loading checkpoint shards: 100%|██████████| 4/4 [00

Beta-Shapley made 16 unique LLM calls.
Method 'betashap' finished in 80.13s.
BetaShap (Uniform) Values: [ 1.5119  0.2666  4.6781 -0.7403  3.5093]

--- Computing Leave-One-Out (LOO) ---
Starting Leave-One-Out (LOO) computation (n=5)...
  Calculating utility of full set V(N)...


Loading checkpoint shards: 100%|██████████| 4/4 [00:01<00:00,  3.11it/s]


  Calculating utility V(N-i) for 5 items...


Loading checkpoint shards: 100%|██████████| 4/4 [00:01<00:00,  3.17it/s]
Loading checkpoint shards: 100%|██████████| 4/4 [00:01<00:00,  3.25it/s]
Loading checkpoint shards: 100%|██████████| 4/4 [00:01<00:00,  3.16it/s]
Loading checkpoint shards: 100%|██████████| 4/4 [00:01<00:00,  3.18it/s]
Loading checkpoint shards: 100%|██████████| 4/4 [00:01<00:00,  3.17it/s]
                                                        

LOO computation made 6 unique LLM calls.
Method 'loo' finished in 28.87s.
LOO Values: [-1.1096 -1.7544 -2.4805  0.1765  0.0977]

--- Computing Exact Shapley ---
Starting Exact Shapley (n=5, 32 LLM calls)...


Loading checkpoint shards: 100%|██████████| 4/4 [00:01<00:00,  3.17it/s]
Loading checkpoint shards: 100%|██████████| 4/4 [00:01<00:00,  3.25it/s]
Loading checkpoint shards: 100%|██████████| 4/4 [00:01<00:00,  3.17it/s]
Loading checkpoint shards: 100%|██████████| 4/4 [00:01<00:00,  3.12it/s]
Loading checkpoint shards: 100%|██████████| 4/4 [00:01<00:00,  3.13it/s]
Loading checkpoint shards: 100%|██████████| 4/4 [00:01<00:00,  3.15it/s]
Loading checkpoint shards: 100%|██████████| 4/4 [00:01<00:00,  3.14it/s]
Loading checkpoint shards: 100%|██████████| 4/4 [00:01<00:00,  3.12it/s]
Loading checkpoint shards: 100%|██████████| 4/4 [00:01<00:00,  2.63it/s]
Loading checkpoint shards: 100%|██████████| 4/4 [00:01<00:00,  2.77it/s]
Loading checkpoint shards: 100%|██████████| 4/4 [00:01<00:00,  2.88it/s]
Loading checkpoint shards: 100%|██████████| 4/4 [00:01<00:00,  2.85it/s]
Loading checkpoint shards: 100%|██████████| 4/4 [00:01<00:00,  2.85it/s]
Loading checkpoint shards: 100%|██████████| 4/4 [00

Method 'exact' finished in 159.52s.
Exact Values: [ 3.7975  0.8453  2.0478 -0.7774  3.3123]

MAE WSS vs Exact: 1.1269
MAE TMC vs Exact: 2.0753
MAE BetaShap (U) vs Exact: 0.8612
MAE BetaShap (Uni) vs Exact: 1.1457

MAE LOO vs Exact: 3.2407




In [None]:
from scipy.stats import spearmanr, pearsonr
correlation, p_value = spearmanr(exact_values, wss_values)
print(f"Spearman Correlation: {correlation:.4f}")

In [None]:
logp, p, model_out, token_probs = compute_logprob(
    query="What is the capital of Kazakhstan?",
    ground_truth_answer="Paris",
    return_token_probs=True
)

print("Log probability:", logp)
print("Total probability:", p)
print("Model generated:", model_out)