In [None]:
import sys
import os
import random
import gc
import time
import torch
import numpy as np
import pandas as pd
import ast
from tqdm import tqdm
from scipy.stats import spearmanr, pearsonr, kendalltau, rankdata
from sklearn.metrics import ndcg_score
from transformers import AutoModelForCausalLM, AutoTokenizer
from accelerate import Accelerator
os.environ["CUDA_VISIBLE_DEVICES"] = "3" 
current_dir = os.getcwd()
parent_dir = os.path.abspath(os.path.join(current_dir, '..'))
sys.path.append(parent_dir)
from SHapRAG import *

In [None]:
df= pd.read_csv("../data/synthetic_data/20_synergy_hard_negatives.csv",index_col=False, sep=";")
# df= pd.read_csv("../data/complementary.csv")

In [None]:
SEED = 42
# Initialize Accelerator
accelerator_main = Accelerator(mixed_precision="fp16")

# Load Model
if accelerator_main.is_main_process:
    print("Main Script: Loading model...")
# model_path = "mistralai/Mistral-7B-Instruct-v0.3"
# model_path = "meta-llama/Llama-3.1-8B-Instruct"
model_path = "Qwen/Qwen2.5-3B-Instruct"

model_cpu = AutoModelForCausalLM.from_pretrained(
    model_path,
    torch_dtype=torch.float16
)
tokenizer = AutoTokenizer.from_pretrained(model_path)
if tokenizer.pad_token is None:
    tokenizer.pad_token = tokenizer.eos_token
    model_cpu.config.pad_token_id = tokenizer.pad_token_id
    if hasattr(model_cpu, 'generation_config') and model_cpu.generation_config is not None:
        model_cpu.generation_config.pad_token_id = tokenizer.pad_token_id

if accelerator_main.is_main_process:
    print("Main Script: Preparing model with Accelerator...")
prepared_model = accelerator_main.prepare(model_cpu)
unwrapped_prepared_model = accelerator_main.unwrap_model(prepared_model)
unwrapped_prepared_model.eval()
if accelerator_main.is_main_process:
    print("Main Script: Model prepared and set to eval.")

# Define utility cache

accelerator_main.wait_for_everyone()

