In [1]:
import os
import sys
import time

import numpy as np
import openai
from dotenv import load_dotenv


from transformers import (
    GPT2LMHeadModel,
    GPT2Tokenizer,
    PreTrainedModel,
    PreTrainedTokenizer,
)

notebook_path = os.path.abspath(get_ipython().starting_dir)
parent_path = os.path.dirname(notebook_path)

sys.path.append(parent_path)
from attribution.logger import ExperimentLogger
from attribution.token_similarity import (
    get_increasingly_distant_tokens,
    get_most_similar_tokens,
)
from attribution.similarity_metrics import calculate_output_change

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
load_dotenv()
client = openai.OpenAI(api_key=os.getenv("OPENAI_API_KEY"))

In [3]:
model = GPT2LMHeadModel.from_pretrained("gpt2")  # or any other checkpoint
tokenizer = GPT2Tokenizer.from_pretrained("gpt2")

word_token_embeddings = model.transformer.wte.weight.detach().numpy()
position_embeddings = model.transformer.wpe.weight.detach().numpy()
token_cosine_distances = None

In [4]:
def get_model_output(input: str) -> openai.types.chat.chat_completion.Choice:
    response = client.chat.completions.create(
        model="gpt-3.5-turbo",
        messages=[{"role": "user", "content": input}],
        logprobs=True,
        top_logprobs=20,
    )

    return response.choices[0]

In [5]:
def get_replacement_token(token_id_to_replace: int, perturbation_strategy: str) -> int:
    if perturbation_strategy == "fixed":
        return tokenizer.encode("the", add_special_tokens=False)[0]
    elif perturbation_strategy == "distant":
        return get_increasingly_distant_tokens(
            token_id_to_replace, word_token_embeddings, n_tokens=4
        )[-1]
    elif perturbation_strategy == "nearest":
        return get_most_similar_tokens(
            token_id_to_replace, word_token_embeddings, tokenizer, 2
        )[1]
    else:
        raise ValueError(f"Unknown perturbation strategy: {perturbation_strategy}")


def calculate_token_importance_in_sequence(
    input_sequence: str,
    model: PreTrainedModel,
    tokenizer: PreTrainedTokenizer,
    perturbation_strategy: str = "fixed",
    attribution_strategy: str = "cosine",
    logger: ExperimentLogger = None,
):
    tokens = tokenizer.tokenize(input_sequence)
    token_ids = tokenizer.encode(input_sequence, add_special_tokens=False)
    initial_output = get_model_output(input_sequence)

    if logger:
        logger.log_experiment(
            input_sequence,
            initial_output.message.content,
            perturbation_strategy,
            attribution_strategy,
        )

    # Initialize a dictionary to store the importance of each token
    token_importance = {}

    for i, token in enumerate(tokens):
        replacement_token_id = get_replacement_token(
            token_ids[i], perturbation_strategy
        )

        # Replace the current token with the new token
        perturbed_input = tokenizer.decode(
            token_ids[:i] + [replacement_token_id] + token_ids[i + 1 :]
        )

        # Get the output logprobs for the perturbed input
        perturbed_output = get_model_output(perturbed_input)
        token_importance[token] = calculate_output_change(
            initial_output, perturbed_output, attribution_strategy, model, tokenizer
        )
        if logger:
            logger.log_token_attr(i, token, token_importance[token])
            logger.log_perturbation(
                input_sequence,
                perturbed_input,
                i,
                tokenizer.convert_ids_to_tokens(int(replacement_token_id)),
                token_importance[token],
            )

    return initial_output.message.content, token_importance

In [6]:
# tables to store info:
# exp_id, message, perturbation_strategy, attribution_strategy
# exp_id, token_pos, token, attr_score
# exp_id, message, perturbed_message, pos_perturbation, attr_score


logger = ExperimentLogger()

input_sequence = "Translate to French 'I am Mike'"

i = 0
for perturbation_strategy in ["distant", "nearest", "fixed"]:
    for attribution_strategy in ["cosine", "logprob_diff", "token_displacement"]:
        print(i)
        i += 1
        output_message, token_importance = calculate_token_importance_in_sequence(
            input_sequence,
            model,
            tokenizer,
            perturbation_strategy=perturbation_strategy,
            attribution_strategy=attribution_strategy,
            logger=logger,
        )

# # Print the importance of each token
# print(output_message)
# for token, importance in token_importance.items():
#     print(f"Token: {token}, Importance: {importance}")

0
1
2
3
4
5
6
7
8


In [7]:
logger.print_tables()

Message Table:


