In [1]:
import sys
import os
import random
import gc
import time
import torch
import numpy as np
import pandas as pd
import pickle
import ast
from tqdm import tqdm
from scipy.sparse import csr_matrix
import itertools
from scipy.stats import spearmanr, pearsonr, kendalltau, rankdata
from sklearn.metrics import ndcg_score
from transformers import AutoModelForCausalLM, AutoTokenizer
from accelerate import Accelerator
import nltk
nltk.download('punkt')
os.environ["CUDA_VISIBLE_DEVICES"] = "2" 
current_dir = os.getcwd()
parent_dir = os.path.abspath(os.path.join(current_dir, '..'))
sys.path.append(parent_dir)
from SHapRAG import *

  from .autonotebook import tqdm as notebook_tqdm
[nltk_data] Downloading package punkt to /root/nltk_data...
[nltk_data]   Package punkt is already up-to-date!


In [2]:
df=pd.read_json("../data/musique/musique_ans_v1.0_train.jsonl", lines=True)

In [3]:
def get_titles(lst):
    # Titles where is_supporting is True
    supporting = [d['paragraph_text'] for d in lst if d.get('is_supporting') == True]
    # Titles where is_supporting is False or missing AND not already in supporting
    others = [d['paragraph_text'] for d in lst if d.get('is_supporting') != True and d['paragraph_text'] not in supporting]
    # Combine: all supporting + as many others as needed to reach 10
    result = supporting + others
    return result[:10]

df.paragraphs=df.paragraphs.apply(get_titles)

In [4]:
df['Sentences'] = df['paragraphs'].apply(
    lambda para_list: [sent for para in para_list for sent in nltk.sent_tokenize(para)]
)

In [5]:
df_save=pd.read_csv('../data/musique/sen_labeled.csv',
    quotechar='"',
    skipinitialspace=True,
    engine='python' )

In [None]:
df["paragraphs"] = df["paragraphs"].apply(lambda p: p[:5]+ [p[1]] + p[5:])

In [6]:
SEED = 42
# Initialize Accelerator
accelerator_main = Accelerator(mixed_precision="fp16")

# Load Model
if accelerator_main.is_main_process:
    print("Main Script: Loading model...")
# model_path = "mistralai/Mistral-7B-Instruct-v0.3"
model_path = "meta-llama/Llama-3.1-8B-Instruct"
# model_path = "Qwen/Qwen2.5-3B-Instruct"

model_cpu = AutoModelForCausalLM.from_pretrained(
    model_path,
    torch_dtype=torch.float16
)
tokenizer = AutoTokenizer.from_pretrained(model_path)
if tokenizer.pad_token is None:
    tokenizer.pad_token = tokenizer.eos_token
    model_cpu.config.pad_token_id = tokenizer.pad_token_id
    if hasattr(model_cpu, 'generation_config') and model_cpu.generation_config is not None:
        model_cpu.generation_config.pad_token_id = tokenizer.pad_token_id

if accelerator_main.is_main_process:
    print("Main Script: Preparing model with Accelerator...")
prepared_model = accelerator_main.prepare(model_cpu)
unwrapped_prepared_model = accelerator_main.unwrap_model(prepared_model)
unwrapped_prepared_model.eval()
if accelerator_main.is_main_process:
    print("Main Script: Model prepared and set to eval.")

# Define utility cache

accelerator_main.wait_for_everyone()

Detected kernel version 5.4.0, which is below the recommended minimum of 5.5.0; this can cause the process to hang. It is recommended to upgrade the kernel to the minimum version or higher.


Main Script: Loading model...


Loading checkpoint shards: 100%|██████████| 4/4 [00:01<00:00,  2.69it/s]


Main Script: Preparing model with Accelerator...
Main Script: Model prepared and set to eval.


In [7]:
num_questions_to_run=50
# num_questions_to_run=1
k_values = [1,2,3,4,5]
all_results=[]

# Define ground truth set of docs for precision (adapt as needed)
# e.g., if first 2 docs are always relevant
# def get_gtset_k():
#     return [0, 1]

