In [1]:
import sys
import os
import random
import gc
import time
import torch
import numpy as np
import pandas as pd
import pickle
import ast
from tqdm import tqdm
from scipy.sparse import csr_matrix
import itertools
from scipy.stats import spearmanr, pearsonr, kendalltau, rankdata
from sklearn.metrics import ndcg_score
from transformers import AutoModelForCausalLM, AutoTokenizer
from accelerate import Accelerator
# import nltk
# nltk.download('punkt')
os.environ["CUDA_VISIBLE_DEVICES"] = "0,3"
current_dir = os.getcwd()
parent_dir = os.path.abspath(os.path.join(current_dir, '..'))
sys.path.append(parent_dir)
from SHapRAG import *
from SHapRAG.utils import *

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
df=pd.read_json("../data/musique/musique_ans_v1.0_train.jsonl", lines=True)

In [3]:
def get_titles(lst):
    # Titles where is_supporting is True
    supporting = [d['paragraph_text'] for d in lst if d.get('is_supporting') == True]
    # Titles where is_supporting is False or missing AND not already in supporting
    others = [d['paragraph_text'] for d in lst if d.get('is_supporting') != True and d['paragraph_text'] not in supporting]
    # Combine: all supporting + as many others as needed to reach 10
    result = supporting + others
    return result[:10]

df.paragraphs=df.paragraphs.apply(get_titles)

In [None]:
# df['Sentences'] = df['paragraphs'].apply(
#     lambda para_list: [sent for para in para_list for sent in nltk.sent_tokenize(para)]
# )

In [None]:
# df_save=pd.read_csv('../data/musique/sen_labeled.csv',
#     quotechar='"',
#     skipinitialspace=True,
#     engine='python' )

In [4]:
df["paragraphs"] = df["paragraphs"].apply(lambda p: p[:5]+ [p[1]] + p[5:])

In [5]:
SEED = 42
# Initialize Accelerator
accelerator_main = Accelerator(mixed_precision="fp16")

# Load Model
if accelerator_main.is_main_process:
    print("Main Script: Loading model...")
# model_path = "mistralai/Mistral-7B-Instruct-v0.3"
model_path = "meta-llama/Llama-3.1-8B-Instruct"
# model_path = "Qwen/Qwen2.5-3B-Instruct"

model_cpu = AutoModelForCausalLM.from_pretrained(
    model_path,
    torch_dtype=torch.float16
)
tokenizer = AutoTokenizer.from_pretrained(model_path)
if tokenizer.pad_token is None:
    tokenizer.pad_token = tokenizer.eos_token
    model_cpu.config.pad_token_id = tokenizer.pad_token_id
    if hasattr(model_cpu, 'generation_config') and model_cpu.generation_config is not None:
        model_cpu.generation_config.pad_token_id = tokenizer.pad_token_id

if accelerator_main.is_main_process:
    print("Main Script: Preparing model with Accelerator...")
prepared_model = accelerator_main.prepare(model_cpu)
unwrapped_prepared_model = accelerator_main.unwrap_model(prepared_model)
unwrapped_prepared_model.eval()
if accelerator_main.is_main_process:
    print("Main Script: Model prepared and set to eval.")

# Define utility cache

accelerator_main.wait_for_everyone()

Detected kernel version 5.4.0, which is below the recommended minimum of 5.5.0; this can cause the process to hang. It is recommended to upgrade the kernel to the minimum version or higher.


Main Script: Loading model...


2025-10-31 08:46:34.684521: I tensorflow/core/util/port.cc:153] oneDNN custom operations are on. You may see slightly different numerical results due to floating-point round-off errors from different computation orders. To turn them off, set the environment variable `TF_ENABLE_ONEDNN_OPTS=0`.
2025-10-31 08:46:34.736395: I tensorflow/core/platform/cpu_feature_guard.cc:210] This TensorFlow binary is optimized to use available CPU instructions in performance-critical operations.
To enable the following instructions: AVX2 AVX512F AVX512_VNNI FMA, in other operations, rebuild TensorFlow with the appropriate compiler flags.
2025-10-31 08:46:36.268557: I tensorflow/core/util/port.cc:153] oneDNN custom operations are on. You may see slightly different numerical results due to floating-point round-off errors from different computation orders. To turn them off, set the environment variable `TF_ENABLE_ONEDNN_OPTS=0`.
2025-10-31 08:46:36.268557: I tensorflow/core/util/port.cc:153] oneDNN custom op

Main Script: Preparing model with Accelerator...
Main Script: Model prepared and set to eval.