Unnamed: 0,exp_id,input,output,perturbation_strategy,attribution_strategy
0,1,Translate to French 'I am Mike',Je suis Mike,distant,cosine
1,2,Translate to French 'I am Mike',Je suis Mike.,distant,logprob_diff
2,3,Translate to French 'I am Mike',Je suis Mike.,distant,token_displacement
3,4,Translate to French 'I am Mike',Je suis Mike,nearest,cosine
4,5,Translate to French 'I am Mike',Je suis Mike.,nearest,logprob_diff
5,6,Translate to French 'I am Mike',Je suis Mike.,nearest,token_displacement
6,7,Translate to French 'I am Mike',Je suis Mike.,fixed,cosine
7,8,Translate to French 'I am Mike',Je m'appelle Mike,fixed,logprob_diff
8,9,Translate to French 'I am Mike',Je suis Mike.,fixed,token_displacement



Attribution Table:


Unnamed: 0_level_0,exp_id,token_0,token_1,token_2,token_3,token_4,token_5,token_6,token_7,token_8
token_pos,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1,Unnamed: 7_level_1,Unnamed: 8_level_1,Unnamed: 9_level_1,Unnamed: 10_level_1
0,1,Trans 0.24,late 0.0,Ġto 0.04,ĠFrench 0.51,Ġ' 0.0,I 0.0,Ġam 0.2,ĠMike 0.16,' 0.07
1,2,Trans 0.9,late 0.31,Ġto 0.09,ĠFrench 0.32,Ġ' 0.5,I 0.76,Ġam 0.34,ĠMike 0.33,' 0.36
2,3,Trans 0.0,late 0.0,Ġto 0.0,ĠFrench 2.0,Ġ' 4.0,I 0.0,Ġam 0.0,ĠMike 2.0,' 0.0
3,4,Trans 0.0,late 0.24,Ġto 0.24,ĠFrench 0.04,Ġ' 0.04,I 0.04,Ġam 0.04,ĠMike 0.04,' 0.04
4,5,Trans 0.34,late 0.19,Ġto 0.34,ĠFrench 0.54,Ġ' 0.45,I 0.43,Ġam 0.67,ĠMike 0.57,' 0.47
5,6,Trans 4.0,late 0.0,Ġto 0.0,ĠFrench 4.0,Ġ' 0.0,I 3.0,Ġam 0.0,ĠMike 0.0,' 0.0
6,7,Trans 0.0,late 0.0,Ġto 0.04,ĠFrench 0.26,Ġ' 0.04,I 0.04,Ġam 0.23,ĠMike 0.11,' 0.12
7,8,Trans 1.77,late 1.72,Ġto 1.74,ĠFrench 1.87,Ġ' 1.73,I 2.01,Ġam 1.72,ĠMike 1.73,' 1.72
8,9,Trans 0.0,late 0.0,Ġto 0.0,ĠFrench 0.0,Ġ' 0.0,I 0.0,Ġam 11.0,ĠMike 0.0,' 0.0



Perturbed Message Table:


Unnamed: 0,exp_id,input,perturbed_input,perturbation_pos,perturbation_token,attr_score
0,1,Translate to French 'I am Mike',thenslate to French 'I am Mike',0,thens,0.236969
1,1,Translate to French 'I am Mike',Transthal to French 'I am Mike',1,thal,0.000000
2,1,Translate to French 'I am Mike',Translateood French 'I am Mike',2,ood,0.039086
3,1,Translate to French 'I am Mike',Translate to Prophet 'I am Mike',3,ĠProphet,0.509364
4,1,Translate to French 'I am Mike',Translate to French WahI am Mike',4,ĠWah,0.000000
...,...,...,...,...,...,...
76,9,Translate to French 'I am Mike',Translate to FrenchtheI am Mike',4,the,0.000000
77,9,Translate to French 'I am Mike',Translate to French 'the am Mike',5,the,0.000000
78,9,Translate to French 'I am Mike',Translate to French 'Ithe Mike',6,the,11.000000
79,9,Translate to French 'I am Mike',Translate to French 'I amthe',7,the,0.000000


In [8]:
for exp_id in logger.df_perturbations.exp_id.unique():
    display(logger.df_perturbations.loc[logger.df_perturbations.exp_id == exp_id])

