In [1]:
from attribution.api_attribution import OpenAIAttributor
from attribution.experiment_logger import ExperimentLogger
from attribution.token_perturbation import FixedPerturbationStrategy, NthNearestPerturbationStrategy
import pandas as pd
# Re-import modified modules without restarting the server (for dev use)
%load_ext autoreload
%autoreload 2

import os

from dotenv import load_dotenv

# Checks if you're using a .env file, and loads it if so.
if os.path.isfile(".env"):
    load_dotenv()

import warnings

# Suppress annoying FutureWarning from huggingface_hub that is not our fault
warnings.filterwarnings("ignore", category=FutureWarning, module="huggingface_hub")


  from .autonotebook import tqdm as notebook_tqdm


In [2]:
gpt4_attributor = OpenAIAttributor(
    openai_api_key=os.getenv("OPENAI_API_KEY"),
    max_concurrent_requests=5,
    openai_model="gpt-4o-mini",
)


In [40]:

samples = [
    {"prompt": "The clock shows 9:47. How many minutes until 11?", "key_strs":["9:47","11"], "answer":"73"},
    # {"prompt": "Maria is 37 years old today. How many years till she's 50?", "key_strs":["37","50"], "answer":"13"},
    # {"prompt": "John has 83 books on his shelf. If he buys 17 more books, how many will he have in total?", "key_strs":["83","17"], "answer":"100"},
    # {"prompt": "What is the capital of Japan?", "key_strs": ["capital", "Japan"], "answer":"Tokyo"},
    # {"prompt": "In which continent is Johannesburg?", "key_strs": ["Johannesburg", "continent"], "answer":"Africa"},
    # {"prompt": "Which element has the symbol O?", "key_strs": ["element", "O"], "answer":"Oxygen"},
    # {"prompt": "What is the largest bird?", "key_strs": ["largest", "bird"], "answer":"Ostritch"},
    # {"prompt": "What is the smallest prime number?", "key_strs": ["smallest", "prime"], "answer":"2"},
    # {"prompt": "What colour does mixing red and blue create?", "key_strs":["red", "blue"], "answer":"purple"},
    # {"prompt": "What is frozen water called?", "key_strs":["frozen", "water"], "answer":"ice"},
    ]

brevity_prompt = " Answer in one word."

In [42]:
logger = ExperimentLogger()

algos = ['iterative', 'hierarchical']
attr_methods = ['prob_diff', 'cosine']
perturb_methods = ['fixed', 'nth_nearest']

results = []

for sample in samples:
    for algo in algos:
        for attr_method in attr_methods:
            for perturb_method in perturb_methods:
                prompt = sample["prompt"] + brevity_prompt
                key_strs = sample["key_strs"]
                sample_copy = sample.copy()
                sample_copy.update({"algorithm":algo, "attr_method":attr_method, "perturb_method":perturb_method})

                if perturb_method == 'fixed':
                    perturbation_strategy = FixedPerturbationStrategy(replacement_token="")
                elif perturb_method == 'nth_nearest':
                    perturbation_strategy = NthNearestPerturbationStrategy(n=1)
                
                if algo == 'iterative':
                    await gpt4_attributor.iterative_perturbation(
                        prompt,
                        logger=logger,
                        attribution_strategies=[attr_method],
                        unit_definition="word",
                        perturbation_strategy=perturbation_strategy,
                    )
                elif algo == 'hierarchical':
                    await gpt4_attributor.hierarchical_perturbation(
                        prompt,
                        logger=logger,
                        attribution_strategies=[attr_method],
                        unit_definition="word",
                        perturbation_strategy=perturbation_strategy,
                    )
                else:
                    break

                attr = logger.df_input_token_attribution
                exps = logger.df_experiments

                attr = attr[attr["exp_id"] == attr["exp_id"].max()]
                exps = exps[exps["exp_id"] == exps["exp_id"].max()]
                api_calls = exps["num_llm_calls"].values[0]
                duration = exps["duration"].values[0]

                mean_attr = attr['attr_score'].mean()
                key_attr = 0
                for key in key_strs:
                    filtered = attr[attr["input_token"].str.contains(key)]
                    key_attr += filtered["attr_score"].values[0]
                key_attr = key_attr/len(key_strs)
                correct = key_attr > mean_attr
                sample_copy.update({"api_calls":api_calls, "duration":duration,
                               "mean_attr":mean_attr, "mean_key_attr":key_attr, "correct":correct})
                logger.print_text_total_attribution(exp_id=-1)
                print(sample_copy)
                results.append(sample_copy)

Sending 5 concurrent requests at a time: 100%|██████████| 3/3 [00:01<00:00,  2.80it/s]


{'prompt': 'The clock shows 9:47. How many minutes until 11?', 'key_strs': ['9:47', '11'], 'answer': '73', 'algorithm': 'iterative', 'attr_method': 'prob_diff', 'perturb_method': 'fixed', 'api_calls': 14, 'duration': 1.1010041236877441, 'mean_attr': 0.4050741902848444, 'mean_key_attr': 0.9220638977843478, 'correct': True}


Sending 5 concurrent requests at a time: 100%|██████████| 3/3 [00:00<00:00,  3.10it/s]


{'prompt': 'The clock shows 9:47. How many minutes until 11?', 'key_strs': ['9:47', '11'], 'answer': '73', 'algorithm': 'iterative', 'attr_method': 'prob_diff', 'perturb_method': 'nth_nearest', 'api_calls': 14, 'duration': 4.032686948776245, 'mean_attr': 0.26716872286488286, 'mean_key_attr': 0.9996987836547553, 'correct': True}