In [6]:
with open("answered.txt", "r") as f:
    res = [line.strip() for line in f]

In [None]:
def gtset_k():
    return [0, 1,5]

num_questions_to_run = 50
k_values = [1, 2, 3, 4, 5]
all_results = []
extras = []

for i in range(num_questions_to_run):
    query = df.question[i]
    if res[i]=="True":
        if accelerator_main.is_main_process:
            print(f"\n--- Question {i+1}/{num_questions_to_run}: {query[:60]}... ---")

        docs = df.paragraphs[i]
        utility_cache_base_dir = f"../Experiment_data/musique/{model_path.split('/')[1]}/new/duplicate"
        utility_cache_filename = f"utilities_q_idx{i}.pkl"
        current_utility_path = os.path.join(utility_cache_base_dir, utility_cache_filename)

        if accelerator_main.is_main_process:
            os.makedirs(os.path.dirname(current_utility_path), exist_ok=True)

        harness = ContextAttribution(
            items=docs,
            query=query,
            prepared_model=prepared_model,
            prepared_tokenizer=tokenizer,
            accelerator=accelerator_main,
            utility_cache_path=current_utility_path
        )
        full_budget=pow(2,harness.n_items)
        # res = evaluate(df.question[i], harness.target_response, df.answer[i])
        # res='True'
        if accelerator_main.is_main_process:
            methods_results = {}
            metrics_results = {}
            extra_results = {}

            m_samples_map = {"XS": 32, "S":64, "M":128, "L":264, "XL":528, "XXL":724, "XXXL":1024}

            # Store FM models for later R²/MSE
            fm_models = {}
            methods_results['Exact-Shap']=harness._calculate_shapley()
            for size_key, actual_samples in m_samples_map.items():
                print(f"Running sample size: {actual_samples}")
                methods_results[f"ContextCite_{actual_samples}"], fm_models[f"ContextCite_{actual_samples}"] = harness.compute_contextcite(
                    num_samples=actual_samples, seed=SEED
                )
                # FM Weights (loop over ranks 0–5)
                # for rank in range(5, -1, -1):
                    # methods_results[f"FM_WeightsLK_{rank}_{actual_samples}"], extra_results[f"Flk_{rank}_{actual_samples}"], fm_models[f"FM_WeightsLK_{rank}_{actual_samples}"] = harness.compute_wss(
                    #     num_samples=actual_samples,
                    #     seed=SEED,
                    #     sampling="kernelshap",
                    #     sur_type="fm",
                    #     rank=rank
                    # )
                    # methods_results[f"FM_WeightsLU_{rank}_{actual_samples}"], extra_results[f"Flu_{rank}_{actual_samples}"], fm_models[f"FM_WeightsLU_{rank}_{actual_samples}"] = harness.compute_wss(
                    #     num_samples=actual_samples,
                    #     seed=SEED,
                    #     sampling="kernelshap",
                    #     sur_type="fm",
                    #     rank=rank
                    # )
                # methods_results[f"FM_u_dynamic_{actual_samples}"], extra_results[f"FM_u_dynamic_{actual_samples}"], fm_models[f"FM_u_dynamic_{actual_samples}"] = harness.compute_wss_dynamic_pruning_reuse_utility(num_samples=actual_samples)
                methods_results[f"FM_k_dynamic_{actual_samples}"], extra_results[f"FM_k_dynamic_{actual_samples}"], fm_models[f"FM_k_dynamic_{actual_samples}"] = harness.compute_wss_dynamic_pruning_reuse_utility(num_samples=actual_samples)
                # methods_results[f"FM_k_dynamice_{actual_samples}"], extra_results[f"FM_k_dynamice_{actual_samples}"], fm_models[f"FM_k_dynamice_{actual_samples}"] = harness.compute_wss_dynamic_pruning_reuse_utility(num_samples=actual_samples, pruning_strategy='elbow')
                try:
                    # attributionsspex, interactionspex = harness.compute_spex(sample_budget=actual_samples, max_order=2)
                    attributionshap, interactionshap, fm_models[f"FSII_{actual_samples}"] = harness.compute_fsii(sample_budget=actual_samples, max_order=2)
                    # attributionban, interactionban, fm_models[f"FBII_{actual_samples}"] = harness.compute_fbii(sample_budget=actual_samples, max_order=harness.n_items)
                    # methods_results[f"FBII_{actual_samples}"] = attributionban
                    methods_results[f"FSII_{actual_samples}"] = attributionshap
                    # methods_results[f"Spex_{actual_samples}"] = attributionsspex


                    extra_results.update({
                    f"Int_FSII_{actual_samples}":interactionshap,
                    # f"Int_FBII_{actual_samples}":interactionban,
                    # f"Int_Spex_{actual_samples}":interactionspex
                                                                            })
                except Exception: pass


        #     methods_results["LOO"] = harness.compute_loo()
        #     methods_results["ARC-JSD"] = harness.compute_arc_jsd()
            attributionxs, interactionxs, fm_models["Exact-FSII"] = harness.compute_exact_fsii(max_order=2)

            extra_results.update({
            "Exact-FSII": interactionxs
        })
            methods_results["Exact-FSII"]=attributionxs

            # --- Evaluation Metrics ---
            metrics_results["topk_probability"] = harness.evaluate_topk_performance(
                methods_results, fm_models, k_values
            )

            # R²
            metrics_results["R2"] = harness.r2(methods_results,100,mode='logit-prob', models=fm_models)
            metrics_results['Recall']=harness.recall_at_k(gtset_k(), methods_results, k_values)

            # LDS per method
            metrics_results["LDS"] = harness.lds(methods_results,100,mode='logit-prob', models=fm_models)



            all_results.append({
                "query_index": i,
                "query": query,
                "ground_truth": df.answer[i],
                "response": harness.target_response,
                "methods": methods_results,
                "metrics": metrics_results
            })
            extras.append(extra_results)

            # Save utility cache
            harness.save_utility_cache(current_utility_path)
