In [1]:
import sys
import os
import random
import gc
import time
import torch
import numpy as np
import pandas as pd
import pickle
import ast
from tqdm import tqdm
from scipy.sparse import csr_matrix
import itertools
from scipy.stats import spearmanr, pearsonr, kendalltau, rankdata
from sklearn.metrics import ndcg_score
from transformers import AutoModelForCausalLM, AutoTokenizer
from accelerate import Accelerator
# import nltk
# nltk.download('punkt')
os.environ["CUDA_VISIBLE_DEVICES"] = "3"
current_dir = os.getcwd()
parent_dir = os.path.abspath(os.path.join(current_dir, '..'))
sys.path.append(parent_dir)
from SHapRAG import *
from SHapRAG.utils import *

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
df=pd.read_json("../data/musique/musique_ans_v1.0_train.jsonl", lines=True)

In [3]:
def get_titles(lst):
    # Titles where is_supporting is True
    supporting = [d['paragraph_text'] for d in lst if d.get('is_supporting') == True]
    # Titles where is_supporting is False or missing AND not already in supporting
    others = [d['paragraph_text'] for d in lst if d.get('is_supporting') != True and d['paragraph_text'] not in supporting]
    # Combine: all supporting + as many others as needed to reach 10
    result = supporting + others
    return result[:10]

df.paragraphs=df.paragraphs.apply(get_titles)

In [4]:
SEED = 42
# Initialize Accelerator
accelerator_main = Accelerator(mixed_precision="fp16")

# Load Model
if accelerator_main.is_main_process:
    print("Main Script: Loading model...")
# model_path = "mistralai/Mistral-7B-Instruct-v0.3"
model_path = "meta-llama/Llama-3.1-8B-Instruct"
# model_path = "Qwen/Qwen2.5-3B-Instruct"

model_cpu = AutoModelForCausalLM.from_pretrained(
    model_path,
    torch_dtype=torch.float16
)
tokenizer = AutoTokenizer.from_pretrained(model_path)
if tokenizer.pad_token is None:
    tokenizer.pad_token = tokenizer.eos_token
    model_cpu.config.pad_token_id = tokenizer.pad_token_id
    if hasattr(model_cpu, 'generation_config') and model_cpu.generation_config is not None:
        model_cpu.generation_config.pad_token_id = tokenizer.pad_token_id

if accelerator_main.is_main_process:
    print("Main Script: Preparing model with Accelerator...")
prepared_model = accelerator_main.prepare(model_cpu)
unwrapped_prepared_model = accelerator_main.unwrap_model(prepared_model)
unwrapped_prepared_model.eval()
if accelerator_main.is_main_process:
    print("Main Script: Model prepared and set to eval.")

# Define utility cache

accelerator_main.wait_for_everyone()

Detected kernel version 5.4.0, which is below the recommended minimum of 5.5.0; this can cause the process to hang. It is recommended to upgrade the kernel to the minimum version or higher.


Main Script: Loading model...


Loading checkpoint shards: 100%|██████████| 4/4 [00:02<00:00,  1.46it/s]


Main Script: Preparing model with Accelerator...
Main Script: Model prepared and set to eval.


In [None]:
num_questions_to_run = 50
k_values = [1, 2, 3, 4, 5]
all_results = []
extras = []
def gtset_k():
    return [0, 1,2,3,5]