Sending 5 concurrent requests at a time: 100%|██████████| 3/3 [00:01<00:00,  2.47it/s]


{'prompt': 'The clock shows 9:47. How many minutes until 11?', 'key_strs': ['9:47', '11'], 'answer': '73', 'algorithm': 'iterative', 'attr_method': 'cosine', 'perturb_method': 'fixed', 'api_calls': 14, 'duration': 1.2648849487304688, 'mean_attr': 0.258263872219966, 'mean_key_attr': 0.6294264793395996, 'correct': True}


Sending 5 concurrent requests at a time: 100%|██████████| 3/3 [00:01<00:00,  2.76it/s]


{'prompt': 'The clock shows 9:47. How many minutes until 11?', 'key_strs': ['9:47', '11'], 'answer': '73', 'algorithm': 'iterative', 'attr_method': 'cosine', 'perturb_method': 'nth_nearest', 'api_calls': 14, 'duration': 3.2586612701416016, 'mean_attr': 0.11651701652086698, 'mean_key_attr': 0.47622236609458923, 'correct': True}


Sending 5 concurrent requests at a time: 100%|██████████| 2/2 [00:02<00:00,  1.41s/it]
Sending 5 concurrent requests at a time: 100%|██████████| 2/2 [00:02<00:00,  1.24s/it]


{'prompt': 'The clock shows 9:47. How many minutes until 11?', 'key_strs': ['9:47', '11'], 'answer': '73', 'algorithm': 'hierarchical', 'attr_method': 'prob_diff', 'perturb_method': 'fixed', 'api_calls': 19, 'duration': 7.217117786407471, 'mean_attr': 0.21900062778142615, 'mean_key_attr': 0.16658597016874915, 'correct': False}


Sending 5 concurrent requests at a time: 100%|██████████| 2/2 [00:01<00:00,  1.69it/s]
Sending 5 concurrent requests at a time: 100%|██████████| 2/2 [00:01<00:00,  1.90it/s]


{'prompt': 'The clock shows 9:47. How many minutes until 11?', 'key_strs': ['9:47', '11'], 'answer': '73', 'algorithm': 'hierarchical', 'attr_method': 'prob_diff', 'perturb_method': 'nth_nearest', 'api_calls': 19, 'duration': 12.134681940078735, 'mean_attr': 0.19298478293716287, 'mean_key_attr': 0.16658597016874915, 'correct': False}


Sending 5 concurrent requests at a time: 100%|██████████| 2/2 [00:02<00:00,  1.20s/it]
Sending 5 concurrent requests at a time: 100%|██████████| 2/2 [00:01<00:00,  1.94it/s]


{'prompt': 'The clock shows 9:47. How many minutes until 11?', 'key_strs': ['9:47', '11'], 'answer': '73', 'algorithm': 'hierarchical', 'attr_method': 'cosine', 'perturb_method': 'fixed', 'api_calls': 17, 'duration': 4.353753089904785, 'mean_attr': 0.10917184935059658, 'mean_key_attr': 0.11076738437016806, 'correct': True}


Sending 5 concurrent requests at a time: 100%|██████████| 2/2 [00:01<00:00,  1.62it/s]
Sending 5 concurrent requests at a time: 100%|██████████| 2/2 [00:00<00:00,  2.21it/s]


{'prompt': 'The clock shows 9:47. How many minutes until 11?', 'key_strs': ['9:47', '11'], 'answer': '73', 'algorithm': 'hierarchical', 'attr_method': 'cosine', 'perturb_method': 'nth_nearest', 'api_calls': 19, 'duration': 12.298116207122803, 'mean_attr': 0.08260827040185734, 'mean_key_attr': 0.07204000155131023, 'correct': False}


In [44]:

results_df = pd.DataFrame(results)
display(results_df)

Unnamed: 0,prompt,key_strs,answer,algorithm,attr_method,perturb_method,api_calls,duration,mean_attr,mean_key_attr,correct
0,The clock shows 9:47. How many minutes until 11?,"[9:47, 11]",73,iterative,prob_diff,fixed,14,1.101004,0.405074,0.922064,True
1,The clock shows 9:47. How many minutes until 11?,"[9:47, 11]",73,iterative,prob_diff,nth_nearest,14,4.032687,0.267169,0.999699,True
2,The clock shows 9:47. How many minutes until 11?,"[9:47, 11]",73,iterative,cosine,fixed,14,1.264885,0.258264,0.629426,True
3,The clock shows 9:47. How many minutes until 11?,"[9:47, 11]",73,iterative,cosine,nth_nearest,14,3.258661,0.116517,0.476222,True
4,The clock shows 9:47. How many minutes until 11?,"[9:47, 11]",73,hierarchical,prob_diff,fixed,19,7.217118,0.219001,0.166586,False
5,The clock shows 9:47. How many minutes until 11?,"[9:47, 11]",73,hierarchical,prob_diff,nth_nearest,19,12.134682,0.192985,0.166586,False
6,The clock shows 9:47. How many minutes until 11?,"[9:47, 11]",73,hierarchical,cosine,fixed,17,4.353753,0.109172,0.110767,True
7,The clock shows 9:47. How many minutes until 11?,"[9:47, 11]",73,hierarchical,cosine,nth_nearest,19,12.298116,0.082608,0.07204,False