# with open(f"{utility_cache_base_dir}/results.pkl", "wb") as f:
#     pickle.dump(all_results, f)

# with open(f"{utility_cache_base_dir}/extras.pkl", "wb") as f:
#     pickle.dump(extras, f)


--- Question 2/50: What year saw the creation of the region where the county of... ---
Main Process: Attempting to load utility cache from ../Experiment_data/musique/Llama-3.1-8B-Instruct/new/duplicate/utilities_q_idx1.pkl...
Successfully loaded 2048 cached utility entries.
Running sample size: 32
Initial scores: [0.46768362 2.48132338 0.92720086 1.52526116 1.23844564 4.11825813
 1.10390648 1.46441797 0.07421359 0.24894526 0.29845612]
We are keeping 7 documents
Running SPEX with method
SPEX approximator initialized.
Running sample size: 64
Initial scores: [5.03847300e-01 7.33616224e+00 1.38127676e+00 7.44849138e-01
 1.03131567e+00 4.58525371e+00 3.78735932e-03 1.15978803e-01
 6.92608620e-01 1.16109147e+00 8.70767382e-01]
We are keeping 7 documents
Running SPEX with method
SPEX approximator initialized.
Running sample size: 128
Initial scores: [0.12291553 5.42581114 0.90794406 0.34627293 0.21403151 5.88306015
 0.62363095 0.26174867 0.51573362 0.45089137 0.06614052]
We are keeping 7 doc

In [None]:
harness._run_spex("FBII", 524, harness.n_items, 'logit-prob')

In [None]:
len(harness.utility_cache)

In [None]:
with open(f"{utility_cache_base_dir}/results.pkl", "wb") as f:
    pickle.dump(all_results, f)

with open(f"{utility_cache_base_dir}/extras.pkl", "wb") as f:
    pickle.dump(extras, f)

In [None]:

with open(f"../Experiment_data/musique/Llama-3.1-8B-Instruct/new/duplicate/results.pkl", "rb") as f:
    all_results = pickle.load(f)

In [None]:
with open(f"../Experiment_data/musique/Llama-3.1-8B-Instruct/new/duplicate/extras.pkl", "rb") as f:
    extras = pickle.load(f)

In [None]:
harness._get_top_k_utility_subsets(4, 'logit-prob')

In [None]:
spearmans={i:[] for i in all_results[0]['methods'] if i!="Exact-Shap"}
for method_res in all_results:
    for method, attribution in method_res['methods'].items():
        if method!="Exact-Shap":
            spear=len(set(np.array(method_res['methods']["Exact-Shap"]).argsort()[-3:]).intersection(set(np.array(attribution).argsort()[-3:])))/3
            spearmans[method].append(spear)

In [8]:
from sklearn.preprocessing import MinMaxScaler
from sklearn.metrics import ndcg_score
import numpy as np

spearmans = {i: [] for i in all_results[0]['methods'] if i != "Exact-Shap" and "Wei" not in all_results[0]['methods']}
scaler = MinMaxScaler()

