In [1]:
import sys
import os
import random
import gc
import time
import torch
import numpy as np
import pandas as pd
import ast
from tqdm import tqdm
from scipy.stats import spearmanr, pearsonr, kendalltau, rankdata
from sklearn.metrics import ndcg_score
from transformers import AutoModelForCausalLM, AutoTokenizer
from accelerate import Accelerator

current_dir = os.getcwd()
parent_dir = os.path.abspath(os.path.join(current_dir, '..'))
sys.path.append(parent_dir)
from SHapRAG import *

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
df=pd.read_json("../data/musique/musique_ans_v1.0_train.jsonl", lines=True)

In [3]:
def get_titles(lst):
    # Titles where is_supporting is True
    supporting = [d['paragraph_text'] for d in lst if d.get('is_supporting') == True]
    # Titles where is_supporting is False or missing AND not already in supporting
    others = [d['paragraph_text'] for d in lst if d.get('is_supporting') != True and d['paragraph_text'] not in supporting]
    # Combine: all supporting + as many others as needed to reach 10
    result = supporting + others
    return result[:10]

df.paragraphs=df.paragraphs.apply(get_titles)

In [4]:
SEED = 42
# Initialize Accelerator
accelerator_main = Accelerator(mixed_precision="fp16")

# Load Model
if accelerator_main.is_main_process:
    print("Main Script: Loading model...")
# model_path = "mistralai/Mistral-7B-Instruct-v0.3"
model_path = "meta-llama/Llama-3.1-8B-Instruct"
# model_path = "Qwen/Qwen2.5-3B-Instruct"

model_cpu = AutoModelForCausalLM.from_pretrained(
    model_path,
    torch_dtype=torch.float16
)
tokenizer = AutoTokenizer.from_pretrained(model_path)
if tokenizer.pad_token is None:
    tokenizer.pad_token = tokenizer.eos_token
    model_cpu.config.pad_token_id = tokenizer.pad_token_id
    if hasattr(model_cpu, 'generation_config') and model_cpu.generation_config is not None:
        model_cpu.generation_config.pad_token_id = tokenizer.pad_token_id

if accelerator_main.is_main_process:
    print("Main Script: Preparing model with Accelerator...")
prepared_model = accelerator_main.prepare(model_cpu)
unwrapped_prepared_model = accelerator_main.unwrap_model(prepared_model)
unwrapped_prepared_model.eval()
if accelerator_main.is_main_process:
    print("Main Script: Model prepared and set to eval.")

# Define utility cache

accelerator_main.wait_for_everyone()

Detected kernel version 5.4.0, which is below the recommended minimum of 5.5.0; this can cause the process to hang. It is recommended to upgrade the kernel to the minimum version or higher.


Main Script: Loading model...


Loading checkpoint shards: 100%|██████████| 4/4 [00:16<00:00,  4.01s/it]


Main Script: Preparing model with Accelerator...
Main Script: Model prepared and set to eval.


In [None]:
df.answer[:10]

