In [1]:
import sys
import os
import random
import gc
import time
import torch
import numpy as np
import pandas as pd
import pickle
import ast
from tqdm import tqdm
from scipy.sparse import csr_matrix
import itertools
from scipy.stats import spearmanr, pearsonr, kendalltau, rankdata
from sklearn.metrics import ndcg_score
from transformers import AutoModelForCausalLM, AutoTokenizer
from accelerate import Accelerator
import nltk
nltk.download('punkt')
os.environ["CUDA_VISIBLE_DEVICES"] = "2" 
current_dir = os.getcwd()
parent_dir = os.path.abspath(os.path.join(current_dir, '..'))
sys.path.append(parent_dir)
from SHapRAG import *
from SHapRAG.utils import *

  from .autonotebook import tqdm as notebook_tqdm
[nltk_data] Downloading package punkt to /root/nltk_data...
[nltk_data]   Package punkt is already up-to-date!


In [2]:
context = [
    "The weather in Chorvoq is sunny today.",
    "Berlin is the capital of Germany.", # Irrelevant
    "The Eiffel Tower is located in Paris, France.",
    "France borders several countries including Germany.",
    "The currency used in Suvsambil is the chaqa.",
    "Chorvoq is the capital of Suvsambil.",
    "Paris hosted the Summer Olympics in 1900 and 1924.",
    "Germany uses the Euro as well.", # Redundant info
    "The sun is shining in Chorvoq today",
    "It is cloudy in Berlin today."
    ]
question="what is the weather like in the capital of Suvsambil"

In [3]:
SEED = 42
# Initialize Accelerator
accelerator_main = Accelerator(mixed_precision="fp16")

# Load Model
if accelerator_main.is_main_process:
    print("Main Script: Loading model...")
# model_path = "mistralai/Mistral-7B-Instruct-v0.3"
model_path = "meta-llama/Llama-3.1-8B-Instruct"
# model_path = "Qwen/Qwen2.5-3B-Instruct"

model_cpu = AutoModelForCausalLM.from_pretrained(
    model_path,
    torch_dtype=torch.float16
)
tokenizer = AutoTokenizer.from_pretrained(model_path)
if tokenizer.pad_token is None:
    tokenizer.pad_token = tokenizer.eos_token
    model_cpu.config.pad_token_id = tokenizer.pad_token_id
    if hasattr(model_cpu, 'generation_config') and model_cpu.generation_config is not None:
        model_cpu.generation_config.pad_token_id = tokenizer.pad_token_id

if accelerator_main.is_main_process:
    print("Main Script: Preparing model with Accelerator...")
prepared_model = accelerator_main.prepare(model_cpu)
unwrapped_prepared_model = accelerator_main.unwrap_model(prepared_model)
unwrapped_prepared_model.eval()
if accelerator_main.is_main_process:
    print("Main Script: Model prepared and set to eval.")

# Define utility cache

accelerator_main.wait_for_everyone()

Detected kernel version 5.4.0, which is below the recommended minimum of 5.5.0; this can cause the process to hang. It is recommended to upgrade the kernel to the minimum version or higher.


Main Script: Loading model...


2025-10-07 08:32:47.616857: I tensorflow/core/util/port.cc:153] oneDNN custom operations are on. You may see slightly different numerical results due to floating-point round-off errors from different computation orders. To turn them off, set the environment variable `TF_ENABLE_ONEDNN_OPTS=0`.
2025-10-07 08:32:47.666392: I tensorflow/core/platform/cpu_feature_guard.cc:210] This TensorFlow binary is optimized to use available CPU instructions in performance-critical operations.
To enable the following instructions: AVX2 AVX512F AVX512_VNNI FMA, in other operations, rebuild TensorFlow with the appropriate compiler flags.
2025-10-07 08:32:48.992629: I tensorflow/core/util/port.cc:153] oneDNN custom operations are on. You may see slightly different numerical results due to floating-point round-off errors from different computation orders. To turn them off, set the environment variable `TF_ENABLE_ONEDNN_OPTS=0`.
Loading checkpoint shards: 100%|██████████| 4/4 [00:01<00:00,  2.36it/s]


Main Script: Preparing model with Accelerator...
Main Script: Model prepared and set to eval.


In [4]:
harness = ContextAttribution(
    items=context,
    query=question,
    prepared_model=prepared_model,
    prepared_tokenizer=tokenizer,
    accelerator=accelerator_main,
    utility_cache_path= None
)

Fetching 3 files: 100%|██████████| 3/3 [00:00<00:00, 12458.33it/s]


In [5]:
def gtset_k():
    return [0, 1,5]

num_questions_to_run = 50
k_values = [1, 2, 3, 4, 5]
all_results = []
extras = []
if accelerator_main.is_main_process:
    methods_results = {}
    metrics_results = {}
    extra_results = {}

    m_samples_map = {"XS": 32, "S":64, "M":128, "L":264, "XL":528, "XXL":724}

    # Store FM models for later R²/MSE
    fm_models = {}
for size_key, actual_samples in m_samples_map.items():
    # for rank in range(5, -1, -1):
            # methods_results[f"FM_WeightsLU_{rank}_{actual_samples}"], extra_results[f"Flu_{rank}_{actual_samples}"], fm_models[f"FM_WeightsLU_{rank}_{actual_samples}"] = harness.compute_wss(
            #     num_samples=actual_samples,
            #     seed=SEED,
            #     sampling="kernelshap",
            #     sur_type="fm",
            #     rank=rank
            # )
    methods_results[f"FM_dynamic_{actual_samples}"], extra_results[f"Flu_dynamic_{actual_samples}"], fm_models[f"FM_dynamic_{actual_samples}"] = harness.compute_wss_dynamic_pruning_reuse_utility(num_samples=actual_samples, pruning_strategy="elbow")

We are keeping 9 documents
We are keeping 4 documents
We are keeping 9 documents
We are keeping 3 documents
We are keeping 3 documents
We are keeping 3 documents


In [7]:
methods_results

{'FM_dynamic_32': array([ 9.70341203, -0.01213078,  0.03787155, -0.05379619,  1.15143543,
         1.34389732, -0.98866751,  0.51069104,  1.50685881,  0.        ]),
 'FM_dynamic_64': array([7.20226027, 0.        , 0.        , 0.        , 1.37971299,
        3.23773816, 0.        , 0.        , 3.35193817, 0.        ]),
 'FM_dynamic_128': array([ 6.03576339,  0.        ,  0.47841423,  0.67405757,  1.62342594,
         3.17548893,  0.26009093,  0.36703115,  3.62462117, -0.65197333]),
 'FM_dynamic_264': array([7.18764405, 0.        , 0.        , 0.        , 0.        ,
        3.73139437, 0.        , 0.        , 3.83193414, 0.        ]),
 'FM_dynamic_528': array([7.03012382, 0.        , 0.        , 0.        , 0.        ,
        3.80119356, 0.        , 0.        , 4.01609434, 0.        ]),
 'FM_dynamic_724': array([7.28277142, 0.        , 0.        , 0.        , 0.        ,
        3.74381958, 0.        , 0.        , 3.80350204, 0.        ])}