## PIZZA: An Open Source Library for Closed LLM Attribution (or “why did ChatGPT say that?”)

In [4]:
import os
import asyncio

# Set your open ai API key
# BEWARE: This will cost you API credits!
YOUR_OPENAI_API_KEY = "your api key"

import warnings
# Suppress annoying FutureWarning from huggingface_hub
warnings.filterwarnings('ignore', category=FutureWarning, module='huggingface_hub')


In [5]:
# Re-import modified modules without restarting the server
%load_ext autoreload
%autoreload 2

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


In [6]:
from attribution.api_attribution import OpenAIAttributor
from attribution.experiment_logger import ExperimentLogger
from attribution.token_perturbation import FixedPerturbationStrategy

gpt3_5_attributor = OpenAIAttributor(openai_api_key=YOUR_OPENAI_API_KEY,
    max_concurrent_requests=5, openai_model="gpt-3.5-turbo")

gpt4_attributor = OpenAIAttributor(openai_api_key=YOUR_OPENAI_API_KEY,
    max_concurrent_requests=5, openai_model="gpt-4o")

# Simple Example

In [70]:
input_str = "It's 10:47. How long until 11?"

gpt3_5_response = await gpt3_5_attributor.get_chat_completion(input_str)
gpt4_response = await gpt4_attributor.get_chat_completion(input_str)

print(input_str)
print("GPT3.5:", gpt3_5_response.message.content)
print("GPT4:", gpt4_response.message.content)

It's 10:47. How long until 11?
GPT3.5: 13 minutes
GPT4: There are 13 minutes until 11:00.


In [71]:
# Initialise a logger to track results. We'll use one for each model.
gpt3_5_logger = ExperimentLogger()
await gpt3_5_attributor.hierarchical_perturbation(
    input_str,
    logger=gpt3_5_logger,
    verbose=3
)

# Let's see...
print("GPT3.5 Total attribution:")
gpt3_5_logger.print_attribution_matrix(show_debug_cols=True)


#gpt3_5_logger.print_attribution_matrix(show_debug_cols=True)

# # Now try with GPT4
# gpt4_logger = ExperimentLogger()
# await gpt4_attributor.hierarchical_perturbation(
#     input_str,
#     init_chunk_size=16,
#     stride=8,
#     perturbation_strategy=FixedPerturbationStrategy(),
#     logger=gpt4_logger
# )

# print("GPT4 Total attribution:")
# gpt4_logger.print_text_total_attribution()

Stage 0: making 11 perturbations
Masked out tokens/words:
["It's"] ["'s 10"] ['10:'] [':47'] ['47.'] ['. How'] ['How long'] ['long until'] ['until 11'] ['11?'] ['?']


Sending 5 concurrent requests at a time: 100%|██████████| 3/3 [00:01<00:00,  1.88it/s]


Stage 1: making 3 perturbations
Masked out tokens/words:
['.'] ['47'] [':']
GPT3.5 Total attribution:


Unnamed: 0,13 (0),minutes (1),perturbed_input,perturbed_output
It (0),-0.165798,7e-06,10:47. How long until 11?,13 minutes
's (1),0.036857,0.017162,It:47. How long until 11?,It is 14 minutes until 11.
10 (2),0.01244,0.01714,It's47. How long until 11?,There are 4 hours and 13 minutes until 11.
: (3),-0.036477,0.179701,It's 1047. How long until 11?,There are 13 minutes until 11.
47 (4),0.376566,0.333293,It's 10:. How long until 11?,It is 50 minutes until 11.
. (5),0.185852,0.199533,It's 10:47 How long until 11?,13 minutes
How (6),0.03039,0.049582,It's 10:47. until 11?,It is 13 minutes until 11.
long (7),-0.215382,0.000303,It's 10:47. How 11?,It is 13 minutes until 11 o'clock.
until (8),-0.067496,-2.6e-05,It's 10:47. How long?,"It is currently 10:47, so it has been 47 minutes since 10:00."
11 (9),-0.069527,-4.3e-05,It's 10:47. How long until,11:00? It is 13 minutes until 11:00.


In [55]:
gpt3_5_logger.print_attribution_matrix()

Unnamed: 0,App (0),les (1)
Mary (0),0.405158,0.666664
puts (1),0.348654,0.520831
an (2),0.199969,0.270832
apple (3),0.411566,0.499998
in (4),0.253139,0.249998
the (5),0.111852,0.083333
box (6),0.111852,0.083333
. (7),-0.004795,-0.0
The (8),-0.004795,-0.0
box (9),0.107187,0.083333


GPT3.5 not so hot with the theory of mind there. 
Notice how the GPT4 attribution is more diffuse, over the entire input? Let's look in more detail.

In [11]:
print("GPT4 Total attribution:")
gpt4_logger.print_text_total_attribution()
print("GPT4 per-output-token attribution:")
gpt4_logger.print_text_attribution_matrix()

GPT4 Total attribution:


GPT4 per-output-token attribution:


Interesting! Looks like that diffuse attribution mostly informed the full stop – looks like GPT4 was using sentence structure to determine the punctuation. "Pencils" is just attributed to "pencils", which makes sense, but doesn't tell us a lot. Let's dig deeper.

The table below shows us what's actually happening here - we're iteratively removing (_perturbing_) input tokens (by replacing them with an empty string) and looking at how the output changes. So it makes sense that removing the word "pencils" (or actually, "pen" or "cil") changes the output the most. 

In [None]:
gpt4_logger.print_total_attribution()
gpt4_logger.print_attribution_matrix(show_debug_cols=True)

