In [1]:
import sys
import os
import random
import gc
import time
import torch
import numpy as np
import pandas as pd
import ast
from tqdm import tqdm
from scipy.sparse import csr_matrix
import itertools
from scipy.stats import spearmanr, pearsonr, kendalltau, rankdata
from sklearn.metrics import ndcg_score
from transformers import AutoModelForCausalLM, AutoTokenizer
from accelerate import Accelerator
import nltk
nltk.download('punkt')
os.environ["CUDA_VISIBLE_DEVICES"] = "2" 
current_dir = os.getcwd()
parent_dir = os.path.abspath(os.path.join(current_dir, '..'))
sys.path.append(parent_dir)
from SHapRAG import *

  from .autonotebook import tqdm as notebook_tqdm
[nltk_data] Downloading package punkt to /root/nltk_data...
[nltk_data]   Package punkt is already up-to-date!


In [2]:
df=pd.read_json("../data/musique/musique_ans_v1.0_train.jsonl", lines=True)

In [3]:
def get_titles(lst):
    # Titles where is_supporting is True
    supporting = [d['paragraph_text'] for d in lst if d.get('is_supporting') == True]
    # Titles where is_supporting is False or missing AND not already in supporting
    others = [d['paragraph_text'] for d in lst if d.get('is_supporting') != True and d['paragraph_text'] not in supporting]
    # Combine: all supporting + as many others as needed to reach 10
    result = supporting + others
    return result[:10]

df.paragraphs=df.paragraphs.apply(get_titles)

In [4]:
all_sents = []
for i in range(len(df.question)):
    n = 0
    docs=df.paragraphs[i]
    doc_sents = []
    for j in range(len(docs)):
        sents = nltk.sent_tokenize(docs[j])
        new_sents = []
        for s in range(len(sents)):
            new_sents.append(str(n + s) + '-' + str(j) + '-' + sents[s])
        n += len(sents)
        doc_sents.append(new_sents)
    flat_doc_sents = [
    x
    for xs in doc_sents
    for x in xs
]
    all_sents.append(flat_doc_sents)
df['Sentences'] = all_sents

In [5]:
SEED = 42
# Initialize Accelerator
accelerator_main = Accelerator(mixed_precision="fp16")

# Load Model
if accelerator_main.is_main_process:
    print("Main Script: Loading model...")
# model_path = "mistralai/Mistral-7B-Instruct-v0.3"
model_path = "meta-llama/Llama-3.1-8B-Instruct"
# model_path = "Qwen/Qwen2.5-3B-Instruct"

model_cpu = AutoModelForCausalLM.from_pretrained(
    model_path,
    torch_dtype=torch.float16
)
tokenizer = AutoTokenizer.from_pretrained(model_path)
if tokenizer.pad_token is None:
    tokenizer.pad_token = tokenizer.eos_token
    model_cpu.config.pad_token_id = tokenizer.pad_token_id
    if hasattr(model_cpu, 'generation_config') and model_cpu.generation_config is not None:
        model_cpu.generation_config.pad_token_id = tokenizer.pad_token_id

if accelerator_main.is_main_process:
    print("Main Script: Preparing model with Accelerator...")
prepared_model = accelerator_main.prepare(model_cpu)
unwrapped_prepared_model = accelerator_main.unwrap_model(prepared_model)
unwrapped_prepared_model.eval()
if accelerator_main.is_main_process:
    print("Main Script: Model prepared and set to eval.")

# Define utility cache

accelerator_main.wait_for_everyone()

Detected kernel version 5.4.0, which is below the recommended minimum of 5.5.0; this can cause the process to hang. It is recommended to upgrade the kernel to the minimum version or higher.


Main Script: Loading model...


Loading checkpoint shards: 100%|██████████| 4/4 [00:01<00:00,  2.49it/s]


Main Script: Preparing model with Accelerator...
Main Script: Model prepared and set to eval.