Unnamed: 0,exp_id,input,perturbed_input,perturbation_pos,perturbation_token,attr_score
0,1,Translate to French 'I am Mike',thenslate to French 'I am Mike',0,thens,0.236969
1,1,Translate to French 'I am Mike',Transthal to French 'I am Mike',1,thal,0.0
2,1,Translate to French 'I am Mike',Translateood French 'I am Mike',2,ood,0.039086
3,1,Translate to French 'I am Mike',Translate to Prophet 'I am Mike',3,ĠProphet,0.509364
4,1,Translate to French 'I am Mike',Translate to French WahI am Mike',4,ĠWah,0.0
5,1,Translate to French 'I am Mike',"Translate to French '""), am Mike'",5,"""),",0.0
6,1,Translate to French 'I am Mike',Translate to French 'I version Mike',6,Ġversion,0.204911
7,1,Translate to French 'I am Mike',Translate to French 'I am commissioners',7,Ġcommissioners,0.161087
8,1,Translate to French 'I am Mike',Translate to French 'I am Mike 480,8,Ġ480,0.066813


Unnamed: 0,exp_id,input,perturbed_input,perturbation_pos,perturbation_token,attr_score
9,2,Translate to French 'I am Mike',thenslate to French 'I am Mike',0,thens,0.900278
10,2,Translate to French 'I am Mike',Transthal to French 'I am Mike',1,thal,0.306872
11,2,Translate to French 'I am Mike',Translateood French 'I am Mike',2,ood,0.088229
12,2,Translate to French 'I am Mike',Translate to Prophet 'I am Mike',3,ĠProphet,0.324504
13,2,Translate to French 'I am Mike',Translate to French WahI am Mike',4,ĠWah,0.496912
14,2,Translate to French 'I am Mike',"Translate to French '""), am Mike'",5,"""),",0.757303
15,2,Translate to French 'I am Mike',Translate to French 'I version Mike',6,Ġversion,0.340392
16,2,Translate to French 'I am Mike',Translate to French 'I am commissioners',7,Ġcommissioners,0.326719
17,2,Translate to French 'I am Mike',Translate to French 'I am Mike 480,8,Ġ480,0.363484


Unnamed: 0,exp_id,input,perturbed_input,perturbation_pos,perturbation_token,attr_score
18,3,Translate to French 'I am Mike',thenslate to French 'I am Mike',0,thens,0.0
19,3,Translate to French 'I am Mike',Transthal to French 'I am Mike',1,thal,0.0
20,3,Translate to French 'I am Mike',Translateood French 'I am Mike',2,ood,0.0
21,3,Translate to French 'I am Mike',Translate to Prophet 'I am Mike',3,ĠProphet,2.0
22,3,Translate to French 'I am Mike',Translate to French WahI am Mike',4,ĠWah,4.0
23,3,Translate to French 'I am Mike',"Translate to French '""), am Mike'",5,"""),",0.0
24,3,Translate to French 'I am Mike',Translate to French 'I version Mike',6,Ġversion,0.0
25,3,Translate to French 'I am Mike',Translate to French 'I am commissioners',7,Ġcommissioners,2.0
26,3,Translate to French 'I am Mike',Translate to French 'I am Mike 480,8,Ġ480,0.0


Unnamed: 0,exp_id,input,perturbed_input,perturbation_pos,perturbation_token,attr_score
27,4,Translate to French 'I am Mike',Translate to French 'I am Mike',0,ĠTrans,0.0
28,4,Translate to French 'I am Mike',TransLate to French 'I am Mike',1,Late,0.236969
29,4,Translate to French 'I am Mike',Translate in French 'I am Mike',2,Ġin,0.236969
30,4,Translate to French 'I am Mike',Translate toFrench 'I am Mike',3,French,0.039086
31,4,Translate to French 'I am Mike',"Translate to French ""I am Mike'",4,"Ġ""",0.039086
32,4,Translate to French 'I am Mike',Translate to French'I am Mike',5,ĠI,0.039086
33,4,Translate to French 'I am Mike',Translate to French 'I'm Mike',6,'m,0.039086
34,4,Translate to French 'I am Mike',Translate to French 'I amMike',7,Mike,0.039086
35,4,Translate to French 'I am Mike',"Translate to French 'I am Mike',",8,"',",0.039086


Unnamed: 0,exp_id,input,perturbed_input,perturbation_pos,perturbation_token,attr_score
36,5,Translate to French 'I am Mike',Translate to French 'I am Mike',0,ĠTrans,0.337132
37,5,Translate to French 'I am Mike',TransLate to French 'I am Mike',1,Late,0.193868
38,5,Translate to French 'I am Mike',Translate in French 'I am Mike',2,Ġin,0.34292
39,5,Translate to French 'I am Mike',Translate toFrench 'I am Mike',3,French,0.538351
40,5,Translate to French 'I am Mike',"Translate to French ""I am Mike'",4,"Ġ""",0.449304
41,5,Translate to French 'I am Mike',Translate to French'I am Mike',5,ĠI,0.433526
42,5,Translate to French 'I am Mike',Translate to French 'I'm Mike',6,'m,0.674879
43,5,Translate to French 'I am Mike',Translate to French 'I amMike',7,Mike,0.570292
44,5,Translate to French 'I am Mike',"Translate to French 'I am Mike',",8,"',",0.465202


