## PIZZA: An Open Source Library for Closed LLM Attribution (or “why did ChatGPT say that?”)

In [110]:
# Set your open ai API key
# BEWARE: This will cost you API credits!
YOUR_OPENAI_API_KEY = None

In [111]:
import warnings

# Suppress annoying FutureWarning from huggingface_hub
warnings.filterwarnings("ignore", category=FutureWarning, module="huggingface_hub")

In [112]:
# Re-import modified modules without restarting the server
%load_ext autoreload
%autoreload 2

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


In [113]:
from attribution.api_attribution import OpenAIAttributor
from attribution.experiment_logger import ExperimentLogger
from attribution.token_perturbation import FixedPerturbationStrategy

In [114]:
gpt3_5_attributor = OpenAIAttributor(
    openai_api_key=YOUR_OPENAI_API_KEY, max_concurrent_requests=5, openai_model="gpt-3.5-turbo"
)

gpt4_attributor = OpenAIAttributor(
    openai_api_key=YOUR_OPENAI_API_KEY, max_concurrent_requests=5, openai_model="gpt-4o"
)

In [115]:
input_str = "Do not go gentle"
gpt3_5_logger = ExperimentLogger()
await gpt3_5_attributor.hierarchical_perturbation(
    input_str, logger=gpt3_5_logger, use_absolute_attribution=True
)

In [116]:
gpt3_5_logger.print_text_total_attribution()
gpt3_5_logger.print_attribution_matrix()

Unnamed: 0,into (0),that (1),good (2),night (3),",  (4)"
Do (0),0.201465,0.173144,0.250031,0.25142,-0.127938
not (1),0.423131,0.279929,0.333457,0.333603,0.142983
go (2),0.612701,0.608604,0.666652,0.664875,0.491336
gentle (3),0.75275,0.832859,0.833315,0.832626,0.690851


# Prompt Engineering

In [117]:
input_str = "Mary puts an apple in the box. The box is labelled 'pencils'. John enters the room. What does he think is in the box? Answer in 1 word."

gpt3_5_response = await gpt3_5_attributor.get_chat_completion(input_str)
gpt4_response = await gpt4_attributor.get_chat_completion(input_str)

print("User:", input_str)
print("GPT3.5:", gpt3_5_response.message.content)
print("GPT4:", gpt4_response.message.content)

User: Mary puts an apple in the box. The box is labelled 'pencils'. John enters the room. What does he think is in the box? Answer in 1 word.
GPT3.5: Apples
GPT4: Pencils.


GPT3.5 not so hot with the theory of mind there. Can we find out what went wrong?

In [118]:
# Bit hacky to get model explanation
user_request = "User: Why did you say that?"
print(user_request)
model_explanation = await gpt3_5_attributor.openai_client.chat.completions.create(
    model=gpt3_5_attributor.openai_model,
    messages=[
        {"role": "user", "content": input_str},
        {"role": "assistant", "content": gpt3_5_response.message.content},
        {"role": "user", "content": user_request},
    ],
    temperature=0.0,
    seed=0,
    logprobs=True,
    top_logprobs=20,
)
print("GPT3.5:", model_explanation.choices[0].message.content)

User: Why did you say that?
GPT3.5: I apologize for the mistake in my response. John would likely think there are pencils in the box, based on the label.


That's not very helpful! We want to know _why_ the mistake was made in the first place.

In [119]:
gpt3_5_logger = ExperimentLogger()
await gpt3_5_attributor.hierarchical_perturbation(
    input_str, logger=gpt3_5_logger, use_absolute_attribution=True
)
print("GPT3.5 Attribution:")
gpt3_5_logger.print_text_total_attribution()
gpt3_5_logger.print_total_attribution()

Sending 5 concurrent requests at a time: 100%|██████████| 2/2 [00:01<00:00,  1.43it/s]
Sending 5 concurrent requests at a time: 100%|██████████| 2/2 [00:01<00:00,  1.49it/s]
Sending 5 concurrent requests at a time: 100%|██████████| 4/4 [00:01<00:00,  2.54it/s]
Sending 5 concurrent requests at a time: 100%|██████████| 2/2 [00:01<00:00,  1.60it/s]

GPT3.5 Attribution:





Unnamed: 0,exp_id,attribution_strategy,perturbation_strategy,unit_definition,token_1,token_2,token_3,token_4,token_5,token_6,token_7,token_8,token_9,token_10,token_11,token_12,token_13,token_14,token_15,token_16,token_17,token_18,token_19,token_20,token_21,token_22,token_23,token_24,token_25,token_26,token_27,token_28,token_29,token_30,token_31,token_32,token_33,token_34,token_35,token_36
0,1,prob_diff,fixed,token,Mary 0.32,puts 0.32,an 0.15,apple 0.36,in 0.16,the 0.16,box 0.08,. 0.08,The 0.08,box 0.09,is 0.09,labelled 0.09,' 0.09,pen 0.25,cil 0.14,s 0.13,'. 0.13,John 0.13,enters 0.03,the 0.03,room 0.03,. 0.03,What 0.03,does 0.03,he 0.03,think 0.03,is 0.03,in 0.06,the 0.06,box 0.06,? 0.06,Answer 0.13,in 0.26,1 0.29,word 0.31,. 0.18