In [6]:
# num_questions_to_run=len(df.question)
num_questions_to_run=30
k_values = [2]
all_metrics_data = []
all_results=[]
LDSs=[]
r2_fm=[]
r2_cc=[]
for i in tqdm(range(num_questions_to_run), disable=not accelerator_main.is_main_process):
    query = df.question[i]
    if accelerator_main.is_main_process:
        print(f"\n--- Question {i+1}/{num_questions_to_run}: {query[:60]}... ---")

    docs=df.paragraphs[i]

    utility_cache_base_dir = "../Experiment_data/musique"
    utility_cache_filename = f"utilities_q_idx{i}.pkl" # More robust naming
    current_utility_path = os.path.join(utility_cache_base_dir, utility_cache_filename)
    
    if accelerator_main.is_main_process: # Only main process creates directories
        os.makedirs(os.path.dirname(current_utility_path), exist_ok=True)
    
    # Initialize Harness
    harness = ContextAttribution(
        items=docs,
        query=query,
        prepared_model_for_harness=prepared_model,
        tokenizer_for_harness=tokenizer,
        accelerator_for_harness=accelerator_main,
        utility_cache_path=current_utility_path
    )

    print(f'Response: {harness.target_response}')
    # Compute metrics
    results_for_query = {}
    if accelerator_main.is_main_process:
        m_samples_map = {"L": 128, "XL":256, "XXL":512} 
        T_iterations_map = {"L":40, "XL":50, "XXL":60} 

        for size_key, num_s in m_samples_map.items():
            if 2**len(docs) < num_s and size_key != "L":
                actual_samples = max(1, 2**len(docs)-1 if 2**len(docs)>0 else 1)
            else:
                actual_samples = num_s

            if actual_samples > 0: 
                results_for_query[f"ContextCite{actual_samples}"], model_cc = harness.compute_contextcite(num_samples=actual_samples, seed=SEED)
                results_for_query[f"FM_Shap{actual_samples}"], results_for_query[f"FM_Weights{actual_samples}"], F, modelfm = harness.compute_wss(num_samples=actual_samples, seed=SEED, sampling="kernelshap",sur_type="fm")
                # results_for_query[f"BetaShap{actual_samples}"] = harness.compute_beta_shap(num_iterations_max=T_iterations_map[size_key], beta_a=16, beta_b=1, max_unique_lookups=actual_samples, seed=SEED)
                # results_for_query[f"TMC{actual_samples}"] = harness.compute_tmc_shap(num_iterations_max=T_iterations_map[size_key], performance_tolerance=0.001, max_unique_lookups=actual_samples, seed=SEED)

        results_for_query["LOO"] = harness.compute_loo()
        results_for_query["ARC-JSD"] = harness.compute_arc_jsd()

        prob_topk = harness.evaluate_topk_performance(
                                                results_for_query, 
                                                k_values, 
                                                utility_type="probability"
                                            )

        div_topk = harness.evaluate_topk_performance(
                                            results_for_query, 
                                            k_values, 
                                            utility_type="divergence"
                                        )
        
        r2_fm.append(harness.r2(30, modelfm, method='fm'))
        r2_cc.append(harness.r2(30, model_cc, method='cc'))

        LDS = {}
        for i in results_for_query:
            if "FM_Shap" in i:
                calculate_LDS = {i:harness.lds(results_for_query[i], 30, utl=True, model=modelfm)}
                LDS.update(calculate_LDS)
            else:
                calculate_LDS = {i:harness.lds(results_for_query[i], 30)}
                LDS.update(calculate_LDS)
        #LDS = [{i:harness.lds(results_for_query[i], 30)} for i in results_for_query]

        results_for_query["topk_probability"] = prob_topk
        results_for_query["topk_divergence"] = div_topk
        results_for_query["LDS"] = LDS
        harness.save_utility_cache(current_utility_path)
        
        all_results.append(results_for_query)

  0%|          | 0/30 [00:00<?, ?it/s]


