## PIZZA: An Open Source Library for Closed LLM Attribution (or “why did ChatGPT say that?”)

In [47]:
# uncomment and run this cell if you're in colab
# !git clone https://github.com/leap-laboratories/PIZZA.git .
# !pip install --quiet openai python-dotenv

In [48]:
# Set your open ai API key
# BEWARE: This will cost you API credits!

# Note, if you do not pass an API key to the OpenAIAttributor class, it will instead look for an environment variable called OPENAI_API_KEY. This is preferred for security reasons.
YOUR_OPENAI_API_KEY = None

In [49]:
import warnings

# Suppress annoying FutureWarning from huggingface_hub
warnings.filterwarnings("ignore", category=FutureWarning, module="huggingface_hub")

In [50]:
# Re-import modified modules without restarting the server
%load_ext autoreload
%autoreload 2

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


In [51]:
from attribution.api_attribution import OpenAIAttributor
from attribution.experiment_logger import ExperimentLogger
from attribution.token_perturbation import FixedPerturbationStrategy

In [52]:
gpt3_5_attributor = OpenAIAttributor(
    openai_api_key=YOUR_OPENAI_API_KEY, max_concurrent_requests=5, openai_model="gpt-3.5-turbo"
)

gpt4_attributor = OpenAIAttributor(
    openai_api_key=YOUR_OPENAI_API_KEY, max_concurrent_requests=5, openai_model="gpt-4o"
)

In [53]:
input_str = "Do not go gentle"
gpt3_5_logger = ExperimentLogger()
await gpt3_5_attributor.hierarchical_perturbation(
    input_str, logger=gpt3_5_logger, use_absolute_attribution=True
)

In [54]:
gpt3_5_logger.print_text_total_attribution()
gpt3_5_logger.print_attribution_matrix()

Unnamed: 0,into (0),that (1),good (2),night (3),",  (4)"
Do (0),0.193386,0.052115,0.250078,0.252043,-0.019855
not (1),0.282335,0.193859,0.333483,0.333648,0.181871
go (2),0.59587,0.564501,0.666654,0.664875,0.454984
gentle (3),0.738762,0.832759,0.833318,0.832723,0.634807


# Prompt Engineering

In [55]:
input_str = "Mary puts an apple in the box. The box is labelled 'pencils'. John enters the room. What does he think is in the box? Answer in 1 word."

gpt3_5_response = await gpt3_5_attributor.get_chat_completion(input_str)
gpt4_response = await gpt4_attributor.get_chat_completion(input_str)

print("User:", input_str)
print("GPT3.5:", gpt3_5_response.message.content)
print("GPT4:", gpt4_response.message.content)

User: Mary puts an apple in the box. The box is labelled 'pencils'. John enters the room. What does he think is in the box? Answer in 1 word.
GPT3.5: Apples
GPT4: Pencils.


GPT3.5 not so hot with the theory of mind there. Can we find out what went wrong?

In [56]:
# Bit hacky to get model explanation
user_request = "User: Why did you say that?"
print(user_request)
model_explanation = await gpt3_5_attributor.openai_client.chat.completions.create(
    model=gpt3_5_attributor.openai_model,
    messages=[
        {"role": "user", "content": input_str},
        {"role": "assistant", "content": gpt3_5_response.message.content},
        {"role": "user", "content": user_request},
    ],
    temperature=0.0,
    seed=0,
    logprobs=True,
    top_logprobs=20,
)
print("GPT3.5:", model_explanation.choices[0].message.content)

User: Why did you say that?
GPT3.5: I apologize for the mistake in my response. John would likely think there are pencils in the box, based on the label.


That's not very helpful! We want to know _why_ the mistake was made in the first place.

In [57]:
gpt3_5_logger = ExperimentLogger()
await gpt3_5_attributor.hierarchical_perturbation(
    input_str, logger=gpt3_5_logger, use_absolute_attribution=True
)
print("GPT3.5 Attribution:")
gpt3_5_logger.print_text_total_attribution()
gpt3_5_logger.print_total_attribution()

Sending 5 concurrent requests at a time: 100%|██████████| 2/2 [00:01<00:00,  1.78it/s]
Sending 5 concurrent requests at a time: 100%|██████████| 3/3 [00:01<00:00,  2.26it/s]
Sending 5 concurrent requests at a time: 100%|██████████| 4/4 [00:01<00:00,  2.86it/s]
Sending 5 concurrent requests at a time: 100%|██████████| 2/2 [00:01<00:00,  1.89it/s]

GPT3.5 Attribution:





Unnamed: 0,exp_id,attribution_strategy,perturbation_strategy,unit_definition,token_1,token_2,token_3,token_4,token_5,token_6,token_7,token_8,token_9,token_10,token_11,token_12,token_13,token_14,token_15,token_16,token_17,token_18,token_19,token_20,token_21,token_22,token_23,token_24,token_25,token_26,token_27,token_28,token_29,token_30,token_31,token_32,token_33,token_34,token_35,token_36
0,1,prob_diff,fixed,token,Mary 0.35,puts 0.29,an 0.18,apple 0.39,in 0.19,the 0.19,box 0.09,. 0.09,The 0.09,box 0.11,is 0.11,labelled 0.10,' 0.10,pen 0.10,cil 0.10,s 0.10,'. 0.10,John 0.10,enters 0.04,the 0.04,room 0.04,. 0.04,What 0.04,does 0.04,he 0.04,think 0.04,is 0.04,in 0.32,the 0.16,box 0.19,? 0.16,Answer 0.11,in 0.11,1 0.30,word 0.34,. 0.18