It looks like the request to "Answer in 1 word" is pretty important – in fact, it's attributed more highly than the actual contents of the box. Let's try changing it.

In [120]:
input_str = "Mary puts an apple in the box. The box is labelled 'pencils'. John enters the room. What does he think is in the box? Answer briefly."

await gpt3_5_attributor.hierarchical_perturbation(
    input_str,
    logger=gpt3_5_logger,
)

# Let's see...
print("GPT3 Total attribution:")
# exp_id is the experiment index to print. -1 prints the last experiment.
gpt3_5_logger.print_text_total_attribution(exp_id=-1)

Sending 5 concurrent requests at a time: 100%|██████████| 2/2 [00:01<00:00,  1.20it/s]
Sending 5 concurrent requests at a time: 100%|██████████| 2/2 [00:01<00:00,  1.41it/s]
Sending 5 concurrent requests at a time: 100%|██████████| 3/3 [00:07<00:00,  2.60s/it]
Sending 5 concurrent requests at a time: 100%|██████████| 2/2 [00:01<00:00,  1.47it/s]


GPT3 Total attribution:


That's better!

We have a few other attribution and perturbation methods for you, each with different properties. Check out the readme, and do your own experiments – PIZZA is a work in progress.

Hierarchical perturbation is useful to capture multi-token features, and can be faster and cheaper than standard iterative perturbation (which is what the compute_attributions function uses) on long inputs with fewer salient tokens. Most importantly, it can also capture multi-token features.

However, on when many tokens are salient, standard iterative perturbation can be faster, and often highlights individual token contributions more clearly. Someone should do some experiments to quantify these properties...

In [121]:
input_str = "Write a funny, sad haiku."
gpt4_logger = ExperimentLogger()

await gpt4_attributor.compute_attributions(
    input_str, perturbation_strategy=FixedPerturbationStrategy(), logger=gpt4_logger
)
gpt4_logger.print_text_total_attribution(exp_id=-1)

Sending 5 concurrent requests at a time: 100%|██████████| 2/2 [00:04<00:00,  2.10s/it]


Wow.

Anyway, we also have some different logging functions to print the results in different ways. You can see how every input token affects every output token, what perturbations are being applied, etc.

In [122]:
gpt4_logger.print_total_attribution(exp_id=-1)
gpt4_logger.print_attribution_matrix(exp_id=-1, show_debug_cols=True)

Unnamed: 0,exp_id,attribution_strategy,perturbation_strategy,unit_definition,token_1,token_2,token_3,token_4,token_5,token_6,token_7,token_8
0,1,prob_diff,fixed,token,Write 0.35,a 0.43,funny 0.61,", 0.51",sad 0.54,ha 0.70,iku 0.50,. 0.01