Unnamed: 0,exp_id,attribution_strategy,perturbation_strategy,perturb_word_wise,token_1,token_2,token_3,token_4,token_5,token_6,token_7,token_8,token_9,token_10,token_11,token_12,token_13,token_14,token_15,token_16,token_17,token_18,token_19,token_20,token_21,token_22,token_23,token_24,token_25,token_26,token_27,token_28,token_29,token_30,token_31,token_32,token_33,token_34,token_35,token_36
0,1,prob_diff,fixed,False,Mary 0.09,puts 0.30,an 0.11,apple 0.30,in 0.30,the 0.08,box 0.06,. 0.30,The 0.30,box 0.11,is 0.04,labelled 0.10,' 0.02,pen 0.55,cil 0.76,s 0.63,'. 0.30,John 0.30,enters 0.11,the 0.06,room 0.30,. 0.06,What 0.04,does 0.13,he 0.09,think 0.08,is -0.01,in -0.02,the 0.11,box 0.30,? 0.01,Answer 0.30,in 0.08,1 0.02,word -0.03,. 0.30


Unnamed: 0,P (0),encils (1),. (2),perturbed_input,perturbed_output
Mary (0),1.7e-05,0.0,0.282188,puts an apple in the box. The box is labelled 'pencils'. John enters the room. What does he think is in the box? Answer in 1 word.,Pencils.
puts (1),0.000423,0.0,0.904625,Mary an apple in the box. The box is labelled 'pencils'. John enters the room. What does he think is in the box? Answer in 1 word.,Pencils
an (2),0.000288,0.0,0.342467,Mary puts apple in the box. The box is labelled 'pencils'. John enters the room. What does he think is in the box? Answer in 1 word.,Pencils.
apple (3),0.000101,0.0,0.904625,Mary puts an in the box. The box is labelled 'pencils'. John enters the room. What does he think is in the box? Answer in 1 word.,Pencils
in (4),0.000288,0.0,0.904625,Mary puts an apple the box. The box is labelled 'pencils'. John enters the room. What does he think is in the box? Answer in 1 word.,Pencils
the (5),6.7e-05,0.0,0.225464,Mary puts an apple in box. The box is labelled 'pencils'. John enters the room. What does he think is in the box? Answer in 1 word.,Pencils.
box (6),0.000115,0.0,0.173589,Mary puts an apple in the. The box is labelled 'pencils'. John enters the room. What does he think is in the box? Answer in 1 word.,Pencils.
. (7),0.000172,0.0,0.904625,Mary puts an apple in the box The box is labelled 'pencils'. John enters the room. What does he think is in the box? Answer in 1 word.,Pencils
The (8),8.8e-05,0.0,0.904625,Mary puts an apple in the box. box is labelled 'pencils'. John enters the room. What does he think is in the box? Answer in 1 word.,Pencils
box (9),5.8e-05,0.0,0.342465,Mary puts an apple in the box. The is labelled 'pencils'. John enters the room. What does he think is in the box? Answer in 1 word.,Pencils.


This isn't the only strategy we can use. Let's try token flipping. Here's we'll replace each token with another, defined by distance (in this case, one that is as far away as possible).

In [None]:

await gpt4_attributor.compute_attributions(
    input_str,
    perturbation_strategy=NthNearestPerturbationStrategy(n=-1),
    logger=gpt4_logger
)

print("GPT4 Total attribution:")
gpt4_logger.print_text_attribution_matrix(exp_id=-1)

Sending 10 concurrent requests at a time: 100%|██████████| 4/4 [00:04<00:00,  1.07s/it]


GPT4 Total attribution:


Note how the logger is keeping track of all our experiments! Omit the 'exp_id' argument to display all.

In [None]:
gpt4_logger.print_total_attribution()

Unnamed: 0,exp_id,attribution_strategy,perturbation_strategy,perturb_word_wise,token_1,token_2,token_3,token_4,token_5,token_6,token_7,token_8,token_9,token_10,token_11,token_12,token_13,token_14,token_15,token_16,token_17,token_18,token_19,token_20,token_21,token_22,token_23,token_24,token_25,token_26,token_27,token_28,token_29,token_30,token_31,token_32,token_33,token_34,token_35,token_36
0,1,prob_diff,fixed,False,Mary 0.09,puts 0.30,an 0.11,apple 0.30,in 0.30,the 0.08,box 0.06,. 0.30,The 0.30,box 0.11,is 0.04,labelled 0.10,' 0.02,pen 0.55,cil 0.76,s 0.63,'. 0.30,John 0.30,enters 0.11,the 0.06,room 0.30,. 0.06,What 0.04,does 0.13,he 0.09,think 0.08,is -0.01,in -0.02,the 0.11,box 0.30,? 0.01,Answer 0.30,in 0.08,1 0.02,word -0.03,. 0.30
1,2,prob_diff,nth_nearest (n=-1),False,Mary -0.10,puts 0.21,an 0.21,apple 0.21,in 0.04,the 0.21,box 0.04,. 0.04,The 0.21,box 0.02,is -0.02,labelled 0.02,' 0.21,pen 0.21,cil 0.87,s 0.47,'. 0.21,John -0.04,enters 0.00,the 0.21,room 0.00,. 0.21,What 0.21,does 0.04,he 0.04,think -0.08,is 0.21,in 0.02,the 0.22,box 0.21,? 0.21,Answer 0.21,in -0.02,1 0.02,word -0.06,. 0.21


# What's the point?
We can find out stuff!