--- Question 1/30: When was the institute that owned The Collegian founded?... ---
Loading existing utility cache from ../Experiment_data/musique/utilities_q_idx0.pkl...
Successfully loaded 475 cached utilities.


Fetching 3 files: 100%|██████████| 3/3 [00:00<00:00, 13107.20it/s]


Response: Houston Baptist University was founded in 1960.


100%|██████████| 10/10 [00:02<00:00,  4.96it/s]
  3%|▎         | 1/30 [00:57<27:37, 57.16s/it]


--- Question 2/30: What year saw the creation of the region where the county of... ---
Loading existing utility cache from ../Experiment_data/musique/utilities_q_idx1.pkl...
Successfully loaded 475 cached utilities.
Response: 1994.


100%|██████████| 10/10 [00:02<00:00,  3.35it/s]
  7%|▋         | 2/30 [02:09<30:57, 66.35s/it]


--- Question 3/30: When was the abolishment of the studio that distributed The ... ---
Loading existing utility cache from ../Experiment_data/musique/utilities_q_idx2.pkl...
Successfully loaded 249 cached utilities.
Response: 1999


100%|██████████| 10/10 [00:02<00:00,  3.35it/s]
 10%|█         | 3/30 [04:03<39:30, 87.81s/it]


--- Question 4/30: When was the publisher of Crux launched?... ---
Loading existing utility cache from ../Experiment_data/musique/utilities_q_idx3.pkl...
Successfully loaded 249 cached utilities.
Response: May 2001.


100%|██████████| 10/10 [00:02<00:00,  4.68it/s]
 13%|█▎        | 4/30 [05:31<38:09, 88.05s/it]


--- Question 5/30: Jan Šindel's was born in what country?... ---
Loading existing utility cache from ../Experiment_data/musique/utilities_q_idx4.pkl...
Successfully loaded 249 cached utilities.
Response: The Czech Republic.


100%|██████████| 10/10 [00:02<00:00,  3.78it/s]
 17%|█▋        | 5/30 [07:07<37:49, 90.78s/it]


--- Question 6/30: What city is the person who broadened the doctrine of philos... ---
Loading existing utility cache from ../Experiment_data/musique/utilities_q_idx5.pkl...
Successfully loaded 249 cached utilities.
Response: Copenhagen.


100%|██████████| 10/10 [00:03<00:00,  2.89it/s]
 20%|██        | 6/30 [09:06<40:11, 100.48s/it]


--- Question 7/30: When was the baseball team winning the world series in 2015 ... ---
Loading existing utility cache from ../Experiment_data/musique/utilities_q_idx6.pkl...
Successfully loaded 249 cached utilities.
Response: The Kansas City Royals were founded in 1969.


100%|██████████| 10/10 [00:03<00:00,  2.51it/s]
 23%|██▎       | 7/30 [11:16<42:08, 109.94s/it]


--- Question 8/30: Where did the Baldevins bryllup director die?... ---
Loading existing utility cache from ../Experiment_data/musique/utilities_q_idx7.pkl...
Successfully loaded 249 cached utilities.
Response: There is no information about the director George Schnéevoigt's death in the provided context.


100%|██████████| 10/10 [00:02<00:00,  4.31it/s]
 27%|██▋       | 8/30 [12:48<38:15, 104.36s/it]


--- Question 9/30: Who was thee first president of the association that wrote t... ---
Loading existing utility cache from ../Experiment_data/musique/utilities_q_idx8.pkl...
Successfully loaded 249 cached utilities.
Response: G. Stanley Hall.


100%|██████████| 10/10 [00:03<00:00,  2.61it/s]
 30%|███       | 9/30 [14:53<38:46, 110.81s/it]


--- Question 10/30: Which major Russian city borders the body of water in which ... ---
Loading existing utility cache from ../Experiment_data/musique/utilities_q_idx9.pkl...
Successfully loaded 249 cached utilities.
Response: The major Russian city that borders the Baltic Sea, in which Saaremaa is located, is Saint Petersburg.