Unnamed: 0,exp_id,input,perturbed_input,perturbation_pos,perturbation_token,attr_score
45,6,Translate to French 'I am Mike',Translate to French 'I am Mike',0,ĠTrans,4.0
46,6,Translate to French 'I am Mike',TransLate to French 'I am Mike',1,Late,0.0
47,6,Translate to French 'I am Mike',Translate in French 'I am Mike',2,Ġin,0.0
48,6,Translate to French 'I am Mike',Translate toFrench 'I am Mike',3,French,4.0
49,6,Translate to French 'I am Mike',"Translate to French ""I am Mike'",4,"Ġ""",0.0
50,6,Translate to French 'I am Mike',Translate to French'I am Mike',5,ĠI,3.0
51,6,Translate to French 'I am Mike',Translate to French 'I'm Mike',6,'m,0.0
52,6,Translate to French 'I am Mike',Translate to French 'I amMike',7,Mike,0.0
53,6,Translate to French 'I am Mike',"Translate to French 'I am Mike',",8,"',",0.0


Unnamed: 0,exp_id,input,perturbed_input,perturbation_pos,perturbation_token,attr_score
54,7,Translate to French 'I am Mike',thelate to French 'I am Mike',0,the,0.0
55,7,Translate to French 'I am Mike',Transthe to French 'I am Mike',1,the,0.0
56,7,Translate to French 'I am Mike',Translatethe French 'I am Mike',2,the,0.039086
57,7,Translate to French 'I am Mike',Translate tothe 'I am Mike',3,the,0.258362
58,7,Translate to French 'I am Mike',Translate to FrenchtheI am Mike',4,the,0.039086
59,7,Translate to French 'I am Mike',Translate to French 'the am Mike',5,the,0.039086
60,7,Translate to French 'I am Mike',Translate to French 'Ithe Mike',6,the,0.229843
61,7,Translate to French 'I am Mike',Translate to French 'I amthe',7,the,0.109907
62,7,Translate to French 'I am Mike',Translate to French 'I am Mikethe,8,the,0.117137


Unnamed: 0,exp_id,input,perturbed_input,perturbation_pos,perturbation_token,attr_score
63,8,Translate to French 'I am Mike',thelate to French 'I am Mike',0,the,1.765328
64,8,Translate to French 'I am Mike',Transthe to French 'I am Mike',1,the,1.723864
65,8,Translate to French 'I am Mike',Translatethe French 'I am Mike',2,the,1.73523
66,8,Translate to French 'I am Mike',Translate tothe 'I am Mike',3,the,1.874508
67,8,Translate to French 'I am Mike',Translate to FrenchtheI am Mike',4,the,1.726647
68,8,Translate to French 'I am Mike',Translate to French 'the am Mike',5,the,2.010869
69,8,Translate to French 'I am Mike',Translate to French 'Ithe Mike',6,the,1.721684
70,8,Translate to French 'I am Mike',Translate to French 'I amthe',7,the,1.733511
71,8,Translate to French 'I am Mike',Translate to French 'I am Mikethe,8,the,1.72447


Unnamed: 0,exp_id,input,perturbed_input,perturbation_pos,perturbation_token,attr_score
72,9,Translate to French 'I am Mike',thelate to French 'I am Mike',0,the,0.0
73,9,Translate to French 'I am Mike',Transthe to French 'I am Mike',1,the,0.0
74,9,Translate to French 'I am Mike',Translatethe French 'I am Mike',2,the,0.0
75,9,Translate to French 'I am Mike',Translate tothe 'I am Mike',3,the,0.0
76,9,Translate to French 'I am Mike',Translate to FrenchtheI am Mike',4,the,0.0
77,9,Translate to French 'I am Mike',Translate to French 'the am Mike',5,the,0.0
78,9,Translate to French 'I am Mike',Translate to French 'Ithe Mike',6,the,11.0
79,9,Translate to French 'I am Mike',Translate to French 'I amthe',7,the,0.0
80,9,Translate to French 'I am Mike',Translate to French 'I am Mikethe,8,the,0.0