In [None]:
num_questions_to_run=len(df.question)
# num_questions_to_run=50
k_values = [1,2,3,4,5]
all_results=[]
LDSs=[]
r2_fm=[]
r2_cc=[]
for i in tqdm(range(num_questions_to_run), disable=not accelerator_main.is_main_process):
    query = df.question[i]
    if accelerator_main.is_main_process:
        print(f"\n--- Question {i+1}/{num_questions_to_run}: {query[:60]}... ---")

    docs=ast.literal_eval(df.context[i])
    utility_cache_base_dir = f"../Experiment_data/synthetic/{model_path.split('/')[1]}"
    utility_cache_filename = f"utilities_q_idx{i}.pkl" # More robust naming
    current_utility_path = os.path.join(utility_cache_base_dir, utility_cache_filename)
    
    if accelerator_main.is_main_process: # Only main process creates directories
        os.makedirs(os.path.dirname(current_utility_path), exist_ok=True)
    
    # Initialize Harness
    harness = ContextAttribution(
        items=docs,
        query=query,
        prepared_model=prepared_model,
        prepared_tokenizer=tokenizer,
        accelerator=accelerator_main,
        utility_cache_path=current_utility_path
    )

    print(f'Response: {harness.target_response}')
    print(f'GT: {df.answer[i]}')
    # Compute metrics
    results_for_query = {}
    if accelerator_main.is_main_process:
        m_samples_map = {"L": 364} 
        # m_samples_map = {"L": 128, "XL":256, "XXL":512} 
        T_iterations_map = {"L":40, "XL":50, "XXL":60} 

        for size_key, num_s in m_samples_map.items():
            if 2**len(docs) < num_s and size_key != "L":
                actual_samples = max(1, 2**len(docs)-1 if 2**len(docs)>0 else 1)
            else:
                actual_samples = num_s

            if actual_samples > 0:
                results_for_query[f"ContextCite{actual_samples}"], model_cc = harness.compute_contextcite(num_samples=actual_samples, seed=SEED)
                attributions, ints=harness.compute_spex(sample_budget=actual_samples,max_order=2)
                results_for_query[f"FBII{actual_samples}"]=attributions['fbii']
                results_for_query[f"Spex{actual_samples}"]=attributions['fourier']
                results_for_query[f"FSII{actual_samples}"]=attributions['fsii']
                results_for_query[f"FM_WeightsD{actual_samples}"], F, modelfm = harness.compute_wss(num_samples=actual_samples, seed=SEED, sampling="kernelshap",sur_type="fm", utility_mode="divergence_utility")
                results_for_query[f"FM_Weights{actual_samples}"], F, modelfm = harness.compute_wss(num_samples=actual_samples, seed=SEED, sampling="kernelshap",sur_type="fm")
                # results_for_query[f"BetaShap{actual_samples}"] = harness.compute_beta_shap(num_iterations_max=T_iterations_map[size_key], beta_a=16, beta_b=1, max_unique_lookups=actual_samples, seed=SEED)
                # results_for_query[f"TMC{actual_samples}"] = harness.compute_tmc_shap(num_iterations_max=T_iterations_map[size_key], performance_tolerance=0.001, max_unique_lookups=actual_samples, seed=SEED)

        results_for_query["LOO"] = harness.compute_loo()
        results_for_query["ARC-JSD"] = harness.compute_arc_jsd()

        prob_topk = harness.evaluate_topk_performance(
                                                results_for_query, 
                                                k_values, 
                                                utility_type="probability"
                                            )

        div_topk = harness.evaluate_topk_performance(
                                            results_for_query, 
                                            k_values, 
                                            utility_type="divergence"
                                        )
        
        # r2_fm.append([harness.r2_mse(30, modelfm, method='fm')])
        # r2_cc.append([harness.r2_mse(30, model_cc, method='cc')])

        LDS = {}
        for i in results_for_query:
            if "FM" in i:
                calculate_LDS = {i:harness.lds(results_for_query[i], 30, utl=True, model=modelfm)}
                LDS.update(calculate_LDS)
            else:
                calculate_LDS = {i:harness.lds(results_for_query[i], 30)}
                LDS.update(calculate_LDS)
        LDS = [{i:harness.lds(results_for_query[i], 30)} for i in results_for_query]

        results_for_query["topk_probability"] = prob_topk
        results_for_query["topk_divergence"] = div_topk
        results_for_query["LDS"] = LDS
        harness.save_utility_cache(current_utility_path)
        
        all_results.append(results_for_query)

In [None]:
methods = [f'ContextCite{actual_samples}', f'FM_Weights{actual_samples}',f'FM_WeightsD{actual_samples}',f'Spex{actual_samples}',f'FBII{actual_samples}',f'FSII{actual_samples}', 'LOO', 'ARC-JSD']

# Initialize lists
topk_probs = {method: [] for method in methods}
topk_divs = {method: [] for method in methods}
LDSs = {method: [] for method in methods}

# Collect values
for ind, entry in enumerate(all_results):
    for method in methods:
        topk_probs[method].append(entry['topk_probability'][method][2])
        topk_divs[method].append(entry['topk_divergence'][method][2])
        for d in entry['LDS']:
            if method in d:
                LDSs[method].append(d[method])
                break
        

# Compute means
mean_topk_probs = {method: np.mean(topk_probs[method]) for method in methods}
mean_topk_divs = {method: np.mean(topk_divs[method]) for method in methods}
mean_LDSs = {method: np.mean(LDSs[method]) for method in methods}

print("Mean topk_probability:", mean_topk_probs)
print("Mean topk_divergence:", mean_topk_divs)
print("Mean LDS:", mean_LDSs)

In [None]:
precs=np.zeros((len(all_results), len(methods)))
for j, i in enumerate(all_results):
    for n, m in enumerate(methods):
        precs[j][n]=harness.precision([0,1], i[m])

In [None]:
precs.mean(axis=0)

In [None]:
import matplotlib.pyplot as plt

method_scores = {}

for result in all_results:
    for method, scores in result.items():
        if scores is not None:
            method_scores[method] = np.round(scores, 4)

for method, scores in method_scores.items():
    plt.figure(figsize=(10, 4))
    plt.bar(range(len(scores)), scores, color='skyblue')
    plt.title(f"Approximate Scores: {method}")
    plt.xlabel("Index")
    plt.ylabel("Score")
    plt.grid(axis='y', linestyle='--', alpha=0.7)
    plt.xticks(range(len(scores)))
    plt.tight_layout()
    plt.show()


In [None]:
F[1]

In [None]:
df.context[19]