100%|██████████| 10/10 [00:02<00:00,  3.40it/s]
 33%|███▎      | 10/30 [16:42<36:46, 110.32s/it]


--- Question 11/30: When was the employer of John J. Collins established?... ---
Loading existing utility cache from ../Experiment_data/musique/utilities_q_idx10.pkl...
Successfully loaded 249 cached utilities.
Response: Yale Divinity School.


100%|██████████| 10/10 [00:03<00:00,  2.86it/s]
 37%|███▋      | 11/30 [18:42<35:52, 113.29s/it]


--- Question 12/30: When did Bush declare the war causing Kerry to criticize him... ---
Loading existing utility cache from ../Experiment_data/musique/utilities_q_idx11.pkl...
Successfully loaded 249 cached utilities.
Response: Bush relied on a resolution Kerry voted for in 2002 to order the 2003 invasion of Iraq.


100%|██████████| 10/10 [00:05<00:00,  1.93it/s]
 40%|████      | 12/30 [21:21<38:10, 127.24s/it]


--- Question 13/30: What is the college Francis Walsingham attended an instance ... ---
Loading existing utility cache from ../Experiment_data/musique/utilities_q_idx12.pkl...
Successfully loaded 249 cached utilities.
Response: King's College, Cambridge.


100%|██████████| 10/10 [00:02<00:00,  4.02it/s]
 43%|████▎     | 13/30 [22:55<33:09, 117.00s/it]


--- Question 14/30: What type of university is the college Kyeon Mi-ri attended?... ---
Loading existing utility cache from ../Experiment_data/musique/utilities_q_idx13.pkl...
Successfully loaded 249 cached utilities.
Response: Sejong University.


100%|██████████| 10/10 [00:02<00:00,  4.77it/s]
 47%|████▋     | 14/30 [24:22<28:49, 108.06s/it]


--- Question 15/30: In what year was the author of The Insider's Guide to the Co... ---
Loading existing utility cache from ../Experiment_data/musique/utilities_q_idx14.pkl...
Successfully loaded 249 cached utilities.
Response: The text does not mention the author of The Insider's Guide to the Colleges. However, it does mention that the guide has been published annually by the student editorial staff of the "Yale Daily News" for over four decades.


100%|██████████| 10/10 [00:02<00:00,  3.38it/s]
 50%|█████     | 15/30 [26:07<26:44, 107.00s/it]


--- Question 16/30: When was the territory covered by RIBA's Cambridge branch of... ---
Loading existing utility cache from ../Experiment_data/musique/utilities_q_idx15.pkl...
Successfully loaded 249 cached utilities.
Response: 1994.


100%|██████████| 10/10 [00:02<00:00,  4.14it/s]
 53%|█████▎    | 16/30 [27:39<23:56, 102.59s/it]


--- Question 17/30: What's the meaning of the name of the school that does not i... ---
Loading existing utility cache from ../Experiment_data/musique/utilities_q_idx16.pkl...
Successfully loaded 249 cached utilities.
Response: Theravada.


100%|██████████| 10/10 [00:04<00:00,  2.23it/s]
 57%|█████▋    | 17/30 [30:05<25:03, 115.64s/it]


--- Question 18/30: Where did the director who provided the lyrics to A Time for... ---
Loading existing utility cache from ../Experiment_data/musique/utilities_q_idx17.pkl...
Successfully loaded 249 cached utilities.
Response: University College London.


100%|██████████| 10/10 [00:04<00:00,  2.31it/s]
 60%|██████    | 18/30 [32:28<24:47, 123.99s/it]


--- Question 19/30: When did the country formerly known as Zaire become independ... ---
Loading existing utility cache from ../Experiment_data/musique/utilities_q_idx18.pkl...
Successfully loaded 249 cached utilities.
Response: 1960.