for method_res in all_results:
    for method, attribution in method_res['methods'].items():
        if method != "Exact-Shap" and "Wei" not in method:
            # Convert to numpy arrays for scaling
            ref = np.array(method_res['methods']["Exact-Shap"]).reshape(-1, 1)
            att = np.array(attribution).reshape(-1, 1)
            
            # Scale both reference and attribution to [0, 1]
            ref_scaled = scaler.fit_transform(ref).flatten()
            att_scaled = scaler.fit_transform(att).flatten()
            
            # Compute NDCG score
            spear = ndcg_score([ref_scaled], [att_scaled], k=4)
            spearmans[method].append(spear)


IndexError: list index out of range

In [9]:
import re
import numpy as np
import matplotlib.pyplot as plt

# Parse methods and budgets
parsed = {}
budgets = set()
for key, values in spearmans.items():
    avg_val = np.mean(values)
    
    match = re.match(r"(.+?)_(\d+)$", key)  # method_budget pattern
    if match:
        method, budget = match.groups()
        budget = int(budget)
        budgets.add(budget)
        parsed.setdefault(method, {})[budget] = avg_val
    else:
        # constant methods (no budget)
        parsed.setdefault(key, {})[None] = avg_val

budgets = sorted(budgets)

# Plot
plt.figure(figsize=(10,6))

for method, results in parsed.items():
    if None in results:  # constant method
        plt.hlines(results[None], xmin=min(budgets), xmax=max(budgets), 
                   linestyles='--', label=method)
    else:
        xs = sorted(results.keys())
        ys = [results[b] for b in xs]
        plt.plot(xs, ys, marker='o', label=method)

plt.xlabel("Budget")
plt.ylabel("Average Spearman")
plt.title("Average Recall to Exact Shap per Method vs Budget")
plt.legend()
plt.grid(True)
plt.show()


NameError: name 'spearmans' is not defined

In [None]:
import pandas as pd
import numpy as np
from collections import defaultdict

def summarize_and_print(all_results, k_values=[1, 2, 3,4,5]):
    table_data = defaultdict(lambda: defaultdict(list))

    # Mapping for consistency
    method_name_map = {
        
    }

    for res in all_results:
        metrics = res["metrics"]
        # LDS and R2
        for method_name, lds_val in metrics["LDS"].items():
            method = method_name_map.get(method_name, method_name)
            table_data[method]["LDS"].append(lds_val)

        for method_name, lds_val in metrics["R2"].items():
            method = method_name_map.get(method_name, method_name)
            table_data[method]["R2"].append(lds_val)
        # Top-k
        for method_name, k_dict in metrics["topk_probability"].items():
            method = method_name_map.get(method_name, method_name)
            for k in k_values:
                if k in k_dict:
                    col_name = f"topk_probability_k{k}"
                    table_data[method][col_name].append(k_dict[k])
        
        for method_name, k_dict in metrics["Recall"].items():
            method = method_name_map.get(method_name, method_name)
            for k in k_values:
                col_name = f"Recall@{k}"
                table_data[method][col_name].append(k_dict[k-1])
    # Averages
    avg_table = {
        method: {metric: np.nanmean(values) for metric, values in metric_dict.items()}
        for method, metric_dict in table_data.items()
    }

    # Standard deviations for LDS, R², and MSE
    for method, metric_dict in table_data.items():
        for metric in ["LDS", "R2"]:
            if metric in metric_dict:
                avg_table[method][f"{metric}_std"] = np.nanstd(metric_dict[metric])

    df_summary = pd.DataFrame.from_dict(avg_table, orient="index").sort_index()

    print("\n=== Metrics Summary Across All Queries ===")
    print(df_summary.to_string(float_format="%.4f"))

    return df_summary
df_res=summarize_and_print(all_results, k_values=[1, 2, 3,4,5])


In [None]:
import pandas as pd
import matplotlib.pyplot as plt

# Reset index
df_reset = df_res.reset_index().rename(columns={'index': 'method'})

# Separate constant methods (no budget) and budgeted methods
constant_methods = ['LOO', 'ARC-JSD', 'Exact-FSII', 'Exact-Shap']
df_const = df_reset[df_reset['method'].isin(constant_methods)]
df_budgeted = df_reset[~df_reset['method'].isin(constant_methods)]

# Extract family and budget for budgeted methods
df_budgeted['family'] = df_budgeted['method'].apply(lambda x: "_".join(x.split("_")[:-1]))
df_budgeted['budget'] = df_budgeted['method'].apply(lambda x: int(x.split("_")[-1]))
df_budgeted = df_budgeted.sort_values(by=['family', 'budget'])