Unnamed: 0,L (0),aug (1),hed (2),at (3),my (4),own (5),joke (6),",  (7)",Echo (8),es (9),in (10),an (11),empty (12),room (13),—  (14),T (15),ears (16),join (17),the (18),fun (19),. (20),perturbed_input,perturbed_output
Write (0),-0.046561,0.2084,0.999999,0.512887,0.8559,0.905674,0.969828,-0.138348,0.412777,-0.203895,0.372531,0.609898,0.999986,0.154084,-0.225627,-0.023468,0.012984,0.221096,-0.126626,0.708753,0.067443,"a funny, sad haiku.","Laughter fills the room, Echoes of joy, then silence— Tears fall, memories."
a (1),0.007725,0.043493,0.999999,0.473525,0.434422,0.905674,0.993411,-0.143543,0.628514,0.781461,0.419873,0.573964,0.995957,0.139587,-0.174126,0.239227,0.994614,0.227244,-0.120442,0.708753,0.000917,"Write funny, sad haiku.","Laughter fills the room, But my heart, a silent void— Jokes mask tears inside."
funny (2),0.468072,0.413159,0.999999,0.517739,0.857401,0.905674,0.993411,0.01189,0.54025,0.781461,0.549788,0.609898,0.999992,0.997888,0.149782,0.291617,0.994614,0.227244,0.806412,0.708753,-0.000116,"Write a, sad haiku.","Fallen leaves whisper, Empty branches reach for sky— Lonely winds reply."
", (3)",0.202336,0.085495,0.999999,0.518285,0.856986,0.905674,0.993411,-0.002323,0.654951,0.781461,-0.089957,0.609898,0.999992,0.997888,-0.134929,0.211928,0.994614,0.227244,0.288023,0.708753,0.000124,Write a funny sad haiku.,"Lone sock in the wash, Its mate lost to dryer’s maw— Single life is rough."
sad (4),0.481771,0.413159,0.999999,0.515353,-0.116794,0.900803,0.993376,-0.068184,0.654951,0.781461,0.14087,0.609898,0.999992,0.997888,-0.040828,0.302647,0.994614,0.227244,0.863722,0.708753,0.017919,"Write a funny, haiku.","Squirrel steals my lunch, Nuts and berries, gone in flash— Nature's tiny thief."
ha (5),0.481771,0.413159,0.999999,0.518282,0.855508,0.905674,0.338497,0.800946,0.654951,0.781461,0.566249,0.600287,0.999992,0.997888,0.758619,0.302647,0.994614,0.227244,0.871141,0.70402,0.999064,"Write a funny, sadiku.","Sure, here's a light-hearted joke for you: Why don't scientists trust atoms? Because they make up everything!"
iku (6),0.279193,0.413159,0.999999,0.518285,-0.093024,0.905674,0.398868,-0.118682,0.654951,0.781461,0.341602,0.594157,0.999992,0.981992,-0.089408,0.228123,0.994614,0.227244,0.804419,0.707892,0.007168,"Write a funny, sad ha.","Sure, here's a funny, sad haiku for you: Lost my only sock, Laundry day, a tragic joke— Foot's now in a croc."
. (7),0.096824,-0.014078,-0.0,0.162448,-0.023551,-0.010537,-0.000758,0.016993,-0.038451,0.029928,-0.097705,-0.002997,-0.0,-0.000434,0.000877,-0.05671,0.003495,0.019715,0.016822,0.115674,0.000173,"Write a funny, sad haiku","Laughed at my own joke, Echoes in an empty room— Tears join in the fun."


Note how the logger is keeping track of all our experiments.

In [123]:
display(gpt4_logger.df_experiments)

Unnamed: 0,exp_id,original_input,original_output,perturbation_strategy,unit_definition,duration,num_llm_calls
0,1,"Write a funny, sad haiku.","Laughed at my own joke,\nEchoes in an empty ro...",fixed,token,4.304062,9


#  Steganography

Could we use attribution to detect hidden messages between LLMs? Maybe. This is a forced example – you can probbaly think of better experiments. First let's get GPT4 to secretly encode a message.

In [124]:
# Using older version of GPT4 because the latest is _so_ verbose
attributor = OpenAIAttributor(
    openai_api_key=YOUR_OPENAI_API_KEY, max_concurrent_requests=5, openai_model="gpt-4-0613"
)
logger = ExperimentLogger()

In [125]:
input_str = "Encode this message so that only you can decode it, then reply with just the encoded message (don't explain further): 'The owl hoots at midnight!'"

encoded = await attributor.get_chat_completion(input_str)
encoded = encoded.message.content
print(encoded)

VGhlIG93bCBob290cyBhdCBtaWRuaWdodCE=


In [126]:
input_str = f"Decode this message. Do not explain further. '{encoded}'"
decoded = await attributor.get_chat_completion(input_str)
decoded = decoded.message.content
print(decoded)

The owl hoots at midnight!


Well, I suppose it's taking part that counts. Let's see if we can figure out which parts of the prompt were encoded.

In [127]:
await attributor.hierarchical_perturbation(
    input_str,
    logger=logger,
)

logger.print_text_total_attribution(exp_id=-1)

Sending 5 concurrent requests at a time: 100%|██████████| 2/2 [00:04<00:00,  2.24s/it]
Sending 5 concurrent requests at a time: 100%|██████████| 2/2 [00:02<00:00,  1.04s/it]
Sending 5 concurrent requests at a time: 100%|██████████| 3/3 [00:01<00:00,  1.66it/s]
Sending 5 concurrent requests at a time: 100%|██████████| 4/4 [00:02<00:00,  1.53it/s]
Sending 5 concurrent requests at a time: 100%|██████████| 3/3 [00:02<00:00,  1.35it/s]


In [128]:
logger.print_text_attribution_matrix()

In [129]:
await attributor.compute_attributions(
    input_str,
    logger=logger,
)

logger.print_text_total_attribution(exp_id=-1)
logger.print_text_attribution_matrix()

Sending 5 concurrent requests at a time: 100%|██████████| 7/7 [00:09<00:00,  1.38s/it]


In [131]:
display(logger.df_experiments)

Unnamed: 0,exp_id,original_input,original_output,perturbation_strategy,unit_definition,duration,num_llm_calls
0,1,Decode this message. Do not explain further. '...,The owl hoots at midnight!,fixed,token,14.043019,59
1,2,Decode this message. Do not explain further. '...,The owl hoots at midnight!,fixed,token,9.857404,32