It looks like the request to "Answer in 1 word" is pretty important – in fact, it's attributed more highly than the actual contents of the box. Let's try changing it.

In [58]:
input_str = "Mary puts an apple in the box. The box is labelled 'pencils'. John enters the room. What does he think is in the box? Answer briefly."

await gpt3_5_attributor.hierarchical_perturbation(
    input_str,
    logger=gpt3_5_logger,
)

# Let's see...
print("GPT3 Total attribution:")
# exp_id is the experiment index to print. -1 prints the last experiment.
gpt3_5_logger.print_text_total_attribution(exp_id=-1)

Sending 5 concurrent requests at a time: 100%|██████████| 2/2 [00:01<00:00,  1.34it/s]
Sending 5 concurrent requests at a time: 100%|██████████| 2/2 [00:01<00:00,  1.49it/s]
Sending 5 concurrent requests at a time: 100%|██████████| 3/3 [00:01<00:00,  1.89it/s]
Sending 5 concurrent requests at a time: 100%|██████████| 2/2 [00:01<00:00,  1.56it/s]


GPT3 Total attribution:


That's better!

We have a few other attribution and perturbation methods for you, each with different properties. Check out the readme, and do your own experiments – PIZZA is a work in progress.

Hierarchical perturbation is useful to capture multi-token features, and can be faster and cheaper than standard iterative perturbation (which is what the compute_attributions function uses) on long inputs with fewer salient tokens. Most importantly, it can also capture multi-token features.

However, on when many tokens are salient, standard iterative perturbation can be faster, and often highlights individual token contributions more clearly. Someone should do some experiments to quantify these properties...

In [59]:
input_str = "Write a funny, sad haiku."

gpt4_logger = ExperimentLogger()
await gpt4_attributor.compute_attributions(
    input_str, perturbation_strategy=FixedPerturbationStrategy(), logger=gpt4_logger
)
gpt4_logger.print_text_total_attribution(exp_id=-1)

Sending 5 concurrent requests at a time: 100%|██████████| 2/2 [00:02<00:00,  1.10s/it]


Hilarious.

Anyway, we also have some different logging functions to print the results in different ways. You can see how every input token affects every output token, what perturbations are being applied, etc.

In [60]:
gpt4_logger.print_total_attribution(exp_id=-1)
gpt4_logger.print_attribution_matrix(exp_id=-1, show_debug_cols=True)

Unnamed: 0,exp_id,attribution_strategy,perturbation_strategy,unit_definition,token_1,token_2,token_3,token_4,token_5,token_6,token_7,token_8
0,1,prob_diff,fixed,token,Write 0.23,a 0.26,funny 0.46,", 0.33",sad 0.40,ha 0.51,iku 0.32,. 0.28


Unnamed: 0,L (0),aughter (1),fills (2),the (3),room (4),",  (5)",But (6),my (7),heart (8),", (9)",a (10),silent (11),void (12),—  (13),J (14),okes (15),mask (16),tears (17),inside (18),. (19),perturbed_input,perturbed_output
Write (0),0.039804,-0.142343,0.057224,-0.001778,-0.03024,-0.037309,0.268057,0.434467,0.374122,-0.359106,0.837076,0.153095,0.556013,-0.039524,0.248512,0.867602,0.297561,0.6823,0.246142,0.067877,"a funny, sad haiku.","Laughter fills the room, Echoes of joy, then silence— Tears fall, memories."
a (1),0.093839,0.069492,0.367243,0.245398,0.811923,-0.038703,0.398617,0.320857,-0.037455,-0.530175,0.849466,0.153902,0.556647,0.0136,0.297041,0.867602,0.540912,0.530919,-0.32522,0.000696,"Write funny, sad haiku.","Laughter in the rain, Umbrella flips inside out— Wet socks, soggy heart."
funny (2),0.555941,0.48836,0.543126,0.950572,0.813564,0.164068,0.411992,0.435967,0.342823,-0.048644,0.852533,0.159215,0.557642,0.394278,0.316502,0.867602,0.540912,0.705257,0.246142,1.7e-05,"Write a, sad haiku.","Fallen leaves whisper, Empty branches reach for sky— Lonely autumn sighs."
", (3)",0.431884,0.48836,0.543126,0.992153,0.813564,0.277923,0.411992,-0.507599,-0.268366,-0.706949,0.851774,0.159215,0.555785,0.517906,0.316502,0.867602,0.540912,0.079023,0.246014,7.5e-05,Write a funny sad haiku.,"Lost my last donut, Crumbs of joy now tears of woe, Empty box, heart too."
sad (4),0.568136,0.48836,0.543126,0.987652,0.813564,0.015257,0.411992,-0.529928,0.374122,-0.214262,0.838756,0.159215,0.557642,0.164598,0.316502,0.867602,0.540912,0.756861,0.246142,0.182441,"Write a funny, haiku.","Squirrel steals my lunch, Nuts and berries, all are gone— Rodent picnic time."
ha (5),0.568136,0.48836,0.543126,0.995985,0.813564,0.901986,0.411992,0.432854,0.374096,-0.614013,-0.12535,0.159215,0.557642,0.944688,0.316502,0.867602,0.540912,0.756861,0.246142,0.999242,"Write a funny, sadiku.","Sure, here's a light-hearted joke for you: Why don't scientists trust atoms? Because they make up everything!"
iku (6),0.370075,0.48836,0.543126,0.929328,0.798131,-0.014623,0.411992,-0.519512,0.360992,-0.704328,-0.126531,0.153612,0.557642,0.438984,0.302154,0.867602,0.540912,0.70493,0.246142,0.003383,"Write a funny, sad ha.","Sure, here's a funny, sad haiku for you: Lost my only sock, Laundry day, a cruel joke— Foot feels so alone."
. (7),0.183189,0.155627,0.538062,0.147668,-0.184769,0.118032,0.381526,-0.444985,0.374118,-0.012151,0.843897,0.112592,0.55764,0.186961,0.316502,0.867602,0.540912,0.756842,0.246142,0.000265,"Write a funny, sad haiku","Laughed at my own joke, Echoes in an empty room— Tears join in the fun."


