In [1]:
from attribution.api_attribution import OpenAIAttributor
from attribution.experiment_logger import ExperimentLogger
from attribution.token_perturbation import FixedPerturbationStrategy, NthNearestPerturbationStrategy
import pandas as pd

# Re-import modified modules without restarting the server (for dev use)
%load_ext autoreload
%autoreload 2

import os

from dotenv import load_dotenv

# Checks if you're using a .env file, and loads it if so.
if os.path.isfile(".env"):
    load_dotenv()

import warnings

# Suppress annoying FutureWarning from huggingface_hub that is not our fault
warnings.filterwarnings("ignore", category=FutureWarning, module="huggingface_hub")

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
gpt4_attributor = OpenAIAttributor(
    openai_api_key=os.getenv("OPENAI_API_KEY"),
    max_concurrent_requests=5,
    openai_model="gpt-4o",
)

In [7]:
samples = [
    {
        "prompt": "What is the capital of Japan?",
        "key_strs": ["capital", "Japan"],
        "answer": "Tokyo",
    },
    # {
    #     "prompt": "The clock shows 9:47. How many minutes until 11?",
    #     "key_strs": ["9:47", "11", "minutes"],
    #     "answer": "73",
    # },
    # {
    #     "prompt": "Maria is 37 years old today. How many years till she's 50?",
    #     "key_strs": ["37", "50"],
    #     "answer": "13",
    # },
    # {
    #     "prompt": "John has 83 books on his shelf. If he buys 17 more books, how many will he have in total?",
    #     "key_strs": ["83", "17"],
    #     "answer": "100",
    # },
    # {
    #     "prompt": "In which continent is Johannesburg?",
    #     "key_strs": ["Johannesburg", "continent"],
    #     "answer": "Africa",
    # },
    # {"prompt": "Which element has the symbol O?", "key_strs": ["element", "O"], "answer": "Oxygen"},
    # {"prompt": "What is the largest bird?", "key_strs": ["largest", "bird"], "answer": "Ostrich"},
    # {
    #     "prompt": "What is the smallest prime number?",
    #     "key_strs": ["smallest", "prime"],
    #     "answer": "2",
    # },
    # {
    #     "prompt": "What colour does mixing red and blue create?",
    #     "key_strs": ["red", "blue"],
    #     "answer": "purple",
    # },
    # {"prompt": "What is frozen water called?", "key_strs": ["frozen", "water"], "answer": "ice"},
]

brevity_prompt = " Answer briefly."

In [9]:
import numpy as np

logger = ExperimentLogger()

saliency_percentages = [0.5, 0.1, 0.01]
algorithms = ["iterative", "hierarchical"]
attr_methods = ["prob_diff", "cosine"]
perturb_methods = ["fixed"]

if not os.path.exists("pizza_baseline.csv"):
    results_df = pd.DataFrame(
        columns=[
            "ix_str",
            "prompt",
            "original_prompt",
            "algorithm",
            "attr_method",
            "perturb_method",
            "target_sal_percentage",
            "prompt_length",
            "salient_percentage",
            "api_calls",
            "duration",
            "mean_attr",
            "mean_key_attr",
            "correct",
        ]
    )
else:
    results_df = pd.read_csv("pizza_baseline.csv")