for i in range(num_questions_to_run):
    i=i+19850
    query = df.question[i]
    if accelerator_main.is_main_process:
        print(f"\n--- Question {i+1}/{num_questions_to_run}: {query[:60]}... ---")

    docs = df.paragraphs[i]
    utility_cache_base_dir = f"../Experiment_data/musique/{model_path.split('/')[1]}/four/"
    utility_cache_filename = f"utilities_q_idx{i}.pkl"
    current_utility_path = os.path.join(utility_cache_base_dir, utility_cache_filename)

    if accelerator_main.is_main_process:
        os.makedirs(os.path.dirname(current_utility_path), exist_ok=True)

    harness = ContextAttribution(
        items=docs,
        query=query,
        prepared_model=prepared_model,
        prepared_tokenizer=tokenizer,
        accelerator=accelerator_main,
        utility_cache_path=current_utility_path
    )
    full_budget=pow(2,harness.n_items)
    # res = evaluate(df.question[i], harness.target_response, df.answer[i])
    # res='True'
    if accelerator_main.is_main_process:
        methods_results = {}
        metrics_results = {}
        extra_results = {}

        m_samples_map = {"XS": 32, "S":64, "M":128, "L":264, "XL":528, "XXL":724, "XXXL":1024}

        # Store FM models for later R²/MSE
        fm_models = {}
        methods_results['Exact-Shap']=harness._calculate_shapley()
        for size_key, actual_samples in m_samples_map.items():
            print(f"Running sample size: {actual_samples}")
            methods_results[f"ContextCite_{actual_samples}"], fm_models[f"ContextCite_{actual_samples}"] = harness.compute_contextcite(
                num_samples=actual_samples, seed=SEED
            )
            # FM Weights (loop over ranks 0–5)
            for rank in [1,2,4,8]:
                methods_results[f"FM_WeightsLK_{rank}_{actual_samples}"], extra_results[f"Flk_{rank}_{actual_samples}"], fm_models[f"FM_WeightsLK_{rank}_{actual_samples}"] = harness.compute_wss(
                    num_samples=actual_samples,
                    seed=SEED,
                    sampling="kernelshap",
                    sur_type="fm",
                    rank=rank
                )
                # methods_results[f"FM_WeightsLU_{rank}_{actual_samples}"], extra_results[f"Flu_{rank}_{actual_samples}"], fm_models[f"FM_WeightsLU_{rank}_{actual_samples}"] = harness.compute_wss(
                #     num_samples=actual_samples,
                #     seed=SEED,
                #     sampling="kernelshap",
                #     sur_type="fm",
                #     rank=rank
                # )
            # methods_results[f"FM_u_dynamic_{actual_samples}"], extra_results[f"FM_u_dynamic_{actual_samples}"], fm_models[f"FM_u_dynamic_{actual_samples}"] = harness.compute_wss_dynamic_pruning_reuse_utility(num_samples=actual_samples)
            methods_results[f"FM_k_dynamic_{actual_samples}"], extra_results[f"FM_k_dynamic_{actual_samples}"], fm_models[f"FM_k_dynamic_{actual_samples}"] = harness.compute_wss_dynamic_pruning_reuse_utility(num_samples=actual_samples, initial_rank=1, final_rank=2)
            methods_results[f"FM_k_dynamice_{actual_samples}"], extra_results[f"FM_k_dynamice_{actual_samples}"], fm_models[f"FM_k_dynamice_{actual_samples}"] = harness.compute_wss_dynamic_pruning_reuse_utility(num_samples=actual_samples, initial_rank=1, pruning_strategy='elbow')
            try:
                # attributionsspex, interactionspex = harness.compute_spex(sample_budget=actual_samples, max_order=2)
                attributionshap, interactionshap, fm_models[f"FSII_{actual_samples}"] = harness.compute_fsii(sample_budget=actual_samples, max_order=2)
                # attributionban, interactionban, fm_models[f"FBII_{actual_samples}"] = harness.compute_fbii(sample_budget=actual_samples, max_order=harness.n_items)
                # methods_results[f"FBII_{actual_samples}"] = attributionban
                methods_results[f"FSII_{actual_samples}"] = attributionshap
                # methods_results[f"Spex_{actual_samples}"] = attributionsspex


                extra_results.update({
                f"Int_FSII_{actual_samples}":interactionshap
                # f"Int_FBII_{actual_samples}":interactionban,
                # f"Int_Spex_{actual_samples}":interactionspex
                })
            except Exception: pass


    #     methods_results["LOO"] = harness.compute_loo()
    #     methods_results["ARC-JSD"] = harness.compute_arc_jsd()
        attributionxs, interactionxs, fm_models["Exact-FSII"] = harness.compute_exact_fsii(max_order=2)

        extra_results.update({
        "Exact-FSII": interactionxs
    })
        methods_results["Exact-FSII"]=attributionxs

        # --- Evaluation Metrics ---
        metrics_results["topk_probability"] = harness.evaluate_topk_performance(
            methods_results, fm_models, k_values
        )

        # R²
        metrics_results["R2"] = harness.r2(methods_results,100,mode='logit-prob', models=fm_models)
        metrics_results['Recall']=harness.recall_at_k(gtset_k(), methods_results, k_values)
        metrics_results["Delta_R2"] = harness.delta_r2(methods_results,100,mode='logit-prob', models=fm_models)

        # LDS per method
        metrics_results["LDS"] = harness.lds(methods_results,100,mode='logit-prob', models=fm_models)



        all_results.append({
            "query_index": i,
            "query": query,
            "ground_truth": df.answer[i],
            "response": harness.target_response,
            "methods": methods_results,
            "metrics": metrics_results
        })
        extras.append(extra_results)



--- Question 19851/50: What county shares border with another county adjacent to th... ---
Main Process: Attempting to load utility cache from ../Experiment_data/musique/Llama-3.1-8B-Instruct/four/utilities_q_idx19850.pkl...
Successfully loaded 1024 cached utility entries.
Running sample size: 32
Initial scores: [6.5892857  0.36849948 7.37868511 0.0155777  0.09110336 1.39551912
 0.53505493 1.57352739 0.78223756 0.59002116]
We are keeping 7 documents
Initial scores: [7.56467186 1.20196381 7.20846697 0.15168068 0.91531719 0.82470893
 3.18235842 0.81675852 1.44531187 0.2713567 ]
We are keeping 2 documents
Running sample size: 64
Initial scores: [5.47379741 1.0248676  7.11213604 0.9471321  0.24264287 0.47874157
 2.2543925  0.75076352 0.73539528 0.23320522]
We are keeping 7 documents
Initial scores: [5.47651263 1.46204204 6.44396793 0.55362868 0.22188169 0.6661499
 1.43128399 0.14328284 0.12177976 0.19374594]