100%|██████████| 10/10 [00:03<00:00,  2.89it/s]
 63%|██████▎   | 19/30 [34:28<22:30, 122.79s/it]


--- Question 20/30: Where did Peter and Paul Fortress' designer die?... ---
Loading existing utility cache from ../Experiment_data/musique/utilities_q_idx19.pkl...
Successfully loaded 249 cached utilities.
Response: Domenico Trezzini's death location is not mentioned in the provided context.


100%|██████████| 10/10 [00:02<00:00,  3.60it/s]
 67%|██████▋   | 20/30 [36:07<19:14, 115.43s/it]


--- Question 21/30: When did the network which airs Alt for Norge start?... ---
Loading existing utility cache from ../Experiment_data/musique/utilities_q_idx20.pkl...
Successfully loaded 249 cached utilities.
Response: 5 December 1988.


100%|██████████| 10/10 [00:02<00:00,  3.73it/s]
 70%|███████   | 21/30 [37:43<16:28, 109.80s/it]


--- Question 22/30: Who failed to take back what the French believed instrumenta... ---
Loading existing utility cache from ../Experiment_data/musique/utilities_q_idx21.pkl...
Successfully loaded 249 cached utilities.
Response: The Russians failed to retake the Malakoff.


100%|██████████| 10/10 [00:03<00:00,  2.97it/s]
 73%|███████▎  | 22/30 [39:43<15:00, 112.62s/it]


--- Question 23/30: What is the field of work of the proposer of the modern synt... ---
Loading existing utility cache from ../Experiment_data/musique/utilities_q_idx22.pkl...
Successfully loaded 249 cached utilities.
Response: Evolutionary biology.


100%|██████████| 10/10 [00:04<00:00,  2.38it/s]
 77%|███████▋  | 23/30 [41:58<13:56, 119.43s/it]


--- Question 24/30: When was the season of Greys Anatomy when Derek died filmed?... ---
Loading existing utility cache from ../Experiment_data/musique/utilities_q_idx23.pkl...
Successfully loaded 249 cached utilities.
Response: July 2014.


100%|██████████| 10/10 [00:04<00:00,  2.44it/s]
 80%|████████  | 24/30 [44:08<12:15, 122.52s/it]


--- Question 25/30: When did the manufacturer of a pedometer accessory for the i... ---
Loading existing utility cache from ../Experiment_data/musique/utilities_q_idx24.pkl...
Successfully loaded 249 cached utilities.
Response: The text does not mention when the manufacturer of the Nike+iPod pedometer became a publicly traded company.


100%|██████████| 10/10 [00:02<00:00,  3.36it/s]
 83%|████████▎ | 25/30 [45:59<09:55, 119.19s/it]


--- Question 26/30: What is the record label for the person who sang Beauty and ... ---
Loading existing utility cache from ../Experiment_data/musique/utilities_q_idx25.pkl...
Successfully loaded 249 cached utilities.
Response: Peabo Bryson.


100%|██████████| 10/10 [00:06<00:00,  1.64it/s]
 87%|████████▋ | 26/30 [49:04<09:15, 138.89s/it]


--- Question 27/30: Who is the employer of the Iranian scientist who co-invented... ---
Loading existing utility cache from ../Experiment_data/musique/utilities_q_idx26.pkl...
Successfully loaded 249 cached utilities.
Response: MIT laboratory.


100%|██████████| 10/10 [00:03<00:00,  3.02it/s]
 90%|█████████ | 27/30 [50:59<06:35, 131.81s/it]


--- Question 28/30: How many championships in a row were won by the person who p... ---
Loading existing utility cache from ../Experiment_data/musique/utilities_q_idx27.pkl...
Successfully loaded 249 cached utilities.
Response: Bill Russell played in 70 NBA Finals games and won 11 championships.


100%|██████████| 10/10 [00:04<00:00,  2.32it/s]
 93%|█████████▎| 28/30 [53:21<04:29, 134.77s/it]