for sample in samples:
    for saliency_percentage in saliency_percentages:
        for algorithm in algorithms:
            for attr_method in attr_methods:
                for perturb_method in perturb_methods:
                    prompt = sample["prompt"] + brevity_prompt
                    key_strs = sample["key_strs"]

                    prompt_extension = int(
                        max(0, ((1 / saliency_percentage) * len(key_strs)) - len(prompt.split(" ")))
                    )

                    if prompt_extension > 0:
                        prompt = ". " * prompt_extension + prompt

                    res = {
                        "algorithm": algorithm,
                        "attr_method": attr_method,
                        "perturb_method": perturb_method,
                        "target_sal_percentage": saliency_percentage,
                        "prompt": prompt,
                        "original_prompt": sample["prompt"],
                        "prompt_length": len(prompt.split(" ")),
                        "salient_percentage": len(key_strs) / len(prompt.split(" ")),
                    }

                    ix_str = str([v for v in res.values()])
                    res["ix_str"] = ix_str

                    if ix_str in results_df["ix_str"].values:
                        print("Skipping experiment...")
                        continue
                    print("Starting experiment...")
                    perturbation_strategy = (
                        FixedPerturbationStrategy(replacement_token="")
                        if perturb_method == "fixed"
                        else NthNearestPerturbationStrategy(n=-1)
                    )

                    if algorithm == "iterative":
                        await gpt4_attributor.iterative_perturbation(
                            prompt,
                            logger=logger,
                            attribution_strategies=[attr_method],
                            unit_definition="word",
                            perturbation_strategy=perturbation_strategy,
                        )
                    else:
                        await gpt4_attributor.hierarchical_perturbation(
                            prompt,
                            logger=logger,
                            attribution_strategies=[attr_method],
                            unit_definition="word",
                            perturbation_strategy=perturbation_strategy,
                        )

                    attr = logger.get_attribution_matrices(exp_id=-1)[0]
                    exps = logger.df_experiments[
                        logger.df_experiments["exp_id"] == logger.df_experiments["exp_id"].max()
                    ]
                    sample_copy = {
                        **sample,
                        **res,
                        "api_calls": exps["num_llm_calls"].values[0],
                        "duration": exps["duration"].values[0],
                        "completion": exps["original_output"].values[0],
                        "unit": exps["unit_definition"].values[0],
                    }
                    sample_copy["key_strs"] = key_strs

                    target_col = [c for c in attr.columns if sample["answer"].lower() in c.lower()]
                    if target_col:
                        target_attr = attr[target_col].T
                        mean_attr = target_attr.values.mean()
                        key_attr = {
                            key: target_attr[c].values.mean()
                            for key in key_strs
                            for c in target_attr.columns
                            if key in c
                        }
                        mean_key_attr = np.array(list(key_attr.values())).mean()
                        correct = mean_key_attr > mean_attr
                        target_attr = target_attr.to_dict()
                    else:
                        target_attr, mean_attr, key_attr, mean_key_attr, correct = (
                            None,
                            None,
                            None,
                            None,
                            None,
                        )

                    sample_copy.update(
                        {
                            "mean_attr": mean_attr,
                            "mean_key_attr": mean_key_attr,
                            "target_attr": target_attr,
                            "key_attr": key_attr,
                            "correct": correct,
                        }
                    )
                    logger.print_text_total_attribution(exp_id=-1)

                    results_df = pd.concat(
                        [results_df, pd.DataFrame([sample_copy])], ignore_index=True
                    )
                    results_df.to_csv("pizza_baseline.csv", index=False)

Skipping experiment...
Skipping experiment...
Skipping experiment...
Skipping experiment...
Starting experiment...


Sending 5 concurrent requests at a time: 100%|██████████| 4/4 [00:04<00:00,  1.14s/it]


Starting experiment...


Sending 5 concurrent requests at a time: 100%|██████████| 4/4 [00:01<00:00,  2.47it/s]


  results_df = pd.concat(


Starting experiment...


Sending 5 concurrent requests at a time: 100%|██████████| 2/2 [00:02<00:00,  1.37s/it]
Sending 5 concurrent requests at a time: 100%|██████████| 2/2 [00:03<00:00,  1.56s/it]
Sending 5 concurrent requests at a time: 100%|██████████| 2/2 [00:02<00:00,  1.42s/it]


Starting experiment...


Sending 5 concurrent requests at a time: 100%|██████████| 2/2 [00:02<00:00,  1.12s/it]
Sending 5 concurrent requests at a time: 100%|██████████| 2/2 [00:01<00:00,  1.48it/s]


  results_df = pd.concat(


Starting experiment...


Sending 5 concurrent requests at a time: 100%|██████████| 40/40 [00:08<00:00,  4.89it/s]


Starting experiment...


Sending 5 concurrent requests at a time: 100%|██████████| 40/40 [00:07<00:00,  5.69it/s]


  results_df = pd.concat(


Starting experiment...


Starting experiment...


  results_df = pd.concat(