for i in tqdm(range(num_questions_to_run), disable=not accelerator_main.is_main_process):
    query = df.Sentences[i]
    gt=ast.literal_eval(df_save.labels[i])
    if accelerator_main.is_main_process:
        print(f"\n--- Question {i+1}/{num_questions_to_run}: {query[:60]}... ---")

    docs=df.paragraphs[i]
    utility_cache_base_dir = f"../Experiment_data/musique/{model_path.split('/')[1]}/sentence"
    utility_cache_filename = f"utilities_q_idx{i}.pkl"
    current_utility_path = os.path.join(utility_cache_base_dir, utility_cache_filename)

    if accelerator_main.is_main_process:
        os.makedirs(os.path.dirname(current_utility_path), exist_ok=True)

    harness = ContextAttribution(
        items=docs,
        query=query,
        prepared_model=prepared_model,
        prepared_tokenizer=tokenizer,
        accelerator=accelerator_main,
        utility_cache_path=current_utility_path
    )

    if accelerator_main.is_main_process:
        methods_results = {}
        metrics_results = {}
        extra_results = {}

        m_samples_map = {"L": 364}

        # Store FM models for later R²/MSE
        fm_models = {}

        for size_key, num_s in m_samples_map.items():
            if 2 ** len(docs) < num_s and size_key != "L":
                actual_samples = max(1, 2 ** len(docs) - 1 if 2 ** len(docs) > 0 else 1)
            else:
                actual_samples = num_s

            if actual_samples > 0:
                methods_results[f"ContextCite{actual_samples}"], model_cc = harness.compute_contextcite(
                    num_samples=actual_samples, seed=SEED
                )

                attributions, _ = harness.compute_spex(sample_budget=actual_samples, max_order=2)
                methods_results[f"FBII{actual_samples}"] = attributions['fbii']
                methods_results[f"Spex{actual_samples}"] = attributions['fourier']
                methods_results[f"FSII{actual_samples}"] = attributions['fsii']

                # methods_results[f"FM_WeightsDU{actual_samples}"], Fdu, modelfmdu = harness.compute_wss(
                #     num_samples=actual_samples, seed=SEED, sampling="uniform",
                #     sur_type="fm", utility_mode="divergence_utility"
                # )
                # methods_results[f"FM_WeightsDK{actual_samples}"], Fdk, modelfmdk = harness.compute_wss(
                #     num_samples=actual_samples, seed=SEED, sampling="kernelshap",
                #     sur_type="fm", utility_mode="divergence_utility"
                # )
                methods_results[f"FM_WeightsLK{actual_samples}"], Flk, modelfmlk = harness.compute_wss(
                    num_samples=actual_samples, seed=SEED, sampling="kernelshap", sur_type="fm"
                )
                methods_results[f"FM_WeightsLU{actual_samples}"], Flu, modelfmlu = harness.compute_wss(
                    num_samples=actual_samples, seed=SEED, sampling="uniform", sur_type="fm"
                )

                # Save FM models
                fm_models.update({
                    # f"FM_WeightsDU{actual_samples}": modelfmdu,
                    # f"FM_WeightsDK{actual_samples}": modelfmdk,
                    f"FM_WeightsLK{actual_samples}": modelfmlk,
                    f"FM_WeightsLU{actual_samples}": modelfmlu
                })

                # Save extra Fs
                extra_results.update({
                    # "Fdu": Fdu,
                    # "Fdk": Fdk,
                    "Flk": Flk,
                    "Flu": Flu
                })

        methods_results["LOO"] = harness.compute_loo()
        methods_results["ARC-JSD"] = harness.compute_arc_jsd()

        # --- Evaluation Metrics ---
        metrics_results["topk_probability"] = harness.evaluate_topk_performance(
            methods_results, k_values, utility_type="probability"
        )
        metrics_results["topk_divergence"] = harness.evaluate_topk_performance(
            methods_results, k_values, utility_type="divergence"
        )
        metrics_results["topk_response_probability"] = harness.top_k_response_probability(
            methods_results, k_values=[1, 3, 5]
        )

        # R² and MSE for ContextCite
        r2, mse = harness.r2_mse(30, 'logit-prob', model_cc, method='cc')
        metrics_results["R2_cc"] = r2
        metrics_results["MSE_cc"] = mse

        # R² and MSE for each FM method that has a model
        for method_name, fm_model in fm_models.items():
            r2, mse = harness.r2_mse(30, 'logit-prob', fm_model, method='fm')
            metrics_results[f"R2_{method_name}"] = r2
            metrics_results[f"MSE_{method_name}"] = mse

        # LDS per method
        LDS = {}
        for method_name, scores in methods_results.items():
            if "FM_WeightsLU" in method_name:
                LDS[method_name] = harness.lds(scores, 30, utl=True, model=modelfmlu)
            else:
                LDS[method_name] = harness.lds(scores, 30)
        metrics_results["LDS"] = LDS

        # Precision per method
        precision_scores = {}
        gtset_k = gt
        for method_name, scores in methods_results.items():
            precision_scores[method_name] = harness.precision(gtset_k, scores)
        metrics_results["precision"] = precision_scores

        harness.save_utility_cache(current_utility_path)

        all_results.append({
            "query_index": i,
            "query": query,
            "ground_truth": df.answer[i],
            "methods": methods_results,
            "metrics": metrics_results,
            "extra": extra_results
        })


  0%|          | 0/50 [00:00<?, ?it/s]