In [5]:
# num_questions_to_run=len(df.question)
num_questions_to_run=20
all_metrics_data = []
all_results=[]
for i in tqdm(range(num_questions_to_run), desc="Processing Questions", disable=not accelerator_main.is_main_process):
    query = df.question[i]
    if accelerator_main.is_main_process:
        print(f"\n--- Question {i+1}/{num_questions_to_run}: {query[:60]}... ---")

    docs=df.paragraphs[i]

    utility_cache_base_dir = "../Experiment_data/synergy"
    utility_cache_filename = f"utilities_q_idx{i}_n{len(docs)}.pkl" # More robust naming
    current_utility_path = os.path.join(utility_cache_base_dir, utility_cache_filename)
    
    if accelerator_main.is_main_process: # Only main process creates directories
        os.makedirs(os.path.dirname(current_utility_path), exist_ok=True)
        print(f"  Instantiating ShapleyExperimentHarness for Q{i} (n={len(docs)} docs)...")
    
    # Initialize Harness
    harness = ShapleyExperimentHarness(
        items=docs,
        query=query,
        prepared_model_for_harness=prepared_model,
        tokenizer_for_harness=tokenizer,
        accelerator_for_harness=accelerator_main,
        verbose=True,
        utility_path=current_utility_path
    )
    # Compute metrics
    results_for_query = {}

    if accelerator_main.is_main_process:
        results_for_query["ExactLinear"] = harness.compute_exact_linear_shap()
        results_for_query["ExactInter"], pairs = harness.compute_exact_inter_shap()

        m_samples_map = {"S": 32, "M": 64, "L": 100} 
        T_iterations_map = {"S": 5, "M": 10, "L":20} 

        for size_key, num_s in m_samples_map.items():
            if 2**len(docs) < num_s and size_key != "L":
                actual_samples = max(1, 2**len(docs)-1 if 2**len(docs)>0 else 1)
            else:
                actual_samples = num_s

            if actual_samples > 0: 
                results_for_query[f"ContextCite{actual_samples}"] = harness.compute_contextcite_weights(num_samples=actual_samples, seed=SEED)
                results_for_query[f"WSS_FM{actual_samples}"], F = harness.compute_wss(num_samples=actual_samples, seed=SEED, distil=None, sampling="kernelshap",sur_type="fm", util='pure-surrogate', pairchecking=False)
                results_for_query[f"BetaShap{actual_samples}"] = harness.compute_beta_shap(num_iterations_max=T_iterations_map[size_key], beta_a=0.5, beta_b=0.5, max_unique_lookups=actual_samples, seed=SEED)
                results_for_query[f"TMC{actual_samples}"] = harness.compute_tmc_shap(num_iterations_max=T_iterations_map[size_key], performance_tolerance=0.001, max_unique_lookups=actual_samples, seed=SEED)

        results_for_query["LOO"] = harness.compute_loo()

        exact_scores = results_for_query.get("ExactInter")
        all_results.append(results_for_query)
        if exact_scores is not None:
            positive_exact_score = np.clip(exact_scores, a_min=0.0, a_max=None) # FOR NDGC SCORE COMPUTATION
            for method, approx_scores in results_for_query.items():
                if method != "Exact" and approx_scores is not None:
                    if len(approx_scores) == len(exact_scores):
                        if np.all(exact_scores == exact_scores[0]) or np.all(approx_scores == approx_scores[0]):
                            pearson_c = 1.0 if np.allclose(exact_scores, approx_scores) else 0.0
                            spearman_c = 1.0 if np.allclose(exact_scores, approx_scores) else 0.0
                        else:
                            pearson_c, _ = pearsonr(exact_scores, approx_scores)
                            exact_ranks = rankdata(-np.array(exact_scores), method="average") # rank scores with the smallest =1 and when there is a tie assign the average rank
                            approx_ranks = rankdata(-np.array(approx_scores), method = "average")
                            kendall_c, _ = kendalltau(exact_ranks, approx_ranks) # return tau and pval (if pval is < 0.005 we can say that correlation is statistically significant) 
                        
                        all_metrics_data.append({
                            "Question_Index": i, "Query": query, "Method": method,
                            "Pearson": pearson_c, "KendallTau" : kendall_c,
                        })
                    else:
                        print(f"    Score length mismatch for method {method} (Exact: {len(exact_scores)}, Approx: {len(approx_scores)}). Skipping metrics.")
        else:
            print(f"    Skipping metric calculation for Q{i} as Exact Shapley was not computed or failed.")
    
    accelerator_main.wait_for_everyone() 
   
    if torch.cuda.is_available():
        if accelerator_main.is_main_process: # Print from one process
            print(f"Attempting to empty CUDA cache on rank {accelerator_main.process_index} after Q{i}")
        torch.cuda.empty_cache()
        gc.collect()
        if accelerator_main.is_main_process:
            print(f"CUDA cache empty attempt complete on rank {accelerator_main.process_index}.")
    accelerator_main.wait_for_everyone()


if accelerator_main.is_main_process:
    if all_metrics_data:
        metrics_df_all_questions = pd.DataFrame(all_metrics_data)
        print("\n\n--- Average Correlation Metrics Across All Questions ---")
        average_metrics = metrics_df_all_questions.groupby("Method").agg(
            Avg_Pearson=("Pearson", "mean"),
            Avg_Kendall =("KendallTau", "mean"),
            Num_Valid_Queries=("Query", "nunique")
        ).sort_values(by="Avg_Pearson", ascending=False)
        
        print(average_metrics.round(4))
    else:
        print("\nNo metrics were collected. This might be due to all calculations failing or only non-main processes running sections.")

# Final synchronization before script ends
accelerator_main.wait_for_everyone()
if accelerator_main.is_main_process:
    print("Script finished.")

if torch.distributed.is_available() and torch.distributed.is_initialized():
    if accelerator_main.is_local_main_process:
        print(f"Rank {accelerator_main.process_index} (Local Main): Manually destroying process group...")
    torch.distributed.destroy_process_group()
    if accelerator_main.is_local_main_process:
        print(f"Rank {accelerator_main.process_index} (Local Main): Process group destroyed.")