--- Question 29/30: In what language is the star of Koyelaanchal fluent?... ---
Loading existing utility cache from ../Experiment_data/musique/utilities_q_idx28.pkl...
Successfully loaded 249 cached utilities.
Response: Hindi.


100%|██████████| 10/10 [00:02<00:00,  3.44it/s]
 97%|█████████▋| 29/30 [55:11<02:07, 127.43s/it]


--- Question 30/30: What instrument did the artiste for Vi skall fara bortom mån... ---
Loading existing utility cache from ../Experiment_data/musique/utilities_q_idx29.pkl...
Successfully loaded 249 cached utilities.
Response: Guitar.


100%|██████████| 10/10 [00:02<00:00,  4.96it/s]
100%|██████████| 30/30 [56:36<00:00, 113.22s/it]


In [7]:
methods = ['ContextCite128', 'FM_Shap128', 'FM_Weights128', 'LOO', 'ARC-JSD']

# Initialize lists
topk_probs = {method: [] for method in methods}
topk_divs = {method: [] for method in methods}
LDSs = {method: [] for method in methods}

# Collect values
for entry in all_results:
    for method in methods:
        topk_probs[method].append(entry['topk_probability'][method][2])
        topk_divs[method].append(entry['topk_divergence'][method][2])
        LDSs[method].append(entry['LDS'][method])
        

# Compute means
mean_topk_probs = {method: np.mean(topk_probs[method]) for method in methods}
mean_topk_divs = {method: np.mean(topk_divs[method]) for method in methods}
LDSs = {method: np.mean(LDS[method]) for method in methods}

print("Mean topk_probability:", mean_topk_probs)
print("Mean topk_divergence:", mean_topk_divs)
print("Mean LDS:", LDSs)

Mean topk_probability: {'ContextCite128': 11.26526796023051, 'FM_Shap128': 11.506810601552328, 'FM_Weights128': 11.506810601552328, 'LOO': 10.867941546440125, 'ARC-JSD': 11.371177625656127}
Mean topk_divergence: {'ContextCite128': 1.863587933303997, 'FM_Shap128': 1.9740796185450222, 'FM_Weights128': 1.9740796185450222, 'LOO': 1.7984966605805124, 'ARC-JSD': 1.9053148663624333}
Mean LDS: {'ContextCite128': 0.9339624334057718, 'FM_Shap128': 0.933754632216548, 'FM_Weights128': 0.9292772664423138, 'LOO': 0.9162929056970347, 'ARC-JSD': 0.8755488771515031}


In [10]:
sum(r2_cc)

25.48468959174888

In [8]:
for i in range(len(r2_cc)):
    print(r2_cc[i], r2_fm[i])

0.9375165280647976 0.9578649778482946
0.9577211160365078 0.9832881269772584
0.9869235469851904 0.9906930025271654
0.9879418265086173 0.9962197116887972
0.8708815579602052 0.9104545483631397
0.9165571635505991 0.958990092821317
0.9779624854055823 0.9909412767565823
0.9984259508437362 0.9983177402450156
0.8646305527989073 0.9449403286164748
0.8149501499526735 0.917248494102503
0.9949613745656555 0.9949135504274558
0.7944212166536255 0.9255864490814676
0.7641135045200679 0.8932989498600641
0.9028182826351002 0.9618971708354147
0.9540364432743449 0.9840080891320755
0.7484211795381923 0.9239748229907752
0.5015238688883894 0.8916022430989997
0.9434515751259799 0.98322799968967
0.6362546826621223 0.9056974640132549
0.9740730635035899 0.9925223713515443
0.7015307980047503 0.9313213821033098
0.886598227774964 0.9917409244393396
0.7648391242905223 0.8668265075520394
0.6980593645577506 0.9593063121715885
0.9973644265827782 0.9980497724821059
0.5769914509764489 0.8106384691832434
0.757573591878190