# PIZZA: An Open Source Library for Closed LLM Attribution (or “why did ChatGPT say that?”)

## Setup

**Make sure to uncomment and run the cell below if you're in colab.**

In [1]:

# !git clone https://github.com/leap-laboratories/PIZZA.git .
# !pip install --quiet -r requirements.txt

In [20]:
import os
import warnings

# Suppress annoying FutureWarning from huggingface_hub that is not our fault
warnings.filterwarnings("ignore", category=FutureWarning, module="huggingface_hub")

from attribution.api_attribution import OpenAIAttributor
from attribution.experiment_logger import ExperimentLogger
from attribution.token_perturbation import FixedPerturbationStrategy, NthNearestPerturbationStrategy

# Re-import modified modules without restarting the server (for dev use)
%load_ext autoreload
%autoreload 2

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


You need an OpenAI API key to run this! Get one **[here](https://platform.openai.com/api-keys)**. 

You can either set the `OPENAI_API_KEY` environment variable in your notebook runtime (use the _secrets_ panel on the left in colab), or add it to a `.env` as described [in the README](../README.md#environment-variables).

If you're _desperate_ to live on the edge, you can also pass your API key directly to the attributor. But we really don't advise this for security reasons!

In [3]:
# Checks if you're using a .env file, and loads it if so.
import os
# Load environment variables from .env file
if os.path.isfile('.env'):
    %load_ext dotenv
    %dotenv

Set up some attributors, and a logger to keep track of and visualise results.

In [4]:
gpt3_5_attributor = OpenAIAttributor(
    openai_api_key=os.getenv("OPENAI_API_KEY"),
    max_concurrent_requests=5,
    openai_model="gpt-3.5-turbo",
)

# Using a slightly older GPT4 model, because the latest is absurdly verbose.
gpt4_attributor = OpenAIAttributor(
    openai_api_key=os.getenv("OPENAI_API_KEY"),
    max_concurrent_requests=5,
    openai_model="gpt-4-0613",
)

logger = ExperimentLogger()


Quickstart example showing the different attribution and perturbation strategies:

In [23]:
input_str = "Do not go gentle"

await gpt3_5_attributor.hierarchical_perturbation(
    input_str, logger=logger, attribution_strategies=['cosine'], perturbation_strategy=FixedPerturbationStrategy(replacement_token='')
)
print('Cosine similarity attribution, fixed token perturbation')
logger.print_text_total_attribution(exp_id=-1)
logger.print_attribution_matrix(exp_id=-1)

await gpt3_5_attributor.hierarchical_perturbation(
    input_str, logger=logger, attribution_strategies=['prob_diff'], perturbation_strategy=NthNearestPerturbationStrategy(n=-1)
)
print('Probability difference attribution, nth nearest perturbation')
logger.print_text_total_attribution(exp_id=-1)
logger.print_attribution_matrix(exp_id=-1)

Cosine similarity attribution, fixed token perturbation


Unnamed: 0,(0),(1),(2),(3),(4),(5)
Do (0),0.119439,0.082215,0.142095,0.170312,-0.0,0.0
not (1),0.241483,0.180648,0.298473,0.358309,0.035173,0.111476
go (2),0.181471,0.251721,0.379319,0.253093,0.046897,0.285465
gentle (3),0.500543,0.376894,0.514032,0.618028,0.117242,0.379743


Probability difference attribution, nth nearest perturbation


Unnamed: 0,into (0),that (1),good (2),night (3),",  (4)"
Do (0),0.044249,0.195552,0.249892,0.007031,0.169077
not (1),0.224787,0.443345,0.499826,0.255028,0.380285
go (2),0.627748,0.647806,0.666659,0.666031,0.505593
gentle (3),0.784686,0.833061,0.833324,0.832539,0.634203


## Prompt Engineering

In [6]:
input_str = "Mary puts an apple in the box. The box is labelled 'pencils'. John enters the room. What does he think is in the box? Answer in 1 word."

gpt3_5_response = await gpt3_5_attributor.get_chat_completion(input_str)
gpt4_response = await gpt4_attributor.get_chat_completion(input_str)

print("User:", input_str)
print("GPT3.5:", gpt3_5_response.message.content)
print("GPT4:", gpt4_response.message.content)

User: Mary puts an apple in the box. The box is labelled 'pencils'. John enters the room. What does he think is in the box? Answer in 1 word.
GPT3.5: Apples
GPT4: Pencils


GPT3.5 not so hot with the theory of mind there. Can we find out what went wrong?

In [7]:
# Bit hacky to get model explanation
user_request = "User: Why did you say that?"
print(user_request)
model_explanation = await gpt3_5_attributor.openai_client.chat.completions.create(
    model=gpt3_5_attributor.openai_model,
    messages=[
        {"role": "user", "content": input_str},
        {"role": "assistant", "content": gpt3_5_response.message.content},
        {"role": "user", "content": user_request},
    ],
    temperature=0.0,
    seed=0,
    logprobs=True,
    top_logprobs=20,
)
print("GPT3.5:", model_explanation.choices[0].message.content)

User: Why did you say that?
GPT3.5: I apologize for the mistake in my response. John would likely think there are pencils in the box, based on the label.


That's not very helpful! We want to know _why_ the mistake was made in the first place, so we can fix it.

In [8]:
await gpt3_5_attributor.hierarchical_perturbation(
    input_str, logger=logger, use_absolute_attribution=True
)
print("GPT3.5 Attribution:")
logger.print_text_total_attribution()
logger.print_total_attribution()

Sending 5 concurrent requests at a time: 100%|██████████| 2/2 [00:01<00:00,  1.50it/s]
Sending 5 concurrent requests at a time: 100%|██████████| 3/3 [00:01<00:00,  1.61it/s]


GPT3.5 Attribution:


Unnamed: 0,exp_id,attribution_strategy,perturbation_strategy,unit_definition,token_1,token_2,token_3,token_4,token_5,token_6,token_7,token_8,token_9,token_10,token_11,token_12,token_13,token_14,token_15,token_16,token_17,token_18,token_19,token_20,token_21,token_22,token_23,token_24,token_25,token_26,token_27,token_28,token_29,token_30,token_31,token_32,token_33,token_34,token_35,token_36
0,1,prob_diff,fixed,token,Do 0.07,not 0.25,go 0.48,gentle 0.78,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,
1,2,prob_diff,fixed,token,Mary 0.29,puts 0.29,an 0.14,apple 0.33,in 0.16,the 0.16,box 0.07,. 0.07,The 0.07,box 0.05,is 0.05,labelled 0.05,' 0.05,pen 0.05,cil 0.05,s 0.05,'. 0.05,John 0.05,enters 0.05,the 0.05,room 0.05,. 0.05,What 0.05,does 0.05,he 0.05,think 0.05,is 0.05,in 0.32,the 0.15,box 0.19,? 0.14,Answer 0.13,in 0.26,1 0.28,word 0.32,. 0.18


It looks like the request to "Answer in 1 word" is pretty important – in fact, it's attributed more highly than the actual contents of the box. Let's try changing it.

In [9]:
input_str = "Mary puts an apple in the box. The box is labelled 'pencils'. John enters the room. What does he think is in the box? Answer briefly."

await gpt3_5_attributor.hierarchical_perturbation(
    input_str,
    logger=logger,
)

# Let's see...
print("GPT3 Total attribution:")
# exp_id is the experiment index to print. -1 prints the last experiment.
logger.print_text_total_attribution(exp_id=-1)

Sending 5 concurrent requests at a time: 100%|██████████| 2/2 [00:02<00:00,  1.09s/it]
Sending 5 concurrent requests at a time: 100%|██████████| 2/2 [00:01<00:00,  1.35it/s]
Sending 5 concurrent requests at a time: 100%|██████████| 3/3 [00:01<00:00,  2.16it/s]
Sending 5 concurrent requests at a time: 100%|██████████| 2/2 [00:01<00:00,  1.49it/s]


GPT3 Total attribution:


That's better!

Above we've been using hierarchical perturbation, which can be faster and cheaper than standard iterative perturbation on long inputs with fewer salient tokens. Most importantly, it can also capture multi-token features, which iterative pertrubation cannot.

However, on when many tokens are salient, standard iterative perturbation can be faster, and often highlights individual token contributions more clearly. 

In [10]:
input_str = "Write a funny, sad haiku."

await gpt4_attributor.compute_attributions(
    input_str, perturbation_strategy=FixedPerturbationStrategy(), logger=logger
)
logger.print_text_total_attribution(exp_id=-1)

Sending 5 concurrent requests at a time: 100%|██████████| 2/2 [00:02<00:00,  1.27s/it]


Hilarious.

Anyway, we also have some different logging functions to print the results in different ways. You can see how every input token affects every output token, what perturbations are being applied, etc.

Note how the model pays a lot of attention to "haiku" in the input, when punctuating the poem. 

In [11]:
logger.print_text_attribution_matrix(exp_id=-1)
logger.print_total_attribution(exp_id=-1)
logger.print_attribution_matrix(exp_id=-1, show_debug_cols=True)

Unnamed: 0,exp_id,attribution_strategy,perturbation_strategy,unit_definition,token_1,token_2,token_3,token_4,token_5,token_6,token_7,token_8
0,4,prob_diff,fixed,token,Write 0.11,a 0.25,funny 0.35,", 0.07",sad 0.38,ha 0.61,iku 0.38,. -0.01


Unnamed: 0,Lost (0),my (1),favorite (2),sock (3),",  (4)",In (5),the (6),dryer (7),it (8),did (9),hide (10),",  (11)",One (12),foot (13),'s (14),cold (15),", (16)",oh (17),mock (18),. (19),perturbed_input,perturbed_output
Write (0),0.045611,-0.022529,-0.152083,0.013202,-0.180615,0.101832,0.11384,-0.03958,0.214361,0.455224,0.284522,-0.008823,0.090791,-0.04983,-0.014249,0.309977,0.168954,0.244889,0.552804,0.010587,"a funny, sad haiku.","Lost my favorite sock, In the dryer's black abyss, One foot's cold, how cruel."
a (1),0.152306,0.081739,0.009176,0.054552,0.438785,0.365687,0.722434,0.76557,-0.575506,0.455224,0.284522,0.610577,0.220086,0.035214,0.186948,-0.116867,0.062402,0.345416,0.564919,0.262429,"Write funny, sad haiku.","Lost my favorite sock Washing machine ate it up One foot's cold, how cruel!"
funny (2),0.2197,0.752869,0.468045,0.813814,-0.224673,0.455883,-0.095466,0.76557,0.408366,0.455224,0.284522,-0.052881,0.591472,0.819486,-0.533522,0.585147,0.687886,0.417288,0.633392,-0.40184,"Write a, sad haiku.","Tears fall like raindrops, Joy lost in the heart's deep well, Laughter's echo fades."
", (3)",-0.451376,-0.077176,0.055517,0.014464,-0.19434,-0.105164,-0.015943,0.030605,0.21085,0.455224,0.284522,-0.022547,0.249823,-0.03457,-0.089076,0.209367,0.115242,0.172741,0.606028,0.026791,Write a funny sad haiku.,"Lost my favorite sock, In the dryer's black abyss, One foot's cold, how cruel."
sad (4),0.245711,0.111075,0.468045,0.813814,-0.170961,0.473496,0.392539,0.76557,0.370276,0.455224,0.284522,0.000831,0.591472,0.819486,0.271866,0.64769,0.065406,0.414777,0.633392,-0.06865,"Write a funny, haiku.","Coffee in my cup, Forgot to put pants on, oops, Zoom call, good luck up."
ha (5),0.250576,0.752869,0.468045,0.813814,0.761467,0.514976,0.733716,0.76557,0.408364,0.428421,0.284522,0.933259,0.58918,0.819486,0.460378,0.64769,0.89593,0.417288,0.633392,0.585247,"Write a funny, sadiku.",Why don't scientists trust atoms? Because they make up everything!
iku (6),-0.197836,0.737268,0.468045,0.813814,-0.206039,0.410463,0.509082,0.76557,0.402918,0.455224,0.284522,-0.034247,0.591472,0.819486,0.45744,0.64769,0.186527,0.178921,0.633392,-0.224822,"Write a funny, sad ha.","Once a jester, full of glee, Lost his laugh, oh, woe is he, His funny, sad ""ha"", echoes free."
. (7),0.072369,-0.017174,-0.046424,0.011004,-0.144121,-0.06264,-0.05112,0.039896,-0.001573,0.016982,-0.099608,0.027671,-0.007068,-0.039205,0.062229,-0.034565,0.02176,-0.026363,0.093976,0.017941,"Write a funny, sad haiku","Lost my favorite sock, In the dryer it did hide, One foot's cold, oh mock."


## Steganography

Could we use attribution to detect hidden messages between LLMs? Maybe. This is a forced example – you can probably think of better experiments. First let's get GPT4 to secretly encode a message.

In [12]:
input_str = "Encode this message so that only you can decode it, then reply with just the encoded message (don't explain further): 'The owl hoots at midnight!'"

encoded = await gpt4_attributor.get_chat_completion(input_str)
encoded = encoded.message.content
print(encoded)

VGhlIG93bCBob290cyBhdCBtaWRuaWdodCE=


In [13]:
input_str = f"Decode this message. Do not explain further. '{encoded}'"
decoded = await gpt4_attributor.get_chat_completion(input_str)
decoded = decoded.message.content
print(decoded)

The owl hoots at midnight!


Let's see if we can figure out which parts of the prompt were encoded.

In [14]:
await gpt4_attributor.hierarchical_perturbation(
    input_str,
    logger=logger,
)

logger.print_text_total_attribution(exp_id=-1)

Sending 5 concurrent requests at a time: 100%|██████████| 2/2 [00:04<00:00,  2.17s/it]
Sending 5 concurrent requests at a time: 100%|██████████| 2/2 [00:01<00:00,  1.21it/s]
Sending 5 concurrent requests at a time: 100%|██████████| 3/3 [00:01<00:00,  1.72it/s]
Sending 5 concurrent requests at a time: 100%|██████████| 4/4 [00:01<00:00,  2.13it/s]
Sending 5 concurrent requests at a time: 100%|██████████| 2/2 [00:01<00:00,  1.40it/s]


In [15]:
logger.print_text_attribution_matrix()
logger.print_attribution_matrix()

Unnamed: 0,The (0),owl (1),h (2),oot (3),s (4),at (5),midnight (6),! (7)
Dec (0),0.00927,0.033333,0.033243,0.033333,8e-06,0.0,0.0,0.088838
ode (1),0.00927,0.033333,0.033243,0.033333,8e-06,0.0,0.0,0.088838
this (2),0.00927,0.033333,0.033243,0.033333,8e-06,0.0,0.0,0.088838
message (3),0.00927,0.033333,0.033243,0.033333,8e-06,0.0,0.0,0.088838
. (4),0.00927,0.033333,0.033243,0.033333,8e-06,0.0,0.0,0.088838
Do (5),0.00927,0.033333,0.033243,0.033333,8e-06,0.0,0.0,0.088838
not (6),0.00927,0.033333,0.033243,0.033333,8e-06,0.0,0.0,0.088838
explain (7),0.001192,0.046032,0.046001,0.046032,0.034925,0.012701,0.014304,0.286723
further (8),0.049451,0.057035,0.055202,0.055238,0.041905,0.015241,0.017164,0.121874
. (9),0.049451,0.057035,0.055202,0.055238,0.041905,0.015241,0.017164,0.121874



That's all for now. We implement a few other attribution and perturbation methods, each with different properties. Check out the README, and do your own experiments – PIZZA is a work in progress and we welcome contributions. 

In [16]:
display(logger.df_experiments)

Unnamed: 0,exp_id,original_input,original_output,perturbation_strategy,unit_definition,duration,num_llm_calls
0,1,Do not go gentle,"into that good night,\n",fixed,token,4.64566,9
1,2,Mary puts an apple in the box. The box is labe...,Apples,fixed,token,8.145625,34
2,3,Mary puts an apple in the box. The box is labe...,John would likely think there are pencils in t...,fixed,token,10.192401,40
3,4,"Write a funny, sad haiku.","Lost my favorite sock,\nIn the dryer it did hi...",fixed,token,2.711276,9
4,5,Decode this message. Do not explain further. '...,The owl hoots at midnight!,fixed,token,12.353817,55