else:
    if accelerator_main.is_local_main_process:
        print(f"Rank {accelerator_main.process_index} (Local Main): Distributed environment not initialized or not available, skipping destroy_process_group.")

if accelerator_main.is_main_process:
    print("Script fully exited.")

Processing Questions:   0%|          | 0/20 [00:00<?, ?it/s]


--- Question 1/20: When was the institute that owned The Collegian founded?... ---
  Instantiating ShapleyExperimentHarness for Q0 (n=10 docs)...
Generating target response based on full context...


Fetching 3 files: 100%|██████████| 3/3 [00:00<00:00, 11037.64it/s]


Target response: '1960'
Successfully loaded utilities from ../Experiment_data/synergy/utilities_q_idx0_n10.pkl. Found 1024 entries.
Broadcasted loaded utilities to all processes.




Attempting to empty CUDA cache on rank 0 after Q0
CUDA cache empty attempt complete on rank 0.

Processing Questions:   5%|▌         | 1/20 [00:02<00:48,  2.53s/it]



--- Question 2/20: What year saw the creation of the region where the county of... ---
  Instantiating ShapleyExperimentHarness for Q1 (n=10 docs)...
Generating target response based on full context...




Target response: '1994'
Successfully loaded utilities from ../Experiment_data/synergy/utilities_q_idx1_n10.pkl. Found 1024 entries.
Broadcasted loaded utilities to all processes.




Attempting to empty CUDA cache on rank 0 after Q1


Processing Questions:  10%|█         | 2/20 [00:03<00:26,  1.46s/it]

CUDA cache empty attempt complete on rank 0.

--- Question 3/20: When was the abolishment of the studio that distributed The ... ---
  Instantiating ShapleyExperimentHarness for Q2 (n=10 docs)...
Generating target response based on full context...




Target response: '1999'
Successfully loaded utilities from ../Experiment_data/synergy/utilities_q_idx2_n10.pkl. Found 1024 entries.
Broadcasted loaded utilities to all processes.




Attempting to empty CUDA cache on rank 0 after Q2


Processing Questions:  15%|█▌        | 3/20 [00:03<00:19,  1.12s/it]

CUDA cache empty attempt complete on rank 0.

--- Question 4/20: When was the publisher of Crux launched?... ---
  Instantiating ShapleyExperimentHarness for Q3 (n=10 docs)...
Generating target response based on full context...




Target response: '2001'
Successfully loaded utilities from ../Experiment_data/synergy/utilities_q_idx3_n10.pkl. Found 1024 entries.
Broadcasted loaded utilities to all processes.




Attempting to empty CUDA cache on rank 0 after Q3
CUDA cache empty attempt complete on rank 0.

Processing Questions:  20%|██        | 4/20 [00:04<00:15,  1.05it/s]



--- Question 5/20: Jan Šindel's was born in what country?... ---
  Instantiating ShapleyExperimentHarness for Q4 (n=10 docs)...
Generating target response based on full context...




Target response: 'Czech Republic.'
Successfully loaded utilities from ../Experiment_data/synergy/utilities_q_idx4_n10.pkl. Found 1024 entries.
Broadcasted loaded utilities to all processes.




Attempting to empty CUDA cache on rank 0 after Q4
CUDA cache empty attempt complete on rank 0.

Processing Questions:  25%|██▌       | 5/20 [00:05<00:13,  1.10it/s]



--- Question 6/20: What city is the person who broadened the doctrine of philos... ---
  Instantiating ShapleyExperimentHarness for Q5 (n=10 docs)...
Generating target response based on full context...




Target response: 'Copenhagen.'
Successfully loaded utilities from ../Experiment_data/synergy/utilities_q_idx5_n10.pkl. Found 1024 entries.
Broadcasted loaded utilities to all processes.




Attempting to empty CUDA cache on rank 0 after Q5
CUDA cache empty attempt complete on rank 0.

Processing Questions:  30%|███       | 6/20 [00:06<00:12,  1.14it/s]



--- Question 7/20: When was the baseball team winning the world series in 2015 ... ---
  Instantiating ShapleyExperimentHarness for Q6 (n=10 docs)...
Generating target response based on full context...




Target response: '1969.'
Successfully loaded utilities from ../Experiment_data/synergy/utilities_q_idx6_n10.pkl. Found 1024 entries.
Broadcasted loaded utilities to all processes.




Attempting to empty CUDA cache on rank 0 after Q6
CUDA cache empty attempt complete on rank 0.

Processing Questions:  35%|███▌      | 7/20 [00:07<00:11,  1.13it/s]



--- Question 8/20: Where did the Baldevins bryllup director die?... ---
  Instantiating ShapleyExperimentHarness for Q7 (n=10 docs)...