# Function to plot metric
def plot_metric(metric, ylabel):
    plt.figure(figsize=(12, 6))

    # Plot budgeted families
    families = df_budgeted['family'].unique()
    for fam in families:
        if 'LK' not in fam and "Wei" not in method:
            subset = df_budgeted[df_budgeted['family'] == fam]
            plt.plot(subset['budget'], subset[metric], marker='o', label=fam)

    # Plot constant methods as horizontal lines
    colors = plt.cm.tab10.colors  # categorical palette
    for idx, (_, row) in enumerate(df_const.iterrows()):
        plt.axhline(y=row[metric], color=colors[idx % len(colors)],marker='x', label=row['method'])

    plt.xlabel("Budget")
    plt.ylabel(ylabel)
    plt.title(f"Evolution of {ylabel} with Increasing Budget")
    plt.legend(bbox_to_anchor=(1.05, 1), loc='upper left')
    plt.grid(True)
    plt.tight_layout()
    plt.show()

# Plot LDS
# plot_metric("LDS", "LDS")

# Plot R²
# plot_metric("R2", "R²")

# plot_metric("Recall@1", "Recall 1")
# plot_metric("Recall@2", "Recall 2")
# plot_metric("Recall@3", "Recall 3")
# plot_metric("Recall@4", "Recall 4")
# plot_metric("Recall@5", "Recall 5")


In [None]:
plot_metric("R2", "R2")

In [None]:
plot_metric("LDS", "LDS")

In [None]:
import pandas as pd
import matplotlib.pyplot as plt

# Filter budgeted methods at budget = 274
df_budgeted_264 = df_budgeted[df_budgeted['budget'] == 1024]

# Metrics to plot
recall_metrics = [f"Recall@{k}" for k in range(1, 6)]
k_values = list(range(1, 6))

plt.figure(figsize=(10, 6))

# Plot budgeted families at budget 274
families = df_budgeted_264['family'].unique()
for fam in families:
    if 'LK' not in fam:
        subset = df_budgeted_264[df_budgeted_264['family'] == fam]
        if not subset.empty:
            recalls = subset[recall_metrics].values.flatten()
            plt.plot(k_values, recalls, marker='o', label=fam)

# Plot constant methods
for idx, (_, row) in enumerate(df_const.iterrows()):
    recalls = [row[m] for m in recall_metrics]
    plt.plot(k_values, recalls, marker='x', linestyle="--", label=row['method'])

plt.xlabel("k")
plt.ylabel("Recall@k")
plt.title("Recall@k for Budget = 264")
plt.legend(bbox_to_anchor=(1.05, 1), loc='upper left')
plt.grid(True)
plt.tight_layout()
plt.show()


In [None]:
import pandas as pd
import matplotlib.pyplot as plt

# Reset index
df_reset = df_res.reset_index().rename(columns={'index': 'method'})

# Parse FM methods (rank + budget)
def parse_fm(method):
    parts = method.split("_")
    if parts[0] == "FM" and parts[1].startswith("Weights"):
        # Example: FM_WeightsLU_5_32 → parts = ['FM', 'WeightsLU', '5', '32']
        variant = parts[1]        # 'WeightsLU' or 'WeightsLK'
        rank = int(parts[2])
        budget = int(parts[-1])
        return f"{parts[0]}_{variant}", rank, budget
    return None, None, None

df_reset[["family", "rank", "budget"]] = pd.DataFrame(
    df_reset["method"].apply(parse_fm).tolist(),
    index=df_reset.index
)
# Separate FM and non-FM methods
df_fm = df_reset[df_reset['family'].notnull()]
df_nonfm = df_reset[df_reset['family'].isnull()]

# Keep only ContextCite baselines
# df_rest = df_nonfm[df_nonfm['method'].str.startswith("ContextCite")]

# Function to plot R² for a given budget
def plot_r2_for_budget(budget):
    plt.figure(figsize=(10, 5))

    for family, subset in df_fm[df_fm['budget'] == budget].groupby("family"):
        subset = subset.sort_values(by="rank")
        plt.plot(subset["rank"], subset["R2"], marker='o', label=f"{family} (R²)")

    cmap = plt.cm.get_cmap("tab10", len(df_nonfm['method'].unique()))

    # Assign each unique method a color
    method_to_color = {
        method: cmap(i) for i, method in enumerate(df_nonfm['method'].unique())
    }

    # Plot
    for _, row in df_nonfm.iterrows():
        if row['method'].endswith(f"_{budget}"):
            plt.axhline(
                y=row['R2'],
                alpha=0.7,
                label=row['method'],
                color=method_to_color[row['method']]
            )

    plt.xlabel("Rank")
    plt.ylabel("R²")
    plt.title(f"Evolution of R² with Rank (Budget = {budget})")
    plt.legend(bbox_to_anchor=(1.05, 1), loc='upper left')
    plt.grid(True)
    plt.tight_layout()
    plt.show()

