In [1]:
from attribution.api_attribution import OpenAIAttributor
from attribution.experiment_logger import ExperimentLogger
from attribution.token_perturbation import FixedPerturbationStrategy, NthNearestPerturbationStrategy
import pandas as pd
# Re-import modified modules without restarting the server (for dev use)
%load_ext autoreload
%autoreload 2

import os

from dotenv import load_dotenv

# Checks if you're using a .env file, and loads it if so.
if os.path.isfile(".env"):
    load_dotenv()

import warnings

# Suppress annoying FutureWarning from huggingface_hub that is not our fault
warnings.filterwarnings("ignore", category=FutureWarning, module="huggingface_hub")


  from .autonotebook import tqdm as notebook_tqdm


In [5]:
gpt4_attributor = OpenAIAttributor(
    openai_api_key=os.getenv("OPENAI_API_KEY"),
    max_concurrent_requests=10,
    openai_model="gpt-4o",
)


In [3]:

samples = [
    {"prompt": "The clock shows 9:47. How many minutes until 11?", "key_strs":["9:47","11", "minutes"], "answer":"73"},
    # {"prompt": "Maria is 37 years old today. How many years till she's 50?", "key_strs":["37","50"], "answer":"13"},
    # {"prompt": "John has 83 books on his shelf. If he buys 17 more books, how many will he have in total?", "key_strs":["83","17"], "answer":"100"},
    # {"prompt": "What is the capital of Japan?", "key_strs": ["capital", "Japan"], "answer":"Tokyo"},
    # {"prompt": "In which continent is Johannesburg?", "key_strs": ["Johannesburg", "continent"], "answer":"Africa"},
    # {"prompt": "Which element has the symbol O?", "key_strs": ["element", "O"], "answer":"Oxygen"},
    # {"prompt": "What is the largest bird?", "key_strs": ["largest", "bird"], "answer":"Ostritch"},
    # {"prompt": "What is the smallest prime number?", "key_strs": ["smallest", "prime"], "answer":"2"},
    # {"prompt": "What colour does mixing red and blue create?", "key_strs":["red", "blue"], "answer":"purple"},
    # {"prompt": "What is frozen water called?", "key_strs":["frozen", "water"], "answer":"ice"},
    ]

brevity_prompt = " Answer briefly."

extension_text = "Lorem ipsum dolor sit amet, consectetur adipiscing elit. Donec velit erat, auctor in nisi ac, porttitor iaculis erat. Phasellus nec viverra massa. Cras suscipit rutrum elit faucibus consectetur. Nulla ligula elit, rutrum quis pharetra quis, ornare nec arcu. Aliquam sagittis ipsum non aliquam varius. Pellentesque euismod mi dapibus, facilisis sapien venenatis, auctor libero. Phasellus porta eget orci sed feugiat. Phasellus pharetra sem ullamcorper accumsan ornare. Quisque eu leo ipsum. Mauris ultricies congue risus, ut tincidunt mi rhoncus eget. Integer ut accumsan est, pulvinar finibus purus. Morbi quis molestie libero, a convallis risus. Nullam sit amet maximus sem, vel rutrum ipsum. Sed in condimentum augue. Phasellus vitae est nisi. Mauris malesuada urna elit, eget volutpat augue dapibus at. In aliquam nisi purus, tincidunt scelerisque nisi blandit ut. Aenean sem lacus, dapibus id dui ac, congue finibus nunc. Phasellus tellus odio, pellentesque eu est eu, vehicula fringilla ante. Donec vitae iaculis mauris. Nunc mollis feugiat odio ut tincidunt. Aenean ut bibendum augue, eu suscipit mi. Praesent pretium viverra mollis. Donec sit amet accumsan dolor. Maecenas ornare elit ac felis feugiat, et dignissim enim accumsan. Nullam commodo maximus sapien, ut elementum ex porta quis. Donec sit amet dignissim dolor. Aliquam sed rutrum orci."


In [8]:
logger = ExperimentLogger()

saliency_percentages = [0.5, 0.1, 0.01]
algorithms = ['iterative', 'hierarchical']
attr_methods = ['prob_diff', 'cosine']
perturb_methods = ['fixed', 'nearest']

if not os.path.exists("pizza_baseline.csv"):
    results_df = pd.DataFrame(columns=["prompt", "original_prompt", "algorithm", "attr_method", "perturb_method", "target_sal_percentage", "prompt_length", "salient_percentage", "api_calls", "duration", "mean_attr", "mean_key_attr", "correct"])
else:
    results_df = pd.read_csv("pizza_baseline.csv")
    results_df = results_df.drop_duplicates()

