In [1]:
import sys
import os
import random
import gc
import time
import torch
import numpy as np
import pandas as pd
import ast
from tqdm import tqdm
from scipy.sparse import csr_matrix
import itertools
from scipy.stats import spearmanr, pearsonr, kendalltau, rankdata
from sklearn.metrics import ndcg_score
from transformers import AutoModelForCausalLM, AutoTokenizer
from accelerate import Accelerator
import nltk
nltk.download('punkt')
from nltk.tokenize import sent_tokenize

current_dir = os.getcwd()
parent_dir = os.path.abspath(os.path.join(current_dir, '..'))
sys.path.append(parent_dir)
from SHapRAG import *
#os.environ["CUDA_VISIBLE_DEVICES"] = "0" 

  from .autonotebook import tqdm as notebook_tqdm
[nltk_data] Downloading package punkt to
[nltk_data]     /home/ulb/code_wit/ekuzmenk/nltk_data...
[nltk_data]   Package punkt is already up-to-date!


In [2]:
#df=pd.read_json("../data/musique/musique_ans_v1.0_train.jsonl", lines=True)
df= pd.read_csv("../scripts/nq_2_positives.csv",index_col=False)

In [3]:
from os import getenv
from dotenv import load_dotenv

load_dotenv()
HF_TOKEN = os.getenv('HF_TOKEN')

In [4]:
SEED = 42
# Initialize Accelerator
accelerator_main = Accelerator(mixed_precision="fp16")

# Load Model
if accelerator_main.is_main_process:
    print("Main Script: Loading model...")
# model_path = "mistralai/Mistral-7B-Instruct-v0.3"
model_path = "meta-llama/Llama-3.1-8B-Instruct"
# model_path = "Qwen/Qwen2.5-3B-Instruct"


model_cpu = AutoModelForCausalLM.from_pretrained(
    model_path,
    torch_dtype=torch.float16,
    token=HF_TOKEN
)
tokenizer = AutoTokenizer.from_pretrained(model_path, token=HF_TOKEN)
if tokenizer.pad_token is None:
    tokenizer.pad_token = tokenizer.eos_token
    model_cpu.config.pad_token_id = tokenizer.pad_token_id
    if hasattr(model_cpu, 'generation_config') and model_cpu.generation_config is not None:
        model_cpu.generation_config.pad_token_id = tokenizer.pad_token_id

if accelerator_main.is_main_process:
    print("Main Script: Preparing model with Accelerator...")
prepared_model = accelerator_main.prepare(model_cpu)
unwrapped_prepared_model = accelerator_main.unwrap_model(prepared_model)
unwrapped_prepared_model.eval()
if accelerator_main.is_main_process:
    print("Main Script: Model prepared and set to eval.")

# Define utility cache

accelerator_main.wait_for_everyone()

  self.scaler = torch.cuda.amp.GradScaler(**kwargs)


Main Script: Loading model...


Loading checkpoint shards: 100%|█████████████████████████████████████████████████████████████████████████████████| 4/4 [00:09<00:00,  2.29s/it]


Main Script: Preparing model with Accelerator...
Main Script: Model prepared and set to eval.


In [5]:
all_sents = []
for i in range(len(df.question)):
    n = 0
    docs=ast.literal_eval(df.context[i])
    doc_sents = []
    for j in range(len(docs)):
        sents = nltk.sent_tokenize(docs[j])
        new_sents = []
        for s in range(len(sents)):
            #new_sents.append(str(n + s) + '-' + str(j) + '-' + sents[s])
            new_sents.append(sents[s])
        n += len(sents)
        doc_sents.append(new_sents)
    flat_doc_sents = [
    x
    for xs in doc_sents
    for x in xs
]
    all_sents.append(flat_doc_sents)
df['Sentences'] = all_sents

In [6]:
df.Sentences[0]