# Example: plot for budget = 528
plot_r2_for_budget(32)
plot_r2_for_budget(64)
plot_r2_for_budget(128)
plot_r2_for_budget(264)
plot_r2_for_budget(528)
plot_r2_for_budget(724)

In [None]:
plt.figure(figsize=(8, 5))
for method in df_res.index:
    if "264" in method and "Wei" not in method:
        plt.plot(
            [1, 2, 3,4,5],
            df_res.loc[method, ['topk_probability_k1', 'topk_probability_k2', 'topk_probability_k3', 'topk_probability_k4', 'topk_probability_k5']],
            marker='o',
            label=method
        )

plt.xlabel('k')
plt.ylabel('Probability Drop')
plt.title('Top-k Probability Drop')
plt.legend()
plt.grid(True, linestyle='--', alpha=0.6)
plt.tight_layout()
plt.show()

In [None]:
extras[2].keys()

In [None]:
import numpy as np

def evaluate_methods(extras, k, m, interaction_type="max"):

    methods = extras[0].keys()
    scores = {m: 0 for m in methods}
    n_experiments = len(extras)

    for exp in extras:
        for method in methods:
            if "Fl" in method or "FM" in method:
                # Flu is a matrix
                value = exp[method][k][m]
                all_values = exp[method].flatten()
            else:
                # Dictionaries with tuple keys
                d = exp[method]
                value = None
                for key, v in d.items():
                    if key == (k, m):
                        value = v
                        break
                if value is None:
                    continue  # skip if (k,m) not found
                all_values = list(d.values())

            if interaction_type == "max":
                if value == max(all_values):
                    scores[method] += 1
            elif interaction_type == "min":
                if value == min(all_values):
                    scores[method] += 1

    # Convert to fraction of experiments
    results = {method: scores[method] / n_experiments for method in methods}
    return results

# Exact match

In [None]:
# Recovery rate
em={}
for i, j in enumerate(np.array(list(evaluate_methods(extras, k=0, m=1, interaction_type="max").values()))+np.array(list(evaluate_methods(extras, k=0, m=5, interaction_type="max").values()))+np.array(list(evaluate_methods(extras, k=1, m=5, interaction_type="min").values()))):
    em.update({list(extras[0].keys())[i]:j})

In [None]:
rows = []
for k, v in em.items():
    parts = k.split("_")
    if parts[0] == "Flu" and parts[1]!='0':
        _, rank, budget = parts
        rows.append({"method": "Flu", "rank": int(rank), "budget": int(budget), "recovery": v})
    elif parts[0] == "FM"and parts[3]!='0':
        _,_,_,budget = parts
        rows.append({"method": f'FM_k_dynamic', "budget": int(budget), "recovery": v})
    elif parts[0] == "Int":
        _, name, budget = parts
        rows.append({"method": name, "budget": int(budget), "recovery": v})

df = pd.DataFrame(rows)

# Plot
plt.figure(figsize=(10, 6))

# Plot Flu (different ranks as lines)
# for rank in sorted(df[df["method"]=="Flu"]["rank"].unique()):
#     sub = df[(df["method"]=="Flu") & (df["rank"]==rank)].sort_values("budget")
#     plt.plot(sub["budget"], sub["recovery"]/3, marker="o", label=f"Flu rank {rank}")

# for rank in sorted(df[df["method"]=="FM_k_dynamic"]["rank"].unique()):
#     sub = df[(df["method"]=="FM_k_dynamic") & (df["rank"]==rank)].sort_values("budget")
#     plt.plot(sub["budget"], sub["recovery"]/3, marker="o", label=f"FM_k_dynamic rank {rank}")
# for rank in sorted(df[df["method"]=="FM_r_dynamic"]):
sub1 = df[(df["method"]=="FM_k_dynamic")].sort_values("budget")
plt.plot(sub1["budget"], sub1["recovery"], marker="o", label=f"FM_k_dynamic")
plt.plot(sub1["budget"],em['Exact-FSII']*np.ones(7), marker="x", linestyle="--", label=f"Exact-FSII")