for sample in samples:
    for saliency_percentage in saliency_percentages:
        for algorithm in algorithms:
            for attr_method in attr_methods:
                for perturb_method in perturb_methods:
                    prompt = sample["prompt"] + brevity_prompt
                    key_strs = sample["key_strs"]
                    
                    prompt_extension = int(max(0, ((1/saliency_percentage) * len(key_strs)) - len(prompt.split(' '))))

                    if prompt_extension > 0:
                        prompt = ' '.join(extension_text.split(' ')[:prompt_extension]) + prompt

                    res = {
                        "algorithm": algorithm,
                        "attr_method": attr_method,
                        "perturb_method": perturb_method,
                        "target_sal_percentage": saliency_percentage,
                        "prompt": prompt,
                        "original_prompt": sample["prompt"],
                        "prompt_length": len(prompt.split(' ')),
                        "salient_percentage": len(key_strs) / len(prompt.split(' '))
                    }

                    if any(results_df[list(res.keys())].apply(lambda x: x.equals(pd.DataFrame([res]).iloc[0]), axis=1)):
                        print("Experiment already done, skipping...")
                        continue
                    print('Starting experiment...')
                    perturbation_strategy = FixedPerturbationStrategy(replacement_token="") if perturb_method == 'fixed' else NthNearestPerturbationStrategy(n=-1)

                    if algorithm == 'iterative':
                        await gpt4_attributor.iterative_perturbation(prompt, logger=logger, attribution_strategies=[attr_method], unit_definition='word', perturbation_strategy=perturbation_strategy)
                    else:
                        await gpt4_attributor.hierarchical_perturbation(prompt, logger=logger, attribution_strategies=[attr_method], unit_definition='word', perturbation_strategy=perturbation_strategy)

                    attr = logger.get_attribution_matrices(exp_id=-1)[0]
                    exps = logger.df_experiments[logger.df_experiments["exp_id"] == logger.df_experiments["exp_id"].max()]
                    sample_copy = {
                        **sample,
                        **res,
                        "api_calls": exps["num_llm_calls"].values[0],
                        "duration": exps["duration"].values[0]
                    }
                    sample_copy['key_strs'] = str(key_strs)

                    target_col = [c for c in attr.columns if sample["answer"].lower() in c.lower()]
                    if target_col:
                        target_attr = attr[target_col].T
                        mean_attr = target_attr.values.mean()
                        key_attr = sum(target_attr[c].values.mean() for key in key_strs for c in target_attr.columns if key in c) / len(key_strs)
                        correct = key_attr > (mean_attr - key_attr)
                    else:
                        mean_attr, key_attr, correct = None, None, None

                    sample_copy.update({"mean_attr": mean_attr, "mean_key_attr": key_attr, "correct": correct})
                    logger.print_text_total_attribution(exp_id=-1)

                    results_df = pd.concat([results_df, pd.DataFrame([sample_copy])], ignore_index=True)
                    results_df.to_csv("pizza_baseline.csv", index=False)


Experiment already done, skipping...
Experiment already done, skipping...
Experiment already done, skipping...
Experiment already done, skipping...
Experiment already done, skipping...
Experiment already done, skipping...
Experiment already done, skipping...
Experiment already done, skipping...
Starting experiment...


Sending 10 concurrent requests at a time: 100%|██████████| 3/3 [00:03<00:00,  1.09s/it]


Starting experiment...


Sending 10 concurrent requests at a time: 100%|██████████| 3/3 [00:03<00:00,  1.14s/it]


Starting experiment...


Sending 10 concurrent requests at a time: 100%|██████████| 3/3 [00:04<00:00,  1.41s/it]


Starting experiment...


Sending 10 concurrent requests at a time: 100%|██████████| 3/3 [00:03<00:00,  1.30s/it]


Starting experiment...


Sending 10 concurrent requests at a time: 100%|██████████| 2/2 [00:03<00:00,  1.76s/it]


Starting experiment...


CancelledError: 

  results_df = results_df.applymap(lambda x: str(x) if type(x) is list else x)


Unnamed: 0,prompt,original_prompt,algorithm,attr_method,perturb_method,target_sal_percentage,prompt_length,salient_percentage,api_calls,duration,mean_attr,mean_key_attr,correct,key_strs,answer
0,The clock shows 9:47. How many minutes until 1...,The clock shows 9:47. How many minutes until 11?,iterative,prob_diff,fixed,0.5,11,0.272727,12,2.909489,0.38886,0.953486,True,"['9:47', '11', 'minutes']",73
1,The clock shows 9:47. How many minutes until 1...,The clock shows 9:47. How many minutes until 11?,iterative,prob_diff,nearest,0.5,11,0.272727,12,4.438298,0.290838,0.77983,True,"['9:47', '11', 'minutes']",73
2,The clock shows 9:47. How many minutes until 1...,The clock shows 9:47. How many minutes until 11?,iterative,cosine,fixed,0.5,11,0.272727,12,15.607642,0.225492,0.543647,True,"['9:47', '11', 'minutes']",73
3,The clock shows 9:47. How many minutes until 1...,The clock shows 9:47. How many minutes until 11?,iterative,cosine,nearest,0.5,11,0.272727,12,5.147579,0.142162,0.419924,True,"['9:47', '11', 'minutes']",73
4,The clock shows 9:47. How many minutes until 1...,The clock shows 9:47. How many minutes until 11?,hierarchical,prob_diff,fixed,0.5,11,0.272727,16,6.211689,0.204888,0.290897,True,"['9:47', '11', 'minutes']",73
5,The clock shows 9:47. How many minutes until 1...,The clock shows 9:47. How many minutes until 11?,hierarchical,prob_diff,nearest,0.5,11,0.272727,29,22.641201,0.235105,0.386029,True,"['9:47', '11', 'minutes']",73
6,The clock shows 9:47. How many minutes until 1...,The clock shows 9:47. How many minutes until 11?,hierarchical,cosine,fixed,0.5,11,0.272727,16,8.281292,0.134561,0.162158,True,"['9:47', '11', 'minutes']",73
7,The clock shows 9:47. How many minutes until 1...,The clock shows 9:47. How many minutes until 11?,hierarchical,cosine,nearest,0.5,11,0.272727,25,23.23092,0.126784,0.212061,True,"['9:47', '11', 'minutes']",73
8,"Lorem ipsum dolor sit amet, consectetur adipis...",The clock shows 9:47. How many minutes until 11?,iterative,prob_diff,fixed,0.1,29,0.103448,30,2.590039,0.134239,0.964428,True,"['9:47', '11', 'minutes']",73
9,"Lorem ipsum dolor sit amet, consectetur adipis...",The clock shows 9:47. How many minutes until 11?,iterative,prob_diff,nearest,0.1,29,0.103448,30,14.169133,0.118374,0.952819,True,"['9:47', '11', 'minutes']",73