['on death row in the United States on January 1, 2013.',
 'Since 1977, the states of Texas (464), Virginia (108) and Oklahoma (94) have executed the most death row inmates.',
 ', California (683), Florida (390), Texas (330) and Pennsylvania (218) housed more than half of all inmates pending on death row.',
 ', the longest-serving prisoner on death row in the US who has been executed was Jack Alderman who served over 33 years.',
 'He was executed in Georgia in 2008.',
 'However, Alderman only holds the distinction of being the longest-serving "executed" inmate so far.',
 'A Florida inmate, Gary Alvord, arrived',
 'punishable by death penalty.',
 'But perhaps the best indicator that this law is not a deterrent to criminality is the ever-increasing number of death convicts.',
 'From 1994 to 1995 the number of persons on death row increased from 12 to 104.',
 'From 1995 to 1996 it increased to 182.',
 'In 1997 the total death convicts was at 520 and in 1998 the inmates in death row was at

In [8]:
# SENTENCE LEVEL
num_questions_to_run = 100
k_values = [1, 2]
all_metrics_data = []
all_results=[]
LDSs=[]
r2s = []
RMSEs = []

for i in tqdm(range(num_questions_to_run), disable=not accelerator_main.is_main_process):
    query = df.question[i]
    if accelerator_main.is_main_process:
        print(f"\n--- Question {i+1}/{num_questions_to_run}: {query[:60]}... ---")

    #docs=df.paragraphs[i]
    #docs=ast.literal_eval(df.context[i])
    #docs = [sent[4:] for sent in df.Sentences[i]][:20]
    docs = [sent for sent in df.Sentences[i]]

    utility_cache_base_dir = "../Experiment_data/NQ_sents"
    utility_cache_filename = f"utilities_q_idx{i}.pkl" # More robust naming
    current_utility_path = os.path.join(utility_cache_base_dir, utility_cache_filename)
    
    if accelerator_main.is_main_process: # Only main process creates directories
        os.makedirs(os.path.dirname(current_utility_path), exist_ok=True)
    
    # Initialize Harness
    harness = ContextAttribution(
        items=docs,
        query=query,
        prepared_model_for_harness=prepared_model,
        tokenizer_for_harness=tokenizer,
        accelerator_for_harness=accelerator_main,
        utility_cache_path=current_utility_path
    )

    print(f'Response: {harness.target_response}')
    # Compute metrics
    results_for_query = {}
    if accelerator_main.is_main_process:
        m_samples_map = {"L": 128, "XL": 256}
        T_iterations_map = {"L":40} 

        for size_key, num_s in m_samples_map.items():
            if 2**len(docs) < num_s and size_key != "L":
                actual_samples = max(1, 2**len(docs)-1 if 2**len(docs)>0 else 1)
            else:
                actual_samples = num_s

            if actual_samples > 0: 
                results_for_query[f"ContextCite{actual_samples}"], model_cc = harness.compute_contextcite(num_samples=actual_samples, seed=SEED)
                #results_for_query[f"FM_Shap{actual_samples}"], _, F, modelfm = harness.compute_wss(num_samples=actual_samples, seed=SEED, sampling="kernelshap",sur_type="fm", k=4)
                #results_for_query[f"BetaShap{actual_samples}"] = harness.compute_beta_shap(num_iterations_max=T_iterations_map[size_key], beta_a=16, beta_b=1, max_unique_lookups=actual_samples, seed=SEED)
                #results_for_query[f"TMC{actual_samples}"] = harness.compute_tmc_shap(num_iterations_max=T_iterations_map[size_key], performance_tolerance=0.001, max_unique_lookups=actual_samples, seed=SEED)

        #results_for_query["LOO"] = harness.compute_loo()
        #results_for_query["ARC-JSD"] = harness.compute_arc_jsd()

        #prob_topk = harness.evaluate_topk_performance(
        #                                        results_for_query, 
        #                                        k_values, 
        #                                        utility_type="probability"
        #                                    )
    
        #div_topk = harness.evaluate_topk_performance(
        #                                    results_for_query, 
        #                                    k_values, 
        #                                    utility_type="divergence"
        #                                )
        LDS = []
        r2 = []
        RMSE = []
        for i in results_for_query:
            if "FM_Shap" in i:
                calculate_LDS = {i:harness.lds(results_for_query[i], 30, utl=True, model=modelfm)}
                calculate_r2 = {i:harness.r2(30, model=modelfm)}
                calculate_RMSE = {i:harness.RMSE(30, model=modelfm)}
                LDS.append(calculate_LDS)
                r2.append(calculate_r2)
                RMSE.append(calculate_RMSE)
            elif "ContextCite" in i:
                calculate_LDS = {i:harness.lds(results_for_query[i], 30)}
                calculate_r2 = {i:harness.r2(30, model=model_cc, method="notfm")}
                calculate_RMSE = {i:harness.RMSE(30, model=model_cc, method="notfm")}
                LDS.append(calculate_LDS)
                r2.append(calculate_r2)
                RMSE.append(calculate_RMSE)
        #LDS = [{i:harness.lds(results_for_query[i], 30)} for i in results_for_query]
        LDSs.append(LDS)
        r2s.append(r2)
        RMSEs.append(RMSE)

        #results_for_query["topk_probability"] = prob_topk
        #results_for_query["topk_divergence"] = div_topk
        harness.save_utility_cache(current_utility_path)
        
        all_results.append(results_for_query)

  0%|                                                  | 0/100 [00:00<?, ?it/s]


--- Question 1/100: total number of death row inmates in the us... ---
Loading existing utility cache from ../Experiment_data/NQ_sents/utilities_q_idx0.pkl...
Successfully loaded 1034 cached utilities.
Response: As of November 1999, there were 956 death convicts at the National Bilibid Prisons and at the Correctional Institute for Women.


  1%|▍                                       | 1/100 [00:50<1:23:35, 50.66s/it]


--- Question 2/100: big little lies season 2 how many episodes... ---
Loading existing utility cache from ../Experiment_data/NQ_sents/utilities_q_idx1.pkl...
Successfully loaded 847 cached utilities.
Response: 7


  2%|▊                                       | 2/100 [01:41<1:22:46, 50.68s/it]


--- Question 3/100: who sang waiting for a girl like you... ---
Loading existing utility cache from ../Experiment_data/NQ_sents/utilities_q_idx2.pkl...
Successfully loaded 847 cached utilities.
Response: Foreigner.


  3%|█▏                                      | 3/100 [02:32<1:22:27, 51.01s/it]


--- Question 4/100: where do you cross the arctic circle in norway... ---
Loading existing utility cache from ../Experiment_data/NQ_sents/utilities_q_idx3.pkl...
Successfully loaded 847 cached utilities.
Response: Saltfjellet.


  4%|█▌                                      | 4/100 [03:23<1:21:17, 50.81s/it]


--- Question 5/100: who is the main character in green eggs and ham... ---
Loading existing utility cache from ../Experiment_data/NQ_sents/utilities_q_idx4.pkl...
Successfully loaded 847 cached utilities.
Response: The main character in Green Eggs and Ham is a strange creature.


  5%|██                                      | 5/100 [04:16<1:21:43, 51.61s/it]


--- Question 6/100: do veins carry blood to the heart or away... ---
Loading existing utility cache from ../Experiment_data/NQ_sents/utilities_q_idx5.pkl...
Successfully loaded 847 cached utilities.
Response: Veins carry blood toward the heart.


  6%|██▍                                     | 6/100 [05:06<1:20:07, 51.14s/it]


--- Question 7/100: who played charlie bucket in the original charlie and the ch... ---
Loading existing utility cache from ../Experiment_data/NQ_sents/utilities_q_idx6.pkl...
Successfully loaded 847 cached utilities.
Response: Peter Ostrum.


  7%|██▊                                     | 7/100 [06:01<1:20:57, 52.24s/it]


--- Question 8/100: what is 1 radian in terms of pi... ---
Loading existing utility cache from ../Experiment_data/NQ_sents/utilities_q_idx7.pkl...
Successfully loaded 847 cached utilities.
Response: 1 radian is equal to π/180.


  8%|███▏                                    | 8/100 [06:53<1:20:08, 52.26s/it]


--- Question 9/100: when does season 5 of bates motel come out... ---
Loading existing utility cache from ../Experiment_data/NQ_sents/utilities_q_idx8.pkl...
Successfully loaded 847 cached utilities.
Response: September 19, 2017.


  9%|███▌                                    | 9/100 [07:46<1:19:49, 52.64s/it]


--- Question 10/100: how many episodes are in series 7 game of thrones... ---
Loading existing utility cache from ../Experiment_data/NQ_sents/utilities_q_idx9.pkl...
Successfully loaded 847 cached utilities.
Response: 7


 10%|███▉                                   | 10/100 [08:38<1:18:39, 52.44s/it]


--- Question 11/100: who is next in line to be the monarch of england... ---
Loading existing utility cache from ../Experiment_data/NQ_sents/utilities_q_idx10.pkl...
Successfully loaded 847 cached utilities.
Response: Prince William, Duke of Cambridge.


 11%|████▎                                  | 11/100 [09:31<1:17:52, 52.50s/it]


--- Question 12/100: who is in charge of enforcing the pendleton act of 1883... ---
Loading existing utility cache from ../Experiment_data/NQ_sents/utilities_q_idx11.pkl...
Successfully loaded 847 cached utilities.
Response: The Civil Service Commission.


 12%|████▋                                  | 12/100 [10:21<1:15:47, 51.67s/it]


--- Question 13/100: what is the name of latest version of android... ---
Loading existing utility cache from ../Experiment_data/NQ_sents/utilities_q_idx12.pkl...
Successfully loaded 847 cached utilities.
Response: Android Pie.


 13%|█████                                  | 13/100 [11:14<1:15:43, 52.23s/it]


--- Question 14/100: why was there so much interest in cuba both before and after... ---
Loading existing utility cache from ../Experiment_data/NQ_sents/utilities_q_idx13.pkl...
Successfully loaded 756 cached utilities.
Response: Historians have debated America's intentions in Cuba, with some initially believing it was due to humanitarian interest in the Cuban people.


 14%|█████▍                                 | 14/100 [12:08<1:15:28, 52.65s/it]


--- Question 15/100: when did veterans day start being called veterans day... ---
Loading existing utility cache from ../Experiment_data/NQ_sents/utilities_q_idx14.pkl...
Successfully loaded 847 cached utilities.
Response: May 26, 1954.


 15%|█████▊                                 | 15/100 [13:01<1:14:47, 52.80s/it]


--- Question 16/100: when did big air snowboarding become an olympic sport... ---
Loading existing utility cache from ../Experiment_data/NQ_sents/utilities_q_idx15.pkl...
Successfully loaded 847 cached utilities.
Response: 2018.


 16%|██████▏                                | 16/100 [13:52<1:13:05, 52.20s/it]


--- Question 17/100: who played in the most world series games... ---
Loading existing utility cache from ../Experiment_data/NQ_sents/utilities_q_idx16.pkl...
Successfully loaded 847 cached utilities.
Response: The New York Yankees.


 17%|██████▋                                | 17/100 [14:45<1:12:45, 52.60s/it]


--- Question 18/100: who sings i can't stop this feeling anymore... ---
Loading existing utility cache from ../Experiment_data/NQ_sents/utilities_q_idx17.pkl...
Successfully loaded 848 cached utilities.
Response: Justin Timberlake.


 18%|███████                                | 18/100 [15:37<1:11:35, 52.39s/it]


--- Question 19/100: who is the month of may named after... ---
Loading existing utility cache from ../Experiment_data/NQ_sents/utilities_q_idx18.pkl...
Successfully loaded 847 cached utilities.
Response: The month of May is named after the goddess Maia, a Greek and Roman goddess of fertility.


 19%|███████▍                               | 19/100 [16:30<1:11:00, 52.60s/it]


--- Question 20/100: who has the most petroleum in the world... ---
Loading existing utility cache from ../Experiment_data/NQ_sents/utilities_q_idx19.pkl...
Successfully loaded 847 cached utilities.
Response: Venezuela.


 20%|███████▊                               | 20/100 [17:22<1:09:38, 52.23s/it]


--- Question 21/100: who is the sister of for king and country... ---
Loading existing utility cache from ../Experiment_data/NQ_sents/utilities_q_idx20.pkl...
Successfully loaded 847 cached utilities.
Response: Rebecca St. James.


 21%|████████▏                              | 21/100 [18:15<1:09:12, 52.56s/it]


--- Question 22/100: who developed the first periodic table with 8 columns... ---
Loading existing utility cache from ../Experiment_data/NQ_sents/utilities_q_idx21.pkl...
Successfully loaded 847 cached utilities.
Response: Gilbert N. Lewis and Irving Langmuir.


 22%|████████▌                              | 22/100 [19:08<1:08:37, 52.79s/it]


--- Question 23/100: who plays skyler on lab rats elite force... ---
Loading existing utility cache from ../Experiment_data/NQ_sents/utilities_q_idx22.pkl...
Successfully loaded 847 cached utilities.
Response: Paris Berelc plays Skylar Storm on Lab Rats: Elite Force.


 23%|████████▉                              | 23/100 [20:02<1:07:55, 52.93s/it]


--- Question 24/100: when is season seven of game of thrones coming out... ---
Loading existing utility cache from ../Experiment_data/NQ_sents/utilities_q_idx23.pkl...
Successfully loaded 847 cached utilities.
Response: July 16, 2017.


 24%|█████████▎                             | 24/100 [20:54<1:06:46, 52.71s/it]


--- Question 25/100: who went home on rupaul's drag race season 10 episode 4... ---
Loading existing utility cache from ../Experiment_data/NQ_sents/utilities_q_idx24.pkl...
Successfully loaded 847 cached utilities.
Response: Dusty Ray Bottoms.


 24%|█████████▎                             | 24/100 [21:16<1:07:22, 53.20s/it]


KeyboardInterrupt: 

In [None]:
# DOC LEVEL
num_questions_to_run = 100
k_values = [1, 2]
all_metrics_data = []
all_results=[]
LDSs=[]

for i in tqdm(range(num_questions_to_run), disable=not accelerator_main.is_main_process):
    query = df.question[i]
    if accelerator_main.is_main_process:
        print(f"\n--- Question {i+1}/{num_questions_to_run}: {query[:60]}... ---")

    #docs=df.paragraphs[i]
    docs=ast.literal_eval(df.context[i])

    utility_cache_base_dir = "../Experiment_data/NQ"
    utility_cache_filename = f"utilities_q_idx{i}.pkl" # More robust naming
    current_utility_path = os.path.join(utility_cache_base_dir, utility_cache_filename)
    
    if accelerator_main.is_main_process: # Only main process creates directories
        os.makedirs(os.path.dirname(current_utility_path), exist_ok=True)
    
    # Initialize Harness
    harness = ContextAttribution(
        items=docs,
        query=query,
        prepared_model_for_harness=prepared_model,
        tokenizer_for_harness=tokenizer,
        accelerator_for_harness=accelerator_main,
        utility_cache_path=current_utility_path
    )

    print(f'Response: {harness.target_response}')
    # Compute metrics
    results_for_query = {}
    if accelerator_main.is_main_process:
        m_samples_map = {"L": 100} 
        T_iterations_map = {"L":40} 

        for size_key, num_s in m_samples_map.items():
            if 2**len(docs) < num_s and size_key != "L":
                actual_samples = max(1, 2**len(docs)-1 if 2**len(docs)>0 else 1)
            else:
                actual_samples = num_s

            if actual_samples > 0: 
                results_for_query[f"ContextCite{actual_samples}"], model_cc = harness.compute_contextcite(num_samples=actual_samples, seed=SEED)
                results_for_query[f"FM_Shap{actual_samples}"], _, F, modelfm = harness.compute_wss(num_samples=actual_samples, seed=SEED, sampling="kernelshap",sur_type="fm")
                #results_for_query[f"BetaShap{actual_samples}"] = harness.compute_beta_shap(num_iterations_max=T_iterations_map[size_key], beta_a=16, beta_b=1, max_unique_lookups=actual_samples, seed=SEED)
                #results_for_query[f"TMC{actual_samples}"] = harness.compute_tmc_shap(num_iterations_max=T_iterations_map[size_key], performance_tolerance=0.001, max_unique_lookups=actual_samples, seed=SEED)

        results_for_query["LOO"] = harness.compute_loo()
        results_for_query["ARC-JSD"] = harness.compute_arc_jsd()

        prob_topk = harness.evaluate_topk_performance(
                                                results_for_query, 
                                                k_values, 
                                                utility_type="probability"
                                            )
    
        div_topk = harness.evaluate_topk_performance(
                                            results_for_query, 
                                            k_values, 
                                            utility_type="divergence"
                                        )
        LDS = []
        for i in results_for_query:
            if "FM_Shap" in i:
                calculate_LDS = {i:harness.lds(results_for_query[i], 30, utl=True, model=modelfm)}
                LDS.append(calculate_LDS)
            else:
                calculate_LDS = {i:harness.lds(results_for_query[i], 30)}
                LDS.append(calculate_LDS)
        #LDS = [{i:harness.lds(results_for_query[i], 30)} for i in results_for_query]
        LDSs.append(LDS)

        results_for_query["topk_probability"] = prob_topk
        results_for_query["topk_divergence"] = div_topk
        harness.save_utility_cache(current_utility_path)
        
        all_results.append(results_for_query)

In [None]:
all_results

In [6]:
import json
import copy
all_results_json = copy.deepcopy(all_results)

    # Recursively convert NumPy arrays to lists
def convert_numpy_to_list(obj):
    if isinstance(obj, np.ndarray):
        return obj.tolist()
    elif isinstance(obj, (np.integer, np.floating, np.bool_)):
        # Convert NumPy scalar types to Python native types
        return obj.item()
    elif isinstance(obj, dict):
        return {k: convert_numpy_to_list(v) for k, v in obj.items()}
    elif isinstance(obj, list):
            return [convert_numpy_to_list(elem) for elem in obj]
    else:
        return obj

converted_all_results = convert_numpy_to_list(all_results_json)

with open('NQ_sents_50_FM20_128.json', 'w') as f:
    json.dump(converted_all_results, f, indent=4)

In [13]:
with open('NQ_sents_50_FM5_128.json', 'r') as f:
    reloaded_all_results = json.load(f)

In [None]:
consolidated_RMSE

In [8]:
consolidated_LDS = [{key: value for d in inner_list for key, value in d.items()} for inner_list in LDSs]
consolidated_R2 = [{key: value for d in inner_list for key, value in d.items()} for inner_list in r2s]
consolidated_RMSE = [{key: value for d in inner_list for key, value in d.items()} for inner_list in RMSEs]
import collections
import math

def average_list_of_dicts(list_of_dicts):
    """
    Averages numeric values across dictionaries in a list, based on common keys.
    Handles NaN values by skipping them.

    Args:
        list_of_dicts (list): A list where each element is a dictionary.
                              All dictionaries are expected to have the same keys.
                              Values may include NaNs.

    Returns:
        dict: A dictionary where keys are the original keys and values are
              the averages of the corresponding numeric values from the input dictionaries.
              NaN values are ignored for averaging.
              Returns an empty dictionary if the input list is empty.
    """
    if not list_of_dicts:
        return {} # Return empty dict for empty input, no print statement as requested

    sum_values = collections.defaultdict(float)
    count_values = collections.defaultdict(int)

    for d in list_of_dicts:
        for key, value in d.items():
            # Check if the value is numeric (int or float) AND not NaN
            if isinstance(value, (int, float)) and not math.isnan(value):
                sum_values[key] += value
                count_values[key] += 1
            # Other types (strings, bools, etc.) and NaNs are simply skipped

    averaged_dict = {}
    for key in list_of_dicts[0].keys(): # Iterate through all expected keys from the first dict
        if count_values[key] > 0:
            averaged_dict[key] = sum_values[key] / count_values[key]
        else:
            # If no valid numeric values were found for a key (e.g., all were NaN or non-numeric)
            averaged_dict[key] = float('nan') # Or 0.0, or None, depending on desired output

    return averaged_dict

import pprint
averaged_LDS = average_list_of_dicts(consolidated_LDS)
print("\n--- Averaged Results ---")
pprint.pprint(averaged_LDS)

averaged_R2 = average_list_of_dicts(consolidated_R2)
print("\n--- Averaged Results ---")
pprint.pprint(averaged_R2)

averaged_RMSE = average_list_of_dicts(consolidated_RMSE)
print("\n--- Averaged Results ---")
pprint.pprint(averaged_RMSE)


--- Averaged Results ---
{'ContextCite128': 0.7727334949984583, 'FM_Shap128': 0.49181042131650154}

--- Averaged Results ---
{'ContextCite128': 0.6258846358556795, 'FM_Shap128': -31.20602288561671}

--- Averaged Results ---
{'ContextCite128': 1.8267026093929886, 'FM_Shap128': 4.762287752280098}


In [12]:
import matplotlib.pyplot as plt
def plot_bar_chart(data_dict, title="Averaged R2 for NQ sent-level, FM20",
                   x_label="Methods", y_label="R2"):

    plot_data = []
    for key, value in data_dict.items():
        if not math.isnan(value):
            plot_data.append((key, value))
        else:
            print(f"Skipping '{key}' for plotting as its value is NaN.")

    if not plot_data:
        print("No valid numeric data to plot after filtering NaNs.")
        return

    # --- NEW: Sort the data ---
    # Sort by the value (index 1 of the tuple), in descending order
    plot_data.sort(key=lambda item: item[1], reverse=True)

    # Unpack sorted data into separate lists for plotting
    labels = [item[0] for item in plot_data]
    values = [item[1] for item in plot_data]
    
    # --- NEW: Define a list of colors ---
    # You can customize this list with any valid matplotlib color names or hex codes.
    colors = ['#1f77b4', '#ff7f0e', '#2ca02c', '#d62728', '#9467bd', '#8c564b', '#e377c2', '#7f7f7f', '#bcbd22', '#17becf']
    # If you have more bars than colors, you can use matplotlib.cm for a colormap:
    # import matplotlib.cm as cm
    # colors = cm.viridis(np.linspace(0, 1, len(labels)))

    plt.figure(figsize=(12, 7))
    # --- NEW: Pass the list of colors to plt.bar ---
    bars = plt.bar(labels, values, color=colors[:len(labels)]) # Slice to match number of bars

    # Add titles and labels
    plt.xlabel(x_label, fontsize=12)
    plt.ylabel(y_label, fontsize=12)
    plt.title(title, fontsize=14)

    # Rotate x-axis labels if they are long to prevent overlap
    plt.xticks(rotation=45, ha='right', fontsize=10) # 'ha' is horizontal alignment

    # Add value labels on top of the bars
    for bar in bars:
        yval = bar.get_height()
        plt.text(bar.get_x() + bar.get_width()/2, yval + 0.01, # Position text slightly above bar
                 round(yval, 4), # Format value to 4 decimal places
                 ha='center', va='bottom', fontsize=9)

    # Add a grid for easier reading of values
    plt.grid(axis='y', linestyle='--', alpha=0.7)

    # Adjust layout to prevent labels from being cut off
    plt.tight_layout()

    # Display the plot
    #plt.show()
    plt.savefig(f'nq_sents_100_fm20_R2_2307.png')

plot_bar_chart(averaged_R2)

In [None]:
# Probability-based results
print("\nProbability-based Top-k Performance:")
for q in reloaded_all_results:
    for method, drops in q['topk_probability'].items():
        print(f"  {method}:")
        for k, drop in drops.items():
            print(f"    k={k}: Drop = {drop:.4f}")

# Divergence-based results
print("\nDivergence-based Top-k Performance:")
for q in reloaded_all_results:
    for method, jsds in q['topk_divergence'].items():
        print(f"  {method}:")
        for k, jsd in jsds.items():
            print(f"    k={k}: JSD = {jsd:.4f}")

In [None]:
reloaded_all_results

In [14]:
topk_drops = {'ContextCite128': {}, 'FM_Shap128': {}, 'ContextCite256': {}, 'FM_Shap256': {}, 'LOO': {}, 'ARC-JSD': {}}
div_drops = {'ContextCite128': {}, 'FM_Shap128': {}, 'ContextCite256': {}, 'FM_Shap256': {}, 'LOO': {}, 'ARC-JSD': {}}
count = 0
for q in reloaded_all_results:
    for method in q:
        if method == 'topk_probability':
            drops = q[method]
            for m in drops:
                if m not in ('TMC100', 'BetaShap100'):
                    values = drops[m]
                    try:
                        topk_drops[m]['1'] += values['1']
                    except:
                        topk_drops[m]['1'] = values['1']
                    try:
                        topk_drops[m]['2'] += values['2']
                    except:
                        topk_drops[m]['2'] = values['2']
        elif method == 'topk_divergence':
            drops = q[method]
            for m in drops:
                if m not in ('TMC100', 'BetaShap100'):
                    values = drops[m]
                    try:
                        div_drops[m]['1'] += values['1']
                    except:
                        div_drops[m]['1'] = values['1']
                    try:
                        div_drops[m]['2'] += values['2']
                    except:
                        div_drops[m]['2'] = values['2']
    count += 1
topk_drops

{'ContextCite128': {'1': 270.12265995144844, '2': 392.120825111866},
 'FM_Shap128': {'1': 278.2387535870075, '2': 416.97775742411613},
 'ContextCite256': {},
 'FM_Shap256': {},
 'LOO': {'1': 278.2387535870075, '2': 328.6395903378725},
 'ARC-JSD': {'1': 270.42414382100105, '2': 302.5687276571989}}

In [15]:
for m in topk_drops:
    for k in topk_drops[m]:
        topk_drops[m][k] = topk_drops[m][k]/count
for m in div_drops:
    for k in div_drops[m]:
        div_drops[m][k] = div_drops[m][k]/count

In [16]:
topk_drops

{'ContextCite128': {'1': 5.402453199028969, '2': 7.84241650223732},
 'FM_Shap128': {'1': 5.564775071740151, '2': 8.339555148482322},
 'ContextCite256': {},
 'FM_Shap256': {},
 'LOO': {'1': 5.564775071740151, '2': 6.57279180675745},
 'ARC-JSD': {'1': 5.408482876420021, '2': 6.051374553143978}}

In [17]:
div_drops

{'ContextCite128': {'1': 0.7963868528121247, '2': 1.180995022549329},
 'FM_Shap128': {'1': 0.8192394025804025, '2': 1.2128547834401404},
 'ContextCite256': {},
 'FM_Shap256': {},
 'LOO': {'1': 0.8137818653616875, '2': 1.0525908829749526},
 'ARC-JSD': {'1': 1.0260680706004086, '2': 1.2847744130069372}}

In [16]:
import matplotlib.pyplot as plt
import numpy as np
import math # For handling potential NaN values if they were in the input

# --- Plotting Function for Grouped Bar Chart ---
def plot_grouped_bar_chart(data_for_plotting, title="Divergence Climb",
                           x_label="Methods", y_label="Averaged Divergence Climb"):
    """
    Draws a grouped bar chart for the given nested dictionary data.

    Args:
        data_for_plotting (dict): A dictionary where keys are metric names,
                                  and values are dictionaries containing numeric
                                  values for categories like '1' and '2'.
                                  e.g., {'MetricA': {'1': 0.1, '2': 0.2}}
        title (str): The title of the plot.
        x_label (str): Label for the x-axis.
        y_label (str): Label for the y-axis.
    """
    if not data_for_plotting:
        print("No data to plot. The input dictionary is empty.")
        return

    # Extract metric names and ensure a consistent order (alphabetical for clarity)
    metric_names = sorted(data_for_plotting.keys())

    # Prepare data for plotting
    values_cat1 = []
    values_cat2 = []
    
    # Store labels for the actual metrics being plotted (in case some have NaNs)
    plot_metric_labels = []

    for metric in metric_names:
        cat_data = data_for_plotting[metric]
        val1 = cat_data.get('1', float('nan')) # Use .get() to handle missing keys
        val2 = cat_data.get('2', float('nan'))

        # Only include metrics where at least one category has a valid number
        if (isinstance(val1, (int, float)) and not math.isnan(val1)) or \
           (isinstance(val2, (int, float)) and not math.isnan(val2)):
            values_cat1.append(val1 if (isinstance(val1, (int, float)) and not math.isnan(val1)) else 0)
            values_cat2.append(val2 if (isinstance(val2, (int, float)) and not math.isnan(val2)) else 0)
            plot_metric_labels.append(metric)
        else:
            print(f"Skipping metric '{metric}' as both '1' and '2' values are NaN or non-numeric.")


    if not plot_metric_labels:
        print("No valid data points found to plot after processing categories.")
        return

    # Set up positions for the bars
    bar_width = 0.35
    index = np.arange(len(plot_metric_labels)) # The x locations for the groups

    plt.figure(figsize=(14, 8)) # Adjust figure size

    # Plotting the bars
    bar1 = plt.bar(index - bar_width/2, values_cat1, bar_width, label='1', color='skyblue')
    bar2 = plt.bar(index + bar_width/2, values_cat2, bar_width, label='2', color='lightcoral')

    # Add labels, title, and legend
    plt.xlabel(x_label, fontsize=14)
    plt.ylabel(y_label, fontsize=14)
    plt.title(title, fontsize=16)
    plt.xticks(index, plot_metric_labels, rotation=45, ha='right', fontsize=12) # Set metric labels at group center
    plt.yticks(fontsize=12)
    plt.legend(fontsize=12)

    # Add value labels on top of the bars
    def autolabel(bars):
        for bar in bars:
            height = bar.get_height()
            plt.text(bar.get_x() + bar.get_width()/2, height + 0.005, # Position text slightly above bar
                     f'{height:.4f}', # Format value to 4 decimal places
                     ha='center', va='bottom', fontsize=9, rotation=0)

    autolabel(bar1)
    autolabel(bar2)

    plt.grid(axis='y', linestyle='--', alpha=0.7)
    plt.tight_layout() # Adjust layout to prevent labels from being cut off
    #plt.show()
    plt.savefig(f'nq_sents_100_fm20_div_drop_2307.png')

plot_grouped_bar_chart(div_drops)

In [9]:
[i for i in results_for_query]

['ContextCite100',
 'FM_Shap100',
 'FM_Weights100',
 'BetaShap100',
 'TMC100',
 'LOO',
 'ARC-JSD',
 'topk_probability',
 'topk_divergence']

In [10]:
df.question[9]

'how many episodes are in series 7 game of thrones'

In [11]:
harness._generate_sampled_ablations(4, sampling_method='uniform', seed=2)

In [12]:
df.context[9]

'[\'Game of Thrones (season 7) The seventh and penultimate season of the fantasy drama television series "Game of Thrones" premiered on HBO on July 16, 2017, and concluded on August 27, 2017. Unlike previous seasons that consisted of ten episodes each, the seventh season consisted of only seven. Like the previous season, it largely consisted of original content not found in George R. R. Martin\\\'s "A Song of Ice and Fire" series, while also incorporating material Martin revealed to showrunners about the upcoming novels in the series. The series was adapted for television by David Benioff and D. B. Weiss.\', \'Bender, who worked on the show\\\'s sixth season, said that the seventh season would consist of seven episodes. Benioff and Weiss stated that they were unable to produce 10 episodes in the show\\\'s usual 12 to 14 month time frame, as Weiss said "It\\\'s crossing out of a television schedule into more of a mid-range movie schedule." HBO confirmed on July 18, 2016, that the sevent

In [37]:
docs[6]

'genuine, though very obscure, saying, "only fools and horses work for a living", which had its origins in 19th-century American vaudeville. "Only Fools and Horses" had also been the title of an episode of "Citizen Smith", and Sullivan liked the expression and thought it was suited to the new sitcom. He also thought longer titles would attract attention. He was first overruled on the grounds that the audience would not understand the title, but he eventually got his way. Filming of the first series began in May 1981, and the first episode, "Big Brother", was transmitted on BBC One at\''

In [15]:
all_results[0]

{'ContextCite100': array([ 9.42799874e+00,  4.95333848e+00,  1.54510382e-01, -1.18733185e+00,
        -1.97362536e-01, -1.63601493e-03, -4.20977587e-01, -1.71711992e+00,
         5.13743546e-01, -3.61054749e-01]),
 'FM_Shap100': array([10.32068284,  5.69851447,  0.11906173, -0.44636073, -0.88331575,
        -1.00641049, -0.22007863, -0.79765982,  0.29558806, -0.81720759]),
 'FM_Weights100': array([10.32068284,  5.69851447,  0.11906173, -0.44636073, -0.88331575,
        -1.00641049, -0.22007863, -0.79765982,  0.29558806, -0.81720759]),
 'BetaShap100': array([12.18496471,  6.09588892,  0.22350156,  0.12756782, -0.06228456,
         0.05961801,  0.48444461, -0.42519495,  0.84019263, -0.43572434]),
 'TMC100': array([ 9.96529419,  5.98631053, -0.16527917, -0.53574083, -0.08729146,
        -0.09466684,  0.01608808, -0.76350622,  0.2977114 , -0.34037049]),
 'LOO': array([12.21575832,  8.61024857,  1.24591446,  0.25884056, -0.06205273,
         0.06636238,  0.56908131, -0.40280247,  0.83529949

In [16]:
F

array([[ 0.        , -2.71718673,  2.64412498,  0.42945254,  0.83703511,
         0.81296934,  2.79770344, -0.70547725, -1.78606628, -1.45755979],
       [-2.71718673,  0.        , -1.55238252,  0.47399397,  0.30587813,
        -1.0472584 , -3.32381804, -0.22556415,  1.52053671,  2.70219414],
       [ 2.64412498, -1.55238252,  0.        , -0.05849546,  0.24429583,
         1.2087218 , -1.96988449,  0.43088249,  1.2081562 ,  1.73971653],
       [ 0.42945254,  0.47399397, -0.05849546,  0.        ,  0.45773645,
         0.56765182,  0.01962721, -0.01922445,  0.25747797, -0.84770228],
       [ 0.83703511,  0.30587813,  0.24429583,  0.45773645,  0.        ,
         0.23349889,  0.07346147, -0.04195382,  0.03793556, -0.31164835],
       [ 0.81296934, -1.0472584 ,  1.2087218 ,  0.56765182,  0.23349889,
         0.        ,  0.17034947,  0.01708509,  0.06055724, -0.52206366],
       [ 2.79770344, -3.32381804, -1.96988449,  0.01962721,  0.07346147,
         0.17034947,  0.        , -0.15618031

In [17]:
all_subsets = list(itertools.product([0, 1], repeat=4))

In [18]:
sampled_tuples = harness._generate_sampled_ablations(10, sampling_method='uniform', seed=2)

In [17]:
d = {}
type(d)

dict

In [18]:
import matplotlib.pyplot as plt

for result in range(len(all_results)):
    method_scores = {}
    for method, scores in all_results[result].items():
        if scores is not None and type(scores) is not dict:
            print(method, scores)
            method_scores[method] = np.round(scores, 4)

    for method, scores in method_scores.items():
        plt.figure(figsize=(10, 4))
        plt.bar(range(len(scores)), scores, color='skyblue')
        plt.title(f"Approximate Scores: {method}")
        plt.xlabel("Index")
        plt.ylabel("Score")
        plt.grid(axis='y', linestyle='--', alpha=0.7)
        plt.xticks(range(len(scores)))
        plt.tight_layout()
        plt.savefig(f'nq_doc_plots/{result}_{method}.png')

ContextCite100 [ 9.42799874e+00  4.95333848e+00  1.54510382e-01 -1.18733185e+00
 -1.97362536e-01 -1.63601493e-03 -4.20977587e-01 -1.71711992e+00
  5.13743546e-01 -3.61054749e-01]
FM_Shap100 [10.32068284  5.69851447  0.11906173 -0.44636073 -0.88331575 -1.00641049
 -0.22007863 -0.79765982  0.29558806 -0.81720759]
FM_Weights100 [10.32068284  5.69851447  0.11906173 -0.44636073 -0.88331575 -1.00641049
 -0.22007863 -0.79765982  0.29558806 -0.81720759]
BetaShap100 [12.18172271  7.73213164  0.78913852  0.1642193  -0.06000717  0.05865822
  0.56100273 -0.43111163  0.83715638 -0.46782496]
TMC100 [ 9.96529419  5.98631053 -0.16527917 -0.53574083 -0.08729146 -0.09466684
  0.01608808 -0.76350622  0.2977114  -0.34037049]
LOO [12.21575832  8.61024857  1.24591446  0.25884056 -0.06205273  0.06636238
  0.56908131 -0.40280247  0.83529949 -0.49150467]
ARC-JSD [ 0.34588716 10.75895722  0.03822068  0.01981782  0.03414572  0.03501194
  0.05906658  0.02726513  0.04695243  0.04338138]
ContextCite100 [ 0.53418168

  plt.figure(figsize=(10, 4))