Generating target response based on full context...




Target response: 'I couldn't find any information about the director of Baldevins bryllup dying.'
Successfully loaded utilities from ../Experiment_data/synergy/utilities_q_idx7_n10.pkl. Found 1024 entries.
Broadcasted loaded utilities to all processes.




Attempting to empty CUDA cache on rank 0 after Q7
CUDA cache empty attempt complete on rank 0.


Processing Questions:  40%|████      | 8/20 [00:08<00:12,  1.07s/it]


--- Question 9/20: Who was thee first president of the association that wrote t... ---
  Instantiating ShapleyExperimentHarness for Q8 (n=10 docs)...
Generating target response based on full context...




Target response: 'G. Stanley Hall.'
Successfully loaded utilities from ../Experiment_data/synergy/utilities_q_idx8_n10.pkl. Found 1024 entries.
Broadcasted loaded utilities to all processes.




Attempting to empty CUDA cache on rank 0 after Q8


Processing Questions:  45%|████▌     | 9/20 [00:09<00:11,  1.03s/it]

CUDA cache empty attempt complete on rank 0.

--- Question 10/20: Which major Russian city borders the body of water in which ... ---
  Instantiating ShapleyExperimentHarness for Q9 (n=10 docs)...
Generating target response based on full context...




Target response: 'The Baltic Sea.'
Successfully loaded utilities from ../Experiment_data/synergy/utilities_q_idx9_n10.pkl. Found 1024 entries.
Broadcasted loaded utilities to all processes.




Attempting to empty CUDA cache on rank 0 after Q9
CUDA cache empty attempt complete on rank 0.

Processing Questions:  50%|█████     | 10/20 [00:10<00:09,  1.04it/s]



--- Question 11/20: When was the employer of John J. Collins established?... ---
  Instantiating ShapleyExperimentHarness for Q10 (n=10 docs)...
Generating target response based on full context...




Target response: 'Yale Divinity School.'
Successfully loaded utilities from ../Experiment_data/synergy/utilities_q_idx10_n10.pkl. Found 1024 entries.
Broadcasted loaded utilities to all processes.




Attempting to empty CUDA cache on rank 0 after Q10


Processing Questions:  55%|█████▌    | 11/20 [00:11<00:08,  1.04it/s]

CUDA cache empty attempt complete on rank 0.

--- Question 12/20: When did Bush declare the war causing Kerry to criticize him... ---
  Instantiating ShapleyExperimentHarness for Q11 (n=10 docs)...
Generating target response based on full context...




Target response: 'March 2003.'
Successfully loaded utilities from ../Experiment_data/synergy/utilities_q_idx11_n10.pkl. Found 1024 entries.
Broadcasted loaded utilities to all processes.




Attempting to empty CUDA cache on rank 0 after Q11
CUDA cache empty attempt complete on rank 0.

Processing Questions:  60%|██████    | 12/20 [00:12<00:08,  1.01s/it]



--- Question 13/20: What is the college Francis Walsingham attended an instance ... ---
  Instantiating ShapleyExperimentHarness for Q12 (n=10 docs)...
Generating target response based on full context...




Target response: 'King's College.'
Successfully loaded utilities from ../Experiment_data/synergy/utilities_q_idx12_n10.pkl. Found 1024 entries.
Broadcasted loaded utilities to all processes.




Attempting to empty CUDA cache on rank 0 after Q12
CUDA cache empty attempt complete on rank 0.


Processing Questions:  65%|██████▌   | 13/20 [00:13<00:06,  1.07it/s]


--- Question 14/20: What type of university is the college Kyeon Mi-ri attended?... ---
  Instantiating ShapleyExperimentHarness for Q13 (n=10 docs)...
Generating target response based on full context...




Target response: 'Private university.'
Successfully loaded utilities from ../Experiment_data/synergy/utilities_q_idx13_n10.pkl. Found 1024 entries.
Broadcasted loaded utilities to all processes.




Attempting to empty CUDA cache on rank 0 after Q13
CUDA cache empty attempt complete on rank 0.

Processing Questions:  70%|███████   | 14/20 [00:13<00:05,  1.16it/s]



--- Question 15/20: In what year was the author of The Insider's Guide to the Co... ---
  Instantiating ShapleyExperimentHarness for Q14 (n=10 docs)...
Generating target response based on full context...




Target response: '1878'
Successfully loaded utilities from ../Experiment_data/synergy/utilities_q_idx14_n10.pkl. Found 1024 entries.
Broadcasted loaded utilities to all processes.