# Plot Int methods (evolve with budget, start at 264)
for m in df[df["method"].isin(["FSII","FBII","Spex"])]['method'].unique():
    sub = df[df["method"]==m].sort_values("budget")
    plt.plot(sub["budget"], sub["recovery"]/3, marker="+", linestyle="--", label=f"Int {m}")



plt.xlabel("Budget")
plt.ylabel("Recovery Rate")
plt.title("Evolution of Recovery Rate with Increasing Budget")
plt.legend(bbox_to_anchor=(1.05, 1), loc="upper left")
plt.grid(True)
plt.tight_layout()
plt.show()

In [None]:
extras[1].keys()

In [None]:
1. compare the new and old fm
2. iteration with the interactions
3. shapley is for first order and faithshap for pairwise

# FaithShap recall

In [None]:
import re
import numpy as np
import matplotlib.pyplot as plt
from collections import defaultdict

def get_top2_min_from_dict(d):
    """Return top-2 and min interaction pairs from a dict {(i, j): val}."""
    if not d:
        return set(), None
    sorted_items = sorted(d.items(), key=lambda x: x[1], reverse=True)
    top2 = {sorted_items[0][0], sorted_items[1][0]} if len(sorted_items) >= 2 else {sorted_items[0][0]}
    min_pair = min(d.items(), key=lambda x: x[1])[0]
    return top2, min_pair

def get_top2_min_from_matrix(mat):
    """Return top-2 and min interaction pairs from a numpy or list matrix."""
    mat = np.array(mat)
    pairs = {(i, j): mat[i][j] for i in range(mat.shape[0]) for j in range(mat.shape[1]) if i != j}
    return get_top2_min_from_dict(pairs)

def extract_budget(key):
    match = re.search(r'_(\d+)$', key)
    return int(match.group(1)) if match else None

def extract_family(key):

    if key.startswith("Flu_"):
        # e.g. Flu_5_128 → Flu_5
        m = re.match(r"(Flu_\d+)_\d+", key)
        return m.group(1) if m else None
    elif key.startswith("FM_k_dynamic"):
        return "FM_k_dynamic"
    elif key.startswith("FSII"):
        return "FSII"
    elif key.startswith("FBII"):
        return "FBII"
    return None

def compare_methods_concordance(extras):
    concordance = defaultdict(list)

    for d in extras:
        exact = d.get('Exact-FSII', {})
        exact_top2, exact_min = get_top2_min_from_dict(exact)

        for key, val in d.items():
            if key == 'Exact-FSII':
                continue

            try:
                if isinstance(val, dict):
                    top2, min_pair = get_top2_min_from_dict(val)
                else:
                    top2, min_pair = get_top2_min_from_matrix(val)
            except Exception:
                continue

            matches = len(exact_top2.intersection(top2))
            if exact_min == min_pair:
                matches += 1

            score = matches / 3
            concordance[key].append(score)

    avg_concordance = {k: np.mean(v) for k, v in concordance.items() if v}
    return avg_concordance

def group_by_family_and_budget(avg_concordance):
    grouped = defaultdict(list)
    for key, score in avg_concordance.items():
        budget = extract_budget(key)
        family = extract_family(key)
        if family and budget:
            grouped[family].append((budget, score))

    # Sort budgets
    for family in grouped:
        grouped[family].sort(key=lambda x: x[0])
    return grouped

def plot_concordance(grouped):
    plt.figure(figsize=(9,6))
    for family, data in grouped.items():
        budgets, scores = zip(*data)
        plt.plot(budgets, scores, marker='o', label=family)

    plt.xlabel("Budget")
    plt.ylabel("Average Concordance to Exact-FSII")
    plt.title("Concordance vs Budget (Top-2 + Min Matching)")
    plt.legend()
    plt.grid(True)
    plt.tight_layout()
    plt.show()

avg_concordance = compare_methods_concordance(extras)
grouped = group_by_family_and_budget(avg_concordance)
plot_concordance(grouped)


# RR@k