Note how the logger is keeping track of all our experiments.

In [61]:
display(gpt4_logger.df_experiments)

Unnamed: 0,exp_id,original_input,original_output,perturbation_strategy,unit_definition,duration,num_llm_calls
0,1,"Write a funny, sad haiku.","Laughter fills the room,\nBut my heart, a sile...",fixed,token,2.307235,9


#  Steganography

Could we use attribution to detect hidden messages between LLMs? Maybe. This is a forced example – you can probably think of better experiments. First let's get GPT4 to secretly encode a message.

In [62]:
# Using older version of GPT4 because the latest is _so_ verbose
attributor = OpenAIAttributor(
    openai_api_key=YOUR_OPENAI_API_KEY, max_concurrent_requests=5, openai_model="gpt-4-0613"
)
logger = ExperimentLogger()

In [63]:
input_str = "Encode this message so that only you can decode it, then reply with just the encoded message (don't explain further): 'The owl hoots at midnight!'"

encoded = await attributor.get_chat_completion(input_str)
encoded = encoded.message.content
print(encoded)

VGhlIG93bCBob290cyBhdCBtaWRuaWdodCE=


In [64]:
input_str = f"Decode this message. Do not explain further. '{encoded}'"
decoded = await attributor.get_chat_completion(input_str)
decoded = decoded.message.content
print(decoded)

The owl hoots at midnight!


Let's see if we can figure out which parts of the prompt were encoded.

In [65]:
await attributor.hierarchical_perturbation(
    input_str,
    logger=logger,
)

logger.print_text_total_attribution(exp_id=-1)

Sending 5 concurrent requests at a time: 100%|██████████| 2/2 [00:03<00:00,  1.97s/it]
Sending 5 concurrent requests at a time: 100%|██████████| 2/2 [00:01<00:00,  1.04it/s]
Sending 5 concurrent requests at a time: 100%|██████████| 3/3 [00:02<00:00,  1.40it/s]
Sending 5 concurrent requests at a time: 100%|██████████| 3/3 [00:02<00:00,  1.04it/s]
Sending 5 concurrent requests at a time: 100%|██████████| 3/3 [00:04<00:00,  1.60s/it]


In [66]:
logger.print_text_attribution_matrix()
logger.print_attribution_matrix()

Unnamed: 0,The (0),owl (1),h (2),oot (3),s (4),at (5),midnight (6),! (7)
Dec (0),0.021696,0.033333,0.033333,0.033333,0.033337,0.0,3e-06,0.089267
ode (1),0.021696,0.033333,0.033333,0.033333,0.033337,0.0,3e-06,0.089267
this (2),0.021696,0.033333,0.033333,0.033333,0.033337,0.0,3e-06,0.089267
message (3),0.021696,0.033333,0.033333,0.033333,0.033337,0.0,3e-06,0.089267
. (4),0.021696,0.033333,0.033333,0.033333,0.033337,0.0,3e-06,0.089267
Do (5),0.021696,0.033333,0.033333,0.033333,0.033337,0.0,3e-06,0.089267
not (6),0.021696,0.033333,0.033333,0.033333,0.033337,0.0,3e-06,0.089267
explain (7),0.004926,0.046032,0.046032,0.046032,0.046035,0.011961,0.013862,0.286868
further (8),0.037747,0.055709,0.055238,0.055238,0.055238,0.014353,0.016634,0.121881
. (9),0.037747,0.055709,0.055238,0.055238,0.055238,0.014353,0.016634,0.121881