Attempting to empty CUDA cache on rank 0 after Q14
CUDA cache empty attempt complete on rank 0.

Processing Questions:  75%|███████▌  | 15/20 [00:14<00:04,  1.22it/s]



--- Question 16/20: When was the territory covered by RIBA's Cambridge branch of... ---
  Instantiating ShapleyExperimentHarness for Q15 (n=10 docs)...
Generating target response based on full context...




Target response: '1994'
Successfully loaded utilities from ../Experiment_data/synergy/utilities_q_idx15_n10.pkl. Found 1024 entries.
Broadcasted loaded utilities to all processes.




Attempting to empty CUDA cache on rank 0 after Q15


Processing Questions:  80%|████████  | 16/20 [00:15<00:03,  1.30it/s]

CUDA cache empty attempt complete on rank 0.

--- Question 17/20: What's the meaning of the name of the school that does not i... ---
  Instantiating ShapleyExperimentHarness for Q16 (n=10 docs)...
Generating target response based on full context...




Target response: 'Theravada.'
Successfully loaded utilities from ../Experiment_data/synergy/utilities_q_idx16_n10.pkl. Found 1024 entries.
Broadcasted loaded utilities to all processes.




Attempting to empty CUDA cache on rank 0 after Q16


Processing Questions:  85%|████████▌ | 17/20 [00:16<00:02,  1.20it/s]

CUDA cache empty attempt complete on rank 0.

--- Question 18/20: Where did the director who provided the lyrics to A Time for... ---
  Instantiating ShapleyExperimentHarness for Q17 (n=10 docs)...
Generating target response based on full context...




Target response: 'University College London.'
Successfully loaded utilities from ../Experiment_data/synergy/utilities_q_idx17_n10.pkl. Found 1024 entries.
Broadcasted loaded utilities to all processes.




Attempting to empty CUDA cache on rank 0 after Q17


Processing Questions:  90%|█████████ | 18/20 [00:17<00:01,  1.16it/s]

CUDA cache empty attempt complete on rank 0.

--- Question 19/20: When did the country formerly known as Zaire become independ... ---
  Instantiating ShapleyExperimentHarness for Q18 (n=10 docs)...
Generating target response based on full context...




Target response: '1960.'
Successfully loaded utilities from ../Experiment_data/synergy/utilities_q_idx18_n10.pkl. Found 1024 entries.
Broadcasted loaded utilities to all processes.




Attempting to empty CUDA cache on rank 0 after Q18


Processing Questions:  95%|█████████▌| 19/20 [00:18<00:00,  1.18it/s]

CUDA cache empty attempt complete on rank 0.

--- Question 20/20: Where did Peter and Paul Fortress' designer die?... ---
  Instantiating ShapleyExperimentHarness for Q19 (n=10 docs)...
Generating target response based on full context...




Target response: 'Domenico Trezzini.'
Successfully loaded utilities from ../Experiment_data/synergy/utilities_q_idx19_n10.pkl. Found 1024 entries.
Broadcasted loaded utilities to all processes.




Attempting to empty CUDA cache on rank 0 after Q19


Processing Questions: 100%|██████████| 20/20 [00:18<00:00,  1.06it/s]

CUDA cache empty attempt complete on rank 0.


--- Average Correlation Metrics Across All Questions ---
                Avg_Pearson  Avg_Kendall  Num_Valid_Queries
Method                                                     
ExactInter           1.0000       1.0000                 20
ExactLinear          1.0000       1.0000                 20
ContextCite100       0.9951       0.7622                 20
WSS_FM100            0.9943       0.7200                 20
ContextCite64        0.9899       0.7422                 20
WSS_FM64             0.9891       0.6933                 20
TMC100               0.9838       0.6422                 20
TMC64                0.9779       0.5933                 20
LOO                  0.9741       0.4778                 20
ContextCite32        0.9552       0.6778                 20
TMC32                0.9486       0.5222                 20
WSS_FM32             0.9154       0.4867                 20
BetaShap100          0.8623       0.4067                




In [None]:
import matplotlib.pyplot as plt

method_scores = {}

for result in all_results:
    for method, scores in result.items():
        if scores is not None:
            method_scores[method] = np.round(scores, 4)

for method, scores in method_scores.items():
    plt.figure(figsize=(10, 4))
    plt.bar(range(len(scores)), scores, color='skyblue')
    plt.title(f"Approximate Scores: {method}")
    plt.xlabel("Index")
    plt.ylabel("Score")
    plt.grid(axis='y', linestyle='--', alpha=0.7)
    plt.xticks(range(len(scores)))
    plt.tight_layout()
    plt.show()