In [None]:
def compute_rr_at_k(interaction, ground_truth, k):
    """
    Compute Recovery@k for a method's interaction dict or matrix.
    interaction: dict {(i, j): value} or 2D numpy array/matrix
    ground_truth: set of ground-truth indices (R^*)
    k: number of top interactions to consider
    Returns: RR@k value
    """
    # Convert matrix to dict if needed
    if isinstance(interaction, (np.ndarray, list)):
        mat = np.array(interaction)
        pairs = {(i, j): mat[i][j] for i in range(mat.shape[0]) for j in range(mat.shape[1]) if i != j}
    else:
        pairs = interaction

    # Sort pairs by value (descending)
    sorted_pairs = sorted(pairs.items(), key=lambda x: x[1], reverse=True)
    rr_sum = 0.0
    for i in range(min(k, len(sorted_pairs))):
        pair_indices = set(sorted_pairs[i][0])
        rr_sum += len(ground_truth & pair_indices) / len(pair_indices)
    return rr_sum / k if k > 0 else 0.0

In [None]:
ground_truth = set([0, 1, 5])  # Example ground-truth indices
k = 5  # Number of top interactions to consider

rr_results = {}
for method, interaction in extras[0].items():  # Use the correct experiment index
    rr_results[method] = compute_rr_at_k(interaction, ground_truth, k)

print(rr_results)

In [None]:
compute_rr_at_k(extras[0]['FM_k_dynamic_1024'], set([0, 1, 5]), 3)

In [None]:
import re
import matplotlib.pyplot as plt
import numpy as np
from collections import defaultdict

# Parse RR@k for all experiments and budgets
def extract_budget(key):
    match = re.search(r'_(\d+)$', key)
    return int(match.group(1)) if match else None

def extract_family(key):
    if key.startswith("FM_k_dynamic"):
        return "FM_k_dynamic"
    elif key.startswith("FSII"):
        return "FSII"
    elif key.startswith("FBII"):
        return "FBII"
    return None

def collect_rr_at_k_over_budgets(extras, ground_truth, k):
    # For each experiment, for each method, collect RR@k by budget
    rr_by_method_budget = defaultdict(lambda: defaultdict(list))
    for exp in extras:
        for method, interaction in exp.items():
            budget = extract_budget(method)
            family = extract_family(method)
            if budget and family:
                rr = compute_rr_at_k(interaction, ground_truth, k)
                rr_by_method_budget[family][budget].append(rr)
    # Average over experiments
    rr_avg = defaultdict(dict)
    for family, budgets in rr_by_method_budget.items():
        for budget, vals in budgets.items():
            rr_avg[family][budget] = np.mean(vals)
    return rr_avg
# Plot RR@k as line chart for each method family, with constant methods as parallel lines
plt.figure(figsize=(10, 6))

# Plot budgeted families
for family, budget_rrs in rr_avg.items():
    budgets = sorted(budget_rrs.keys())
    values = [budget_rrs[b] for b in budgets]
    plt.plot(budgets, values, marker='o', label=family)

# Plot constant methods (e.g., Exact-FSII, LOO, ARC-JSD) as horizontal lines
constant_methods = ['Exact-FSII', 'LOO', 'ARC-JSD']
for method in constant_methods:
    # Collect RR@k for each experiment and average
    rr_vals = []
    for exp in extras:
        if method in exp:
            rr_vals.append(compute_rr_at_k(exp[method], ground_truth, k))
    if rr_vals:
        avg_rr = np.mean(rr_vals)
        plt.axhline(y=avg_rr, color=None, linestyle='--', label=method)

plt.xlabel('Budget')
plt.ylabel(f'RR@{k}')
plt.title(f'Recovery Rate at k={k} vs Budget (Averaged over experiments)')
plt.legend()
plt.grid(True)
plt.tight_layout()
plt.show()


In [None]:
# Plot RR@k as line chart for each method family, with constant methods as parallel lines
plt.figure(figsize=(10, 6))

# Plot budgeted families
for family, budget_rrs in rr_avg.items():
    budgets = sorted(budget_rrs.keys())
    values = [budget_rrs[b] for b in budgets]
    plt.plot(budgets, values, marker='o', label=family)

# Plot constant methods (e.g., Exact-FSII, LOO, ARC-JSD) as horizontal lines
constant_methods = ['Exact-FSII', 'LOO', 'ARC-JSD']
for method in constant_methods:
    # Collect RR@k for each experiment and average
    rr_vals = []
    for exp in extras:
        if method in exp:
            rr_vals.append(compute_rr_at_k(exp[method], ground_truth, k))
    if rr_vals:
        avg_rr = np.mean(rr_vals)
        plt.axhline(y=avg_rr, color=None, linestyle='--', label=method)

plt.xlabel('Budget')
plt.ylabel(f'RR@{k}')
plt.title(f'Recovery Rate at k={k} vs Budget (Averaged over experiments)')
plt.legend()
plt.grid(True)
plt.tight_layout()
plt.show()