--- Question 1/50: ['The Collegian is the bi-weekly official student publication of Houston Baptist University in Houston, Texas.', 'It was founded in 1963 as a newsletter, and adopted the newspaper format in 1990.', "Several private institutions of higher learning—ranging from liberal arts colleges, such as The University of St. Thomas, Houston's only Catholic university, to Rice University, the nationally recognized research university—are located within the city.", 'Rice, with a total enrollment of slightly more than 6,000 students, has a number of distinguished graduate programs and research institutes, such as the James A. Baker Institute for Public Policy.', "Houston Baptist University, affiliated with the Baptist General Convention of Texas, offers bachelor's and graduate degrees.", 'It was founded in 1960 and is located in the Sharpstown area in Southwest Houston.', 'Pakistan Super League (Urdu: پاکستان سپر لیگ \u202c \u200e; PSL) is a Twenty20 cricket league, founded in Lahor

Fetching 3 files: 100%|██████████| 3/3 [00:00<00:00, 13472.07it/s]
Computing utilities for ContextCite: 100%|██████████| 364/364 [01:56<00:00,  3.14it/s]
  0%|          | 0/50 [03:00<?, ?it/s]


KeyboardInterrupt: 

In [None]:
import pandas as pd
import numpy as np
from collections import defaultdict

def summarize_and_print(all_results, k_values=[1, 3, 5]):
    table_data = defaultdict(lambda: defaultdict(list))

    # Mapping for consistency
    method_name_map = {
        "cc": "ContextCite364"  # rename cc to full name
    }

    for res in all_results:
        metrics = res["metrics"]

        # R² / MSE
        for key, val in metrics.items():
            if key.startswith("R2_") or key.startswith("MSE_"):
                raw_method = key.split("_", 1)[1]
                method = method_name_map.get(raw_method, raw_method)  # rename if needed
                metric_name = key.split("_", 1)[0]  # "R2" or "MSE"
                table_data[method][metric_name].append(val)

        # LDS
        for method_name, lds_val in metrics["LDS"].items():
            method = method_name_map.get(method_name, method_name)
            table_data[method]["LDS"].append(lds_val)

        # Precision
        for method_name, prec_val in metrics["precision"].items():
            method = method_name_map.get(method_name, method_name)
            table_data[method]["precision"].append(prec_val)

        # Top-k
        for metric_type in ["topk_probability", "topk_divergence", "topk_response_probability"]:
            for method_name, k_dict in metrics[metric_type].items():
                method = method_name_map.get(method_name, method_name)
                for k in k_values:
                    if k in k_dict:
                        col_name = f"{metric_type}_k{k}"
                        table_data[method][col_name].append(k_dict[k])

    # Averages
    avg_table = {
        method: {metric: np.nanmean(values) for metric, values in metric_dict.items()}
        for method, metric_dict in table_data.items()
    }

    df_summary = pd.DataFrame.from_dict(avg_table, orient="index").sort_index()

    print("\n=== Metrics Summary Across All Queries ===")
    print(df_summary.to_string(float_format="%.4f"))

    return df_summary


In [None]:
summarize_and_print(all_results, k_values=[1, 3, 5])

In [None]:
import numpy as np
import matplotlib.pyplot as plt

def compute_recall_at_k(all_results, k_values=[1,2, 3,4, 5]):
    methods = list(all_results[0]["methods"].keys())
    recall_table = {m: [] for m in methods}

    for k in k_values:
        for method in methods:
            recalls = []
            for p, res in enumerate(all_results):
                gt_indices = set(ast.literal_eval(df_save.labels[p]))  # must exist in each result
                scores = np.array(res["methods"][method])
                topk_indices = set(scores.argsort()[-k:])
                hits = len(gt_indices & topk_indices)
                recalls.append(hits / len(gt_indices) if gt_indices else np.nan)
            recall_table[method].append(np.nanmean(recalls))
    return recall_table


def plot_recall_at_k(recall_table, k_values=[1,2, 3,4, 5]):
    plt.figure(figsize=(8, 5))
    for method, recalls in recall_table.items():
        plt.plot(k_values, recalls, marker="o", label=method)
    plt.xlabel("k")
    plt.ylabel("Recall@k")
    plt.title("Recall@k for All Methods")
    plt.xticks(k_values)
    plt.ylim(0, 1.05)
    plt.grid(True)
    plt.legend()
    plt.show()


In [None]:
recall_table=compute_recall_at_k(all_results, k_values=[1, 2, 3, 4, 5])
plot_recall_at_k(recall_table, k_values=[1, 2, 3, 4, 5])