We are keeping 2 documents
Running sample size: 128
Initial scores: [5.42267428 1

  res = np.real(U_slice/U_slice[0])[1:]
  res = np.real(U_slice/U_slice[0])[1:]


SPEX approximation completed.
Running sample size: 724
Initial scores: [1.60708657e+01 7.86015977e-03 1.04617854e+00 2.17575781e-01
 5.98598953e-02 1.49355166e-01 6.54527285e-03 1.99961671e-01
 6.99685680e-02 5.13892316e-02]
We are keeping 7 documents
Initial scores: [16.09047475  0.018277    1.05188007  0.22455553  0.07702564  0.18145334
  0.01908652  0.19339018  0.0673523   0.08553774]
We are keeping 1 documents


  res = np.real(U_slice/U_slice[0])[1:]
  res = np.real(U_slice/U_slice[0])[1:]


SPEX approximation completed.
Running sample size: 1024
Initial scores: [1.60597830e+01 2.68139545e-02 1.04320715e+00 2.08947838e-01
 6.96973850e-02 1.57148180e-01 7.30500418e-03 1.95857226e-01
 6.94808613e-02 7.67257530e-02]
We are keeping 7 documents
Initial scores: [1.60597830e+01 2.68139545e-02 1.04320715e+00 2.08947838e-01
 6.96973850e-02 1.57148180e-01 7.30500418e-03 1.95857226e-01
 6.94808613e-02 7.67257530e-02]
We are keeping 1 documents
SPEX approximation completed.
Computing actual utility deltas for 488 (subset, player) pairs...
Computing delta R² for Exact-Shap...
Exact-Shap delta R² score: 0.9795
Computing delta R² for ContextCite_32...
ContextCite_32 delta R² score: 0.9749
Computing delta R² for FM_WeightsLK_1_32...
FM_WeightsLK_1_32 delta R² score: 0.9734
Computing delta R² for FM_WeightsLK_2_32...
FM_WeightsLK_2_32 delta R² score: 0.9053
Computing delta R² for FM_WeightsLK_4_32...
FM_WeightsLK_4_32 delta R² score: 0.8875
Computing delta R² for FM_WeightsLK_8_32...
FM_We

  res = np.real(U_slice/U_slice[0])[1:]
  res = np.real(U_slice/U_slice[0])[1:]


SPEX approximation completed.
Running sample size: 724
Initial scores: [11.40440982  0.54176862  0.1663961   0.05923221  0.35352747  0.0995913
  0.3846241   0.1077412   0.01490515  0.29850473]
We are keeping 7 documents
Initial scores: [11.42370399  0.5179778   0.16766007  0.04195839  0.34112995  0.06944337
  0.377348    0.11360033  0.01983012  0.28405852]
We are keeping 1 documents


  res = np.real(U_slice/U_slice[0])[1:]
  res = np.real(U_slice/U_slice[0])[1:]


SPEX approximation completed.
Running sample size: 1024
Initial scores: [11.42888101  0.53408038  0.17793074  0.04166279  0.36515167  0.07962852
  0.40500991  0.10773681  0.01438459  0.28134288]
We are keeping 7 documents
Initial scores: [11.42888101  0.53408038  0.17793074  0.04166279  0.36515167  0.07962852
  0.40500991  0.10773681  0.01438459  0.28134288]
We are keeping 1 documents
SPEX approximation completed.
Computing actual utility deltas for 488 (subset, player) pairs...
Computing delta R² for Exact-Shap...
Exact-Shap delta R² score: 0.9788
Computing delta R² for ContextCite_32...
ContextCite_32 delta R² score: 0.9728
Computing delta R² for FM_WeightsLK_1_32...
FM_WeightsLK_1_32 delta R² score: 0.9672
Computing delta R² for FM_WeightsLK_2_32...
FM_WeightsLK_2_32 delta R² score: 0.9282
Computing delta R² for FM_WeightsLK_4_32...
FM_WeightsLK_4_32 delta R² score: 0.9268
Computing delta R² for FM_WeightsLK_8_32...
FM_WeightsLK_8_32 delta R² score: 0.9647
Computing delta R² for FM_

  res = np.real(U_slice/U_slice[0])[1:]
  res = np.real(U_slice/U_slice[0])[1:]


SPEX approximation completed.
Running sample size: 724
Initial scores: [15.63038206  0.48552029  1.14372302  0.01581309  0.28264554  0.09714203
  0.72609121  0.06067475  0.17261865  0.36918505]
We are keeping 7 documents
Initial scores: [1.55454309e+01 4.52382458e-01 1.17145164e+00 1.40973476e-02
 3.61977761e-01 2.06348948e-01 7.14042291e-01 1.02885073e-01
 1.66150586e-01 3.76369984e-01]
We are keeping 1 documents


  res = np.real(U_slice/U_slice[0])[1:]
  res = np.real(U_slice/U_slice[0])[1:]


SPEX approximation completed.
Running sample size: 1024
Initial scores: [1.55498442e+01 4.93539278e-01 1.17582739e+00 1.52740784e-02
 2.66816118e-01 1.49014209e-01 7.29289776e-01 7.00545249e-02
 1.95419927e-01 4.34110942e-01]
We are keeping 7 documents
Initial scores: [1.55498442e+01 4.93539278e-01 1.17582739e+00 1.52740784e-02
 2.66816118e-01 1.49014209e-01 7.29289776e-01 7.00545249e-02
 1.95419927e-01 4.34110942e-01]
We are keeping 1 documents
SPEX approximation completed.
Computing actual utility deltas for 488 (subset, player) pairs...
Computing delta R² for Exact-Shap...
Exact-Shap delta R² score: 0.9181
Computing delta R² for ContextCite_32...
ContextCite_32 delta R² score: 0.9148
Computing delta R² for FM_WeightsLK_1_32...
FM_WeightsLK_1_32 delta R² score: 0.8747
Computing delta R² for FM_WeightsLK_2_32...
FM_WeightsLK_2_32 delta R² score: 0.7658
Computing delta R² for FM_WeightsLK_4_32...
FM_WeightsLK_4_32 delta R² score: 0.8113
Computing delta R² for FM_WeightsLK_8_32...
FM_We

  res = np.real(U_slice/U_slice[0])[1:]
  res = np.real(U_slice/U_slice[0])[1:]


SPEX approximation completed.
Running sample size: 528
Initial scores: [0.04712354 0.04716923 0.00535015 0.052075   0.03740599 0.00065933
 0.03925264 0.04545412 0.00679067 0.0051981 ]
We are keeping 7 documents
Initial scores: [0.04027093 0.04408468 0.00028566 0.04070522 0.02477992 0.00585873
 0.04001093 0.03640637 0.00343215 0.00021289]
We are keeping 6 documents


  res = np.real(U_slice/U_slice[0])[1:]
  res = np.real(U_slice/U_slice[0])[1:]


SPEX approximation completed.
Running sample size: 724
Initial scores: [0.03309745 0.02963613 0.00719137 0.03674038 0.02427908 0.01025765
 0.03700934 0.03465849 0.00234133 0.0028636 ]
We are keeping 7 documents
Initial scores: [0.03891493 0.0430969  0.0079277  0.03644523 0.0241555  0.00319787
 0.04055481 0.04021064 0.00849529 0.00216838]
We are keeping 5 documents


  res = np.real(U_slice/U_slice[0])[1:]
  res = np.real(U_slice/U_slice[0])[1:]


SPEX approximation completed.
Running sample size: 1024
Initial scores: [0.03378952 0.03386782 0.00629458 0.03386783 0.01539315 0.00464293
 0.03273122 0.03386785 0.00148205 0.00253395]
We are keeping 7 documents
Initial scores: [0.03378952 0.03386782 0.00629458 0.03386783 0.01539315 0.00464293
 0.03273122 0.03386785 0.00148205 0.00253395]
We are keeping 5 documents


  res = np.real(U_slice/U_slice[0])[1:]
  res = np.real(U_slice/U_slice[0])[1:]


SPEX approximation completed.
Computing actual utility deltas for 488 (subset, player) pairs...
Computing delta R² for Exact-Shap...
Exact-Shap delta R² score: 0.0532
Computing delta R² for ContextCite_32...
ContextCite_32 delta R² score: -0.0126
Computing delta R² for FM_WeightsLK_1_32...
FM_WeightsLK_1_32 delta R² score: -1.0278
Computing delta R² for FM_WeightsLK_2_32...
FM_WeightsLK_2_32 delta R² score: -1.2675
Computing delta R² for FM_WeightsLK_4_32...
FM_WeightsLK_4_32 delta R² score: -0.8044
Computing delta R² for FM_WeightsLK_8_32...
FM_WeightsLK_8_32 delta R² score: -0.6208
Computing delta R² for FM_k_dynamic_32...
FM_k_dynamic_32 delta R² score: -4.7489
Computing delta R² for FM_k_dynamice_32...
FM_k_dynamice_32 delta R² score: -0.1411
Computing delta R² for ContextCite_64...
ContextCite_64 delta R² score: 0.0247
Computing delta R² for FM_WeightsLK_1_64...
FM_WeightsLK_1_64 delta R² score: -0.0203
Computing delta R² for FM_WeightsLK_2_64...
FM_WeightsLK_2_64 delta R² score: 



Running sample size: 32
Initial scores: [14.798695    0.21828743  0.08436556  0.3941484   0.29588828  0.1823256
  0.0654932   0.25604685  0.26189357  0.06713071]
We are keeping 7 documents
Initial scores: [14.00545105  0.77341058  0.47736847  0.98322992  0.31215136  0.42596403
  0.08031493  0.11917404  0.09160828  0.3065615 ]
We are keeping 1 documents
Running sample size: 64
Initial scores: [13.83362383  0.36815644  0.56711305  0.86127482  0.44449222  0.20167907
  0.33096348  0.84792929  0.05398353  0.23818652]
We are keeping 7 documents
Initial scores: [14.00841776  0.40313364  0.30454055  0.89595898  0.19377623  0.10711517
  0.07503261  0.54264396  0.46789073  0.21065172]
We are keeping 1 documents
Running sample size: 128
Initial scores: [14.08234739  0.15537999  0.18127124  0.8762362   0.59992284  0.3396797
  0.21021801  0.3712984   0.15839097  0.06110746]
We are keeping 7 documents
Initial scores: [14.11778645  0.19938645  0.12316952  0.86256913  0.64950408  0.14014172
  0.315527

  res = np.real(U_slice/U_slice[0])[1:]
  res = np.real(U_slice/U_slice[0])[1:]


SPEX approximation completed.
Running sample size: 724
Initial scores: [1.25516940e+01 2.64872144e-01 6.18140995e-02 4.62324226e-01
 4.02287256e-02 1.80796200e-02 8.58260365e-02 1.19023208e-01
 1.24125985e-03 1.43264091e-01]
We are keeping 7 documents
Initial scores: [12.53917515  0.25307485  0.05550956  0.4536406   0.04972337  0.01560935
  0.09418128  0.105955    0.01333488  0.13416036]
We are keeping 1 documents


  res = np.real(U_slice/U_slice[0])[1:]
  res = np.real(U_slice/U_slice[0])[1:]


SPEX approximation completed.
Running sample size: 1024
Initial scores: [1.25658789e+01 2.74643799e-01 5.56767716e-02 4.69338881e-01
 3.57838762e-02 3.07763468e-02 8.81856889e-02 1.18250176e-01
 1.91404540e-03 1.19054440e-01]
We are keeping 7 documents
Initial scores: [1.25658789e+01 2.74643799e-01 5.56767716e-02 4.69338881e-01
 3.57838762e-02 3.07763468e-02 8.81856889e-02 1.18250176e-01
 1.91404540e-03 1.19054440e-01]
We are keeping 1 documents
SPEX approximation completed.
Computing actual utility deltas for 488 (subset, player) pairs...
Computing delta R² for Exact-Shap...
Exact-Shap delta R² score: 0.9852
Computing delta R² for ContextCite_32...
ContextCite_32 delta R² score: 0.9841
Computing delta R² for FM_WeightsLK_1_32...
FM_WeightsLK_1_32 delta R² score: 0.9842
Computing delta R² for FM_WeightsLK_2_32...
FM_WeightsLK_2_32 delta R² score: 0.9383
Computing delta R² for FM_WeightsLK_4_32...
FM_WeightsLK_4_32 delta R² score: 0.9686
Computing delta R² for FM_WeightsLK_8_32...
FM_We

  res = np.real(U_slice/U_slice[0])[1:]
  res = np.real(U_slice/U_slice[0])[1:]


SPEX approximation completed.
Running sample size: 724
Initial scores: [2.33810611 0.86134495 0.0750747  0.92181919 0.04507057 0.06589392
 0.0195996  0.07816451 0.07854854 0.07747023]
We are keeping 7 documents
Initial scores: [2.32459888 0.83505855 0.11874243 0.93720164 0.02957152 0.079496
 0.03712428 0.01816954 0.12769314 0.06525612]
We are keeping 3 documents


  res = np.real(U_slice/U_slice[0])[1:]
  res = np.real(U_slice/U_slice[0])[1:]


SPEX approximation completed.
Running sample size: 1024
Initial scores: [2.34628629 0.87877458 0.08920288 0.94124511 0.0292705  0.06658127
 0.05234761 0.05122521 0.10777623 0.07368337]
We are keeping 7 documents
Initial scores: [2.34628629 0.87877458 0.08920288 0.94124511 0.0292705  0.06658127
 0.05234761 0.05122521 0.10777623 0.07368337]
We are keeping 3 documents
SPEX approximation completed.
Computing actual utility deltas for 488 (subset, player) pairs...
Computing delta R² for Exact-Shap...
Exact-Shap delta R² score: 0.4630
Computing delta R² for ContextCite_32...
ContextCite_32 delta R² score: 0.4151
Computing delta R² for FM_WeightsLK_1_32...
FM_WeightsLK_1_32 delta R² score: 0.6015
Computing delta R² for FM_WeightsLK_2_32...
FM_WeightsLK_2_32 delta R² score: 0.5723
Computing delta R² for FM_WeightsLK_4_32...
FM_WeightsLK_4_32 delta R² score: -0.2221
Computing delta R² for FM_WeightsLK_8_32...
FM_WeightsLK_8_32 delta R² score: -3.1637
Computing delta R² for FM_k_dynamic_32...
FM

We detected that you are passing `past_key_values` as a tuple and this is deprecated and will be removed in v4.43. Please use an appropriate `Cache` class (https://huggingface.co/docs/transformers/v4.41.3/en/internal/generation_utils#transformers.Cache)


Computing utility for subset (0, 0, 0, 0, 0, 0, 0, 0, 0, 0) in mode 'logit-prob'...
Computing utility for subset (1, 0, 0, 0, 0, 0, 0, 0, 0, 0) in mode 'logit-prob'...
Computing utility for subset (0, 1, 0, 0, 0, 0, 0, 0, 0, 0) in mode 'logit-prob'...
Computing utility for subset (0, 0, 1, 0, 0, 0, 0, 0, 0, 0) in mode 'logit-prob'...
Computing utility for subset (0, 0, 0, 1, 0, 0, 0, 0, 0, 0) in mode 'logit-prob'...
Computing utility for subset (0, 0, 0, 0, 1, 0, 0, 0, 0, 0) in mode 'logit-prob'...
Computing utility for subset (0, 0, 0, 0, 0, 1, 0, 0, 0, 0) in mode 'logit-prob'...
Computing utility for subset (0, 0, 0, 0, 0, 0, 1, 0, 0, 0) in mode 'logit-prob'...
Computing utility for subset (0, 0, 0, 0, 0, 0, 0, 1, 0, 0) in mode 'logit-prob'...
Computing utility for subset (0, 0, 0, 0, 0, 0, 0, 0, 1, 0) in mode 'logit-prob'...
Computing utility for subset (0, 0, 0, 0, 0, 0, 0, 0, 0, 1) in mode 'logit-prob'...
Computing utility for subset (1, 1, 0, 0, 0, 0, 0, 0, 0, 0) in mode 'logit-p

In [None]:
from sklearn.preprocessing import MinMaxScaler
from sklearn.metrics import ndcg_score
import numpy as np

spearmans = {i: [] for i in all_results[0]['methods'] if i != "Exact-Shap"}
scaler = MinMaxScaler()

for method_res in all_results:
    for method, attribution in method_res['methods'].items():
        if method != "Exact-Shap":
            # Convert to numpy arrays for scaling
            ref = np.array(method_res['methods']["Exact-Shap"]).reshape(-1, 1)
            att = np.array(attribution).reshape(-1, 1)
            
            # Scale both reference and attribution to [0, 1]
            ref_scaled = scaler.fit_transform(ref).flatten()
            att_scaled = scaler.fit_transform(att).flatten()
            
            # Compute NDCG score
            spear = ndcg_score([ref_scaled], [att_scaled], k=4)
            spearmans[method].append(spear)


In [None]:
import re
import numpy as np
import matplotlib.pyplot as plt

# Parse methods and budgets
parsed = {}
budgets = set()
for key, values in spearmans.items():
    avg_val = np.mean(values)
    
    match = re.match(r"(.+?)_(\d+)$", key)  # method_budget pattern
    if match:
        method, budget = match.groups()
        budget = int(budget)
        budgets.add(budget)
        parsed.setdefault(method, {})[budget] = avg_val
    else:
        # constant methods (no budget)
        parsed.setdefault(key, {})[None] = avg_val

budgets = sorted(budgets)

# Plot
plt.figure(figsize=(10,6))

for method, results in parsed.items():
    if None in results:  # constant method
        plt.hlines(results[None], xmin=min(budgets), xmax=max(budgets), 
                   linestyles='--', label=method)
    else:
        xs = sorted(results.keys())
        ys = [results[b] for b in xs]
        plt.plot(xs, ys, marker='o', label=method)

plt.xlabel("Budget")
plt.ylabel("Average Spearman")
plt.title("Average Recall to Exact Shap per Method vs Budget")
plt.legend()
plt.grid(True)
plt.show()


In [None]:
import pandas as pd
import numpy as np
from collections import defaultdict

def summarize_and_print(all_results, k_values=[1, 2, 3,4,5]):
    table_data = defaultdict(lambda: defaultdict(list))

    # Mapping for consistency
    method_name_map = {
        
    }

    for res in all_results:
        metrics = res["metrics"]
        # LDS and R2
        for method_name, lds_val in metrics.get("LDS", {}).items():
            method = method_name_map.get(method_name, method_name)
            table_data[method]["LDS"].append(lds_val)

        for method_name, r2_val in metrics.get("R2", {}).items():
            method = method_name_map.get(method_name, method_name)
            table_data[method]["R2"].append(r2_val)

        # Delta R2 (new)
        for method_name, delta_val in metrics.get("Delta_R2", {}).items():
            method = method_name_map.get(method_name, method_name)
            table_data[method]["Delta_R2"].append(delta_val)

        # Top-k
        for method_name, k_dict in metrics.get("topk_probability", {}).items():
            method = method_name_map.get(method_name, method_name)
            for k in k_values:
                if k in k_dict:
                    col_name = f"topk_probability_k{k}"
                    table_data[method][col_name].append(k_dict[k])
        
        for method_name, k_dict in metrics.get("Recall", {}).items():
            method = method_name_map.get(method_name, method_name)
            for k in k_values:
                col_name = f"Recall@{k}"
                table_data[method][col_name].append(k_dict[k-1])

    # Averages
    avg_table = {
        method: {metric: np.nanmean(values) for metric, values in metric_dict.items()}
        for method, metric_dict in table_data.items()
    }

    # Standard deviations for LDS, R², and Delta_R2
    for method, metric_dict in table_data.items():
        for metric in ["LDS", "R2", "Delta_R2"]:
            if metric in metric_dict:
                avg_table[method][f"{metric}_std"] = np.nanstd(metric_dict[metric])

    df_summary = pd.DataFrame.from_dict(avg_table, orient="index").sort_index()

    print("\n=== Metrics Summary Across All Queries ===")
    print(df_summary.to_string(float_format="%.4f"))

    return df_summary
df_res=summarize_and_print(all_results, k_values=[1, 2, 3,4,5])


In [None]:
import pandas as pd
import matplotlib.pyplot as plt

# Reset index
df_reset = df_res.reset_index().rename(columns={'index': 'method'})

# Separate constant methods (no budget) and budgeted methods
constant_methods = ['LOO', 'ARC-JSD', 'Exact-FSII', 'Exact-Shap']
df_const = df_reset[df_reset['method'].isin(constant_methods)]
df_budgeted = df_reset[~df_reset['method'].isin(constant_methods)]

# Extract family and budget for budgeted methods
df_budgeted['family'] = df_budgeted['method'].apply(lambda x: "_".join(x.split("_")[:-1]))
df_budgeted['budget'] = df_budgeted['method'].apply(lambda x: int(x.split("_")[-1]))
df_budgeted = df_budgeted.sort_values(by=['family', 'budget'])

# Function to plot metric
def plot_metric(metric, ylabel):
    plt.figure(figsize=(12, 6))

    # Plot budgeted families
    families = df_budgeted['family'].unique()
    for fam in families:
        # if 'LK' not in fam:
        subset = df_budgeted[df_budgeted['family'] == fam]
        plt.plot(subset['budget'], subset[metric], marker='o', label=fam)

    # Plot constant methods as horizontal lines
    colors = plt.cm.tab10.colors  # categorical palette
    for idx, (_, row) in enumerate(df_const.iterrows()):
        plt.axhline(y=row[metric], color=colors[idx % len(colors)],marker='x', label=row['method'])

    plt.xlabel("Budget")
    plt.ylabel(ylabel)
    plt.title(f"Evolution of {ylabel} with Increasing Budget")
    plt.legend(bbox_to_anchor=(1.05, 1), loc='upper left')
    plt.grid(True)
    plt.tight_layout()
    plt.show()

plot_metric("R2", "R2")
plot_metric("LDS", "LDS")
plot_metric("Delta_R2", "Delta R2")


In [None]:
import pandas as pd
import matplotlib.pyplot as plt

# Filter budgeted methods at budget = 274
df_budgeted_264 = df_budgeted[df_budgeted['budget'] == 1024]

# Metrics to plot
recall_metrics = [f"Recall@{k}" for k in range(1, 6)]
k_values = list(range(1, 6))

plt.figure(figsize=(10, 6))

# Plot budgeted families at budget 274
families = df_budgeted_264['family'].unique()
for fam in families:
    # if 'LK' not in fam:
    subset = df_budgeted_264[df_budgeted_264['family'] == fam]
    if not subset.empty:
        recalls = subset[recall_metrics].values.flatten()
        plt.plot(k_values, recalls, marker='o', label=fam)

# Plot constant methods
for idx, (_, row) in enumerate(df_const.iterrows()):
    recalls = [row[m] for m in recall_metrics]
    plt.plot(k_values, recalls, marker='x', linestyle="--", label=row['method'])

plt.xlabel("k")
plt.ylabel("Recall@k")
plt.title("Recall@k for Budget = 264")
plt.legend(bbox_to_anchor=(1.05, 1), loc='upper left')
plt.grid(True)
plt.tight_layout()
plt.show()


In [None]:
plt.figure(figsize=(8, 5))
for method in df_res.index:
    if "264" in method :
        plt.plot(
            [1, 2, 3,4,5],
            df_res.loc[method, ['topk_probability_k1', 'topk_probability_k2', 'topk_probability_k3', 'topk_probability_k4', 'topk_probability_k5']],
            marker='o',
            label=method
        )

plt.xlabel('k')
plt.ylabel('Logit-Probability Drop')
plt.title('Top-k Logit-Probability Drop')
plt.legend()
plt.grid(True, linestyle='--', alpha=0.6)
plt.tight_layout()
plt.show()

# 1. Exact match with human labels

In [None]:
import numpy as np

def evaluate_methods(extras, k, m, interaction_type="max"):

    methods = extras[0].keys()
    scores = {m: 0 for m in methods}
    n_experiments = len(extras)

    for exp in extras:
        for method in methods:
            if "Fl" in method or "FM" in method:
                # Flu is a matrix
                value = exp[method][k][m]
                all_values = exp[method].flatten()
            else:
                # Dictionaries with tuple keys
                d = exp[method]
                value = None
                for key, v in d.items():
                    if key == (k, m):
                        value = v
                        break
                if value is None:
                    continue  # skip if (k,m) not found
                all_values = list(d.values())

            if interaction_type == "max":
                if value == max(all_values):
                    scores[method] += 1
            elif interaction_type == "min":
                if value == min(all_values):
                    scores[method] += 1

    # Convert to fraction of experiments
    results = {method: scores[method] / n_experiments for method in methods}
    return results

In [None]:
# Recovery rate
em={}
for i, j in enumerate(np.array(list(evaluate_methods(extras, k=1, m=5, interaction_type="min").values()))):
    em.update({list(extras[0].keys())[i]:j})

In [None]:
rows = []
for k, v in em.items():
    parts = k.split("_")
    if parts[0] == "Flk" and parts[1]!='0':
        _, rank, budget = parts
        rows.append({"method": "Flk", "rank": int(rank), "budget": int(budget), "recovery": v})
    elif parts[0] == "FM"and parts[3]!='0':
        _,_,_,budget = parts
        rows.append({"method": f'FM_k_dynamic', "budget": int(budget), "recovery": v})
    elif parts[0] == "Int":
        _, name, budget = parts
        rows.append({"method": name, "budget": int(budget), "recovery": v})

df = pd.DataFrame(rows)

# Plot
plt.figure(figsize=(10, 6))

# Plot Flu (different ranks as lines)
# for rank in sorted(df[df["method"]=="Flk"]["rank"].unique()):
#     sub = df[(df["method"]=="Flk") & (df["rank"]==rank)].sort_values("budget")
#     plt.plot(sub["budget"], sub["recovery"]/3, marker="o", label=f"Flk rank {rank}")

# for rank in sorted(df[df["method"]=="FM_k_dynamic"]["rank"].unique()):
#     sub = df[(df["method"]=="FM_k_dynamic") & (df["rank"]==rank)].sort_values("budget")
#     plt.plot(sub["budget"], sub["recovery"]/3, marker="o", label=f"FM_k_dynamic rank {rank}")
# for rank in sorted(df[df["method"]=="FM_r_dynamic"]):
sub1 = df[(df["method"]=="FM_k_dynamic")].sort_values("budget")
plt.plot(sub1["budget"], sub1["recovery"]/3, marker="o", label=f"FM_k_dynamic")
plt.plot(sub1["budget"],em['Exact-FSII']*np.ones(7)/3, marker="x", linestyle="--", label=f"Exact-FSII")

# Plot Int methods (evolve with budget, start at 264)
for m in df[df["method"].isin(["FSII","FBII","Spex"])]['method'].unique():
    sub = df[df["method"]==m].sort_values("budget")
    plt.plot(sub["budget"], sub["recovery"]/3, marker="+", linestyle="--", label=f"Int {m}")



plt.xlabel("Budget")
plt.ylabel("Recovery Rate")
plt.title("Evolution of Exact match with Increasing Budget")
plt.legend(bbox_to_anchor=(1.05, 1), loc="upper left")
plt.grid(True)
plt.tight_layout()
plt.show()

In [None]:
extras[1].keys()

In [None]:
1. compare the new and old fm
2. iteration with the interactions
3. shapley is for first order and faithshap for pairwise

# 2. RR@k

In [None]:
def compute_rr_at_k(interaction, ground_truth, k):
    """
    Compute Recovery@k for a method's interaction dict or matrix.
    interaction: dict {(i, j): value} or 2D numpy array/matrix
    ground_truth: set of ground-truth indices (R^*)
    k: number of top interactions to consider
    Returns: RR@k value
    """
    # Convert matrix to dict if needed
    if isinstance(interaction, (np.ndarray, list)):
        mat = np.array(interaction)
        pairs = {(i, j): mat[i][j] for i in range(mat.shape[0]) for j in range(mat.shape[1]) if i != j}
    else:
        pairs = interaction

    # Sort pairs by value (descending)
    sorted_pairs = sorted(pairs.items(), key=lambda x: x[1], reverse=True)
    rr_sum = 0.0
    for i in range(min(k, len(sorted_pairs))):
        pair_indices = set(sorted_pairs[i][0])
        rr_sum += len(ground_truth & pair_indices) / len(pair_indices)
    return rr_sum / k if k > 0 else 0.0

In [None]:
ground_truth = set([0, 1, 5])  # Example ground-truth indices
k = 5  # Number of top interactions to consider

rr_results = {}
for method, interaction in extras[0].items():  # Use the correct experiment index
    rr_results[method] = compute_rr_at_k(interaction, ground_truth, k)

print(rr_results)

In [None]:
import re
import matplotlib.pyplot as plt
import numpy as np
from collections import defaultdict

# Parse RR@k for all experiments and budgets
def extract_budget(key):
    match = re.search(r'_(\d+)$', key)
    return int(match.group(1)) if match else None

def extract_family(key):
    if key.startswith("FM_k_dynamic"):
        return "FM_k_dynamic"
    elif key.startswith("Int_FSII"):
        return "Int_FSII"
    elif key.startswith("Flk"):
        return key.split("_")[0]+"_"+key.split("_")[1]
    elif key.startswith("Int_FBII"):
        return "Int_FBII"
    return None

def collect_rr_at_k_over_budgets(extras, ground_truth, k):
    # For each experiment, for each method, collect RR@k by budget
    rr_by_method_budget = defaultdict(lambda: defaultdict(list))
    for exp in extras:
        for method, interaction in exp.items():
            budget = extract_budget(method)
            family = extract_family(method)
            if budget and family:
                rr = compute_rr_at_k(interaction, ground_truth, k)
                rr_by_method_budget[family][budget].append(rr)
    # Average over experiments
    rr_avg = defaultdict(dict)
    for family, budgets in rr_by_method_budget.items():
        for budget, vals in budgets.items():
            rr_avg[family][budget] = np.mean(vals)
    return rr_avg
rr_avg = collect_rr_at_k_over_budgets(extras, ground_truth, k)
# Plot RR@k as line chart for each method family, with constant methods as parallel lines
plt.figure(figsize=(10, 6))

# Plot budgeted families
for family, budget_rrs in rr_avg.items():
    budgets = sorted(budget_rrs.keys())
    values = [budget_rrs[b] for b in budgets]
    plt.plot(budgets, values, marker='o', label=family)

# Plot constant methods (e.g., Exact-FSII, LOO, ARC-JSD) as horizontal lines
constant_methods = ['Exact-FSII', 'LOO', 'ARC-JSD']
for method in constant_methods:
    # Collect RR@k for each experiment and average
    rr_vals = []
    for exp in extras:
        if method in exp:
            rr_vals.append(compute_rr_at_k(exp[method], ground_truth, k))
    if rr_vals:
        avg_rr = np.mean(rr_vals)
        plt.axhline(y=avg_rr, color=None, linestyle='--', label=method)

plt.xlabel('Budget')
plt.ylabel(f'RR@{k}')
plt.title(f'Recovery Rate at k={k} vs Budget (Averaged over experiments)')
plt.legend()
plt.grid(True)
plt.tight_layout()
plt.show()


# 3. NDCG to Exact-FSII

In [None]:
# FaithShap (absolute values) → NDCG plotting
import re
import numpy as np
import matplotlib.pyplot as plt
from collections import defaultdict
from sklearn.metrics import ndcg_score

def extract_budget(key):
    m = re.search(r'_(\d+)$', key)
    return int(m.group(1)) if m else None

def extract_family(key):
    # Map method keys to families used in earlier plots
    if key.startswith("FM_k_dynamice"):
        return "FM_k_dynamic"
    if key.startswith("Flk_"):
        # collapse to Flu/Flk family prefix (keep as-is for plotting)
        parts = key.split("_")
        return parts[0]+parts[1] if parts else None
    if key.startswith("Int_FSII") or key.startswith("FSII"):
        return "FSII"
    if key.startswith("Int_FBII") or key.startswith("FBII"):
        return "FBII"
    return None

def pairs_from_exact(exp_list):
    # Find a canonical pair ordering from Exact-FSII of the first experiment that has it
    for exp in exp_list:
        exact = exp.get('Exact-FSII')
        if exact and isinstance(exact, dict):
            return sorted(exact.keys())
    # fallback: try to infer from any dict-valued method
    for exp in exp_list:
        for v in exp.values():
            if isinstance(v, dict) and v:
                return sorted(v.keys())
    return []

def vector_for_pairs(val, pairs):
    # val can be dict {(i,j):score} or a square matrix/list/ndarray
    if isinstance(val, (list, np.ndarray)):
        mat = np.array(val)
        return [abs(mat[i][j]) if (0 <= i < mat.shape[0] and 0 <= j < mat.shape[1]) else 0.0 for (i,j) in pairs]
    elif isinstance(val, dict):
        return [abs(val.get(pair, 0.0)) for pair in pairs]
    else:
        # Unknown type -> zeros
        return [0.0 for _ in pairs]

# Build canonical pair list
pairs = pairs_from_exact(extras)
if not pairs:
    print('No pair ordering could be inferred from Exact-FSII or other dicts in extras. Aborting NDCG computation.')
else:
    # Compute per-experiment NDCG scores for each method (relative to Exact-FSII)
    per_method_ndcg = defaultdict(list)
    for exp in extras:
        exact = exp.get('Exact-FSII', {})
        exact_vec = vector_for_pairs(exact, pairs)
        # if exact vector is all zeros, skip this experiment for fairness
        if np.allclose(exact_vec, 0.0):
            continue
        for method, val in exp.items():
            if method == 'Exact-FSII':
                continue
            try:
                vec = vector_for_pairs(val, pairs)
                # ndcg_score expects shape (n_samples, n_labels) for both y_true and y_score
                score = ndcg_score([exact_vec], [vec])
                per_method_ndcg[method].append(score)
            except Exception:
                # skip methods we cannot convert
                continue

    # Average NDCG across experiments for each method
    avg_ndcg = {m: float(np.mean(scores)) for m, scores in per_method_ndcg.items() if len(scores)>0}

    # Group budgeted methods by family and budget
    family_budget = defaultdict(lambda: defaultdict(list))
    for method, score in avg_ndcg.items():
        budget = extract_budget(method)
        family = extract_family(method)
        if budget is not None and family is not None:
            family_budget[family][budget].append(score)

    # Compute mean per family-budget (in case multiple variant keys map to same family-budget)
    family_budget_avg = {}
    for fam, bd in family_budget.items():
        family_budget_avg[fam] = {b: float(np.mean(vals)) for b, vals in bd.items()}

    # Plotting: line per family (budgeted), horizontal lines for constant methods
    plt.figure(figsize=(10,6))
    # Plot budgeted families
    for fam, bd in family_budget_avg.items():
        xs = sorted(bd.keys())
        ys = [bd[x] for x in xs]
        plt.plot(xs, ys, marker='o', label=fam)

    # Constant methods: plot as horizontal lines using avg_ndcg if available
    constant_methods = ['Exact-FSII','Exact-Shap','LOO','ARC-JSD']
    for cm in constant_methods:
        if cm in avg_ndcg:
            plt.axhline(y=avg_ndcg[cm], linestyle='--', label=cm)

    plt.xlabel('Budget')
    plt.ylabel('NDCG to Exact-FSII (absolute interaction magnitudes)')
    plt.title('FaithShap — NDCG of absolute interactions vs Exact-FSII')
    plt.legend(bbox_to_anchor=(1.05,1), loc='upper left')
    plt.grid(True)
    plt.tight_layout()
    plt.show()