In [20]:
import os
import timeit
from time import time
import sys
from typing import List

import numpy as np
import openai
from dotenv import load_dotenv

from transformers import (
    GPT2LMHeadModel,
    GPT2Tokenizer,
    PreTrainedModel,
    PreTrainedTokenizer,
)

notebook_path = os.path.abspath(get_ipython().starting_dir)
parent_path = os.path.dirname(notebook_path)

sys.path.append(parent_path)
from attribution.experiment_logger import ExperimentLogger
from attribution.token_perturbation import (
    get_replacement_token,
    get_most_similar_token_ids,
)
from attribution.attribution_metrics import (
    cosine_similarity_attribution,
    token_prob_difference,
    token_displacement,
)

In [2]:
load_dotenv()
client = openai.OpenAI(api_key=os.getenv("OPENAI_API_KEY"))

In [3]:
model = GPT2LMHeadModel.from_pretrained("gpt2")  # or any other checkpoint
tokenizer = GPT2Tokenizer.from_pretrained("gpt2")

word_token_embeddings = model.transformer.wte.weight.detach().numpy()
position_embeddings = model.transformer.wpe.weight.detach().numpy()
token_cosine_distances = None

logger = ExperimentLogger()

In [4]:
def get_model_output(input: str) -> openai.types.chat.chat_completion.Choice:
    response = client.chat.completions.create(
        model="gpt-3.5-turbo",
        messages=[{"role": "user", "content": input}],
        temperature=0.0,
        seed=0,
        logprobs=True,
        top_logprobs=20,
    )
    return response.choices[0]


def calculate_token_importance_in_sequence(
    input_sequence: str,
    model: PreTrainedModel,
    tokenizer: PreTrainedTokenizer,
    perturbation_strategy: str = "fixed",
    attribution_strategies: List[str] = [
        "cosine",
        "prob_diff",
        "token_displacement",
    ],
    logger: ExperimentLogger = None,
):
    timestamp = time()
    tokens = tokenizer.tokenize(input_sequence)
    token_ids = tokenizer.encode(input_sequence, add_special_tokens=False)
    original_output = get_model_output(input_sequence)
    print(f"Chat Completion - Original: {round(time() - timestamp, 2)}s")

    if logger:
        logger.start_experiment(
            input_sequence, original_output.message.content, perturbation_strategy
        )

    exp_timestamp = time()
    for i, token in enumerate(tokens):
        print(i, token)
        start_token_time = time()
        replacement_token_id = get_replacement_token(
            token_ids[i], perturbation_strategy, word_token_embeddings, tokenizer
        )
        print(
            f"\nReplaced token '{token}': {round(time() - start_token_time, 2)}s - get_replacement_token()"
        )

        # Replace the current token with the new token
        perturbed_input = tokenizer.decode(
            token_ids[:i] + [replacement_token_id] + token_ids[i + 1 :]
        )

        # Get the output logprobs for the perturbed input
        timestamp = time()
        perturbed_output = get_model_output(perturbed_input)
        print(f"Chat Completion - Perturbed: {round(time() - timestamp, 2)}s")

        timestamp = time()
        for attribution_strategy in attribution_strategies:
            attributed_tokens = tokens
            if attribution_strategy == "cosine":
                cosine_timestamp = time()
                sentence_attr, token_attributions = cosine_similarity_attribution(
                    original_output, perturbed_output, model, tokenizer
                )
                cosine_timestamp_end = time()
            elif attribution_strategy == "prob_diff":
                prob_diff_timestamp = time()
                sentence_attr, attributed_tokens, token_attributions = (
                    token_prob_difference(
                        original_output.logprobs, perturbed_output.logprobs
                    )
                )
                prob_diff_timestamp_end = time()
            elif attribution_strategy == "token_displacement":
                token_displacement_timestamp = time()
                sentence_attr, attributed_tokens, token_attributions = (
                    token_displacement(
                        original_output.logprobs, perturbed_output.logprobs
                    )
                )
                token_displacement_timestamp_end = time()
            else:
                raise ValueError(
                    f"Unknown attribution strategy: {attribution_strategy}"
                )

            if logger:
                start_logging = time()
                logger.log_input_token_attribution(
                    attribution_strategy, i, token, float(sentence_attr)
                )
                for j, attr_score in enumerate(token_attributions):
                    logger.log_token_attribution_matrix(
                        attribution_strategy,
                        i,
                        j,
                        attributed_tokens[j],
                        attr_score.squeeze(),
                    )
                end_logging = time()
        time_all_attrs = time() - timestamp
        print(f"Attributions computation: {time_all_attrs}s")
        print(f"- Cosine Attr: {round(cosine_timestamp_end - cosine_timestamp, 2)}s")
        print(
            f"- Prob Diff Attr: {round(prob_diff_timestamp_end - prob_diff_timestamp, 2)}s"
        )
        print(
            f"- Token Displacement Attr: {round(token_displacement_timestamp_end - token_displacement_timestamp, 2)}s"
        )
        print(f"- Attr Logging: {round(end_logging - start_logging, 2)}s")
        print(f"Total for token '{token}': {round(time() - start_token_time, 2)}s")

    print(f"\n\nExp Total: {time() - exp_timestamp}s\n\n")

    if logger:
        logger.log_perturbation(
            i,
            tokenizer.decode([replacement_token_id])[0],
            perturbation_strategy,
            input_sequence,
            original_output.message.content,
            perturbed_input,
            perturbed_output.message.content,
        )
        logger.stop_experiment()

    return (original_output.message.content,)

In [33]:
input_sequences = [
    "The clock shows 9:47 PM. What time does the clock show?",
    "The building is 132 meters tall. How tall is the building?",
    "The package weighs 8.6 kilograms. How much does the package weigh?",
    "The thermometer reads 23 degrees Celsius. What is the temperature according to the thermometer?",
    "She drove 157 kilometers to visit her friend. How far did she drive to visit her friend?",
    "John has 83 books on his shelf. How many books does John have on his shelf?",
    "Maria is 37 years old today. How old is Maria?",
    "There are 68 people registered for the webinar. How many people are registered for the webinar?",
    "Alex saved $363 from his birthday gifts. How much money did Alex save?",
    "The recipe requires 14 teaspoons of sugar. How many teaspoons of sugar does the recipe require?",
]


for input_sequence in input_sequences:
    for perturbation_strategy in ["distant"]:
        original_output = calculate_token_importance_in_sequence(
            input_sequence,
            model,
            tokenizer,
            perturbation_strategy,
            attribution_strategies=["cosine", "prob_diff", "token_displacement"],
            logger=logger,
        )

        print(
            input_sequence,
            original_output,
        )



 The clock shows 9:47 PM. What time does the clock show?

The
clock
shows
9:47
PM.
What
time
does
the
clock
show?


 The building is 132 meters tall. How tall is the building?

The
building
is
132
meters
tall.
How
tall
is
the
building?


 The package weighs 8.6 kilograms. How much does the package weigh?

The
package
weighs
8.6
kilograms.
How
much
does
the
package
weigh?


 The thermometer reads 23 degrees Celsius. What is the temperature according to the thermometer?

The
thermometer
reads
23
degrees
Celsius.
What
is
the
temperature
according
to
the
thermometer?


 She drove 157 kilometers to visit her friend. How far did she drive to visit her friend?

She
drove
157
kilometers
to
visit
her
friend.
How
far
did
she
drive
to
visit
her
friend?


 John has 83 books on his shelf. How many books does John have on his shelf?

John
has
83
books
on
his
shelf.
How
many
books
does
John
have
on
his
shelf?


 Maria is 37 years old today. How old is Maria?

Maria
is
37
years
old
today.
How
old
is

In [30]:
for token in ["red", "car", "fun", "few"]:
    token_id = tokenizer.encode(token, add_special_tokens=False)[0]
    print(
        token,
        tokenizer.decode(
            get_most_similar_token_ids(
                token_id, word_token_embeddings, tokenizer, n_tokens=10
            )
        ),
    )

red redRed red Red yellowREDyellow reduce blue with
car car carCar Carcars cars CAR automobileCAR vehicle
fun funFun Fun fun FUN functions enjoyable funnyquickShip
few fewFew Few few fewerSeveralMany handful Several


In [9]:
display(logger.df_experiments)
display(
    logger.df_token_attribution_matrix[
        (logger.df_token_attribution_matrix.exp_id == 1)
        & (logger.df_token_attribution_matrix.attribution_strategy == "cosine")
    ]
)

Unnamed: 0,exp_id,original_input,original_output,perturbation_strategy,duration
0,1,The clock shows 9:47 PM. What time does the cl...,The clock shows 9:47 PM.,distant,20.620097


Unnamed: 0,exp_id,attribution_strategy,input_token_pos,output_token_pos,output_token,attr_score
0,1,cosine,0,0,The,0.000000e+00
1,1,cosine,0,1,Ġclock,-2.384186e-07
2,1,cosine,0,2,Ġshows,-1.192093e-07
3,1,cosine,0,3,Ġ9,0.000000e+00
4,1,cosine,0,4,:,2.384186e-07
...,...,...,...,...,...,...
367,1,cosine,14,3,Ġ9,7.463263e-01
368,1,cosine,14,4,:,1.000000e+00
369,1,cosine,14,5,47,1.000000e+00
370,1,cosine,14,6,ĠPM,1.000000e+00


In [10]:
logger.print_sentence_attribution()

Unnamed: 0,exp_id,attribution_strategy,perturbation_strategy,token_1,token_2,token_3,token_4,token_5,token_6,token_7,token_8,token_9,token_10,token_11,token_12,token_13,token_14,token_15
0,1,cosine,distant,The 0.00,clock 0.00,shows 0.00,9 0.01,: 0.00,47 0.02,PM 0.02,. 0.00,What 0.00,time 0.06,does 0.00,the 0.09,clock 0.07,show 0.03,? 0.11
1,1,prob_diff,distant,The 0.01,clock 0.08,shows 0.01,9 0.12,: 0.02,47 0.11,PM 0.27,. 0.02,What 0.13,time 0.67,does 0.03,the 0.89,clock 0.90,show 0.09,? 0.96
2,1,token_displacement,distant,The 4.00,clock 4.00,shows 4.00,9 3.67,: 4.00,47 4.78,PM 6.78,. 4.00,What 4.00,time 11.56,does 4.00,the 11.22,clock 12.00,show 3.89,? 16.67


In [18]:
print_attribution_matrix(logger, 1)
# logger.df_token_attribution_matrix[logger.df_token_attribution_matrix.exp_id == 1]

Attribution matrix for cosine with perturbation strategy distant:
Input Tokens (Rows) vs. Output Tokens (Columns)


Unnamed: 0,The (0),clock (1),shows (2),9 (3),: (4),47 (5),PM (6),. (7)
The (0),0.0,-0.0,-0.0,0.0,0.0,-0.0,0.0,0.0
clock (1),0.0,-0.0,-0.0,0.0,0.0,-0.0,0.0,0.0
shows (2),0.0,-0.0,-0.0,0.0,0.0,-0.0,0.0,0.0
9 (3),0.0,-0.0,-0.0,0.289177,0.0,-0.0,0.0,0.0
: (4),0.0,-0.0,-0.0,0.0,0.0,-0.0,0.0,0.0
47 (5),0.0,-0.0,-0.0,0.0,0.0,0.540257,0.0,0.0
PM (6),0.0,-0.0,0.0,0.0,0.0,-0.0,0.8215,1.0
. (7),0.0,-0.0,-0.0,0.0,0.0,-0.0,0.0,0.0
What (8),0.0,-0.0,-0.0,0.0,0.0,-0.0,0.0,0.0
time (9),0.0,-0.0,-0.0,0.667719,0.602138,0.754805,0.831338,0.630817


Attribution matrix for prob_diff with perturbation strategy distant:
Input Tokens (Rows) vs. Output Tokens (Columns)


Unnamed: 0,The (0),clock (1),shows (2),(3),9 (4),: (5),47 (6),PM (7),. (8)
The (0),0.033395,0.009052,0.000873,0.003682,0.00036,7.4e-05,0.000147,1.4e-05,0.006235
clock (1),0.055332,0.250573,0.067855,0.012731,0.153478,7.4e-05,0.148615,0.010486,0.029216
shows (2),0.083218,0.007991,0.008208,0.008611,0.000254,8.5e-05,1.6e-05,0.000275,0.006587
9 (3),0.154698,0.003391,0.002432,0.000598,0.939173,8.4e-05,7e-05,0.000113,0.009089
: (4),0.179471,0.003536,0.003959,0.00348,0.002277,7e-06,0.001425,7.7e-05,0.008447
47 (5),0.019816,0.002488,0.000742,1e-06,0.000146,0.00043,0.999462,4e-06,0.006009
PM (6),0.156422,0.002369,0.041553,0.003346,0.115171,5.4e-05,0.085338,0.999874,0.990127
. (7),0.143831,0.010591,0.001918,0.000544,0.000139,9.1e-05,5.4e-05,4.7e-05,0.009459
What (8),0.150989,0.008399,0.062355,0.439929,0.001447,7.2e-05,2.1e-05,0.000632,0.531042
time (9),0.074004,0.10446,0.031845,0.866577,0.999771,0.999908,0.999894,0.999874,0.990127


Attribution matrix for token_displacement with perturbation strategy distant:
Input Tokens (Rows) vs. Output Tokens (Columns)


Unnamed: 0,The (0),clock (1),shows (2),(3),9 (4),: (5),47 (6),PM (7),. (8)
The (0),0.0,1.0,2.0,3.0,4.0,5.0,6.0,7.0,8.0
clock (1),0.0,1.0,2.0,3.0,4.0,5.0,6.0,7.0,8.0
shows (2),0.0,1.0,2.0,3.0,4.0,5.0,6.0,7.0,8.0
9 (3),0.0,1.0,2.0,3.0,1.0,5.0,6.0,7.0,8.0
: (4),0.0,1.0,2.0,3.0,4.0,5.0,6.0,7.0,8.0
47 (5),0.0,1.0,2.0,3.0,4.0,5.0,13.0,7.0,8.0
PM (6),0.0,1.0,2.0,3.0,4.0,5.0,6.0,20.0,20.0
. (7),0.0,1.0,2.0,3.0,4.0,5.0,6.0,7.0,8.0
What (8),0.0,1.0,2.0,3.0,4.0,5.0,6.0,7.0,8.0
time (9),0.0,1.0,2.0,1.0,20.0,20.0,20.0,20.0,20.0


In [16]:
logger.print_attribution_matrix(1, "cosine")

Attribution matrix for cosine with perturbation strategy distant:
Input Tokens (Rows) vs. Output Tokens (Columns)


Unnamed: 0,The (0),clock (1),shows (2),9 (3),: (4),47 (5),PM (6),. (7)
The (0),0.0,-0.0,-0.0,0.0,0.0,-0.0,0.0,0.0
clock (1),0.0,-0.0,-0.0,0.0,0.0,-0.0,0.0,0.0
shows (2),0.0,-0.0,-0.0,0.0,0.0,-0.0,0.0,0.0
9 (3),0.0,-0.0,-0.0,0.289177,0.0,-0.0,0.0,0.0
: (4),0.0,-0.0,-0.0,0.0,0.0,-0.0,0.0,0.0
47 (5),0.0,-0.0,-0.0,0.0,0.0,0.540257,0.0,0.0
PM (6),0.0,-0.0,0.0,0.0,0.0,-0.0,0.8215,1.0
. (7),0.0,-0.0,-0.0,0.0,0.0,-0.0,0.0,0.0
What (8),0.0,-0.0,-0.0,0.0,0.0,-0.0,0.0,0.0
time (9),0.0,-0.0,-0.0,0.667719,0.602138,0.754805,0.831338,0.630817


In [None]:
logger.print_attribution_matrix(2, attribution_strategy="cosine")
logger.print_attribution_matrix(2, attribution_strategy="prob_diff")

Attribution matrix for cosine with perturbation strategy distant:
Input Tokens (Rows) vs. Output Tokens (Columns)


Unnamed: 0,The (0),clock (1),shows (2),9 (3),: (4),47 (5),PM (6),. (7)
The (0),0.56058,0.85697,0.696821,0.746326,1.0,1.0,1.0,1.0
clock (1),0.0,-0.0,-0.0,0.0,0.0,-0.0,0.0,0.0
shows (2),0.0,-0.0,-0.0,0.0,0.0,-0.0,0.0,0.0
9 (3),0.0,-0.0,-0.0,0.289177,0.0,-0.0,0.0,0.0
: (4),0.0,-0.0,-0.0,0.0,0.0,-0.0,0.0,0.0
47 (5),0.0,-0.0,-0.0,0.0,0.0,0.540257,0.0,0.0
PM (6),0.0,-0.0,0.0,0.0,0.0,-0.0,0.8215,1.0
. (7),0.0,-0.0,-0.0,0.0,0.0,-0.0,0.0,0.0
What (8),0.0,-0.0,-0.0,0.0,0.0,-0.0,0.0,0.0
time (9),0.0,-0.0,-0.0,0.0,0.0,-0.0,0.0,0.0


Attribution matrix for prob_diff with perturbation strategy distant:
Input Tokens (Rows) vs. Output Tokens (Columns)


In [None]:
logger.print_attribution_matrix(3, attribution_strategy="cosine")
logger.print_attribution_matrix(3, attribution_strategy="prob_diff")

Attribution matrix for cosine with perturbation strategy distant:
Input Tokens (Rows) vs. Output Tokens (Columns)


Unnamed: 0,The (0),clock (1),shows (2),9 (3),: (4),47 (5),PM (6),. (7)
The (0),0.56058,0.85697,0.696821,0.746326,1.0,1.0,1.0,1.0
clock (1),0.0,-0.0,-0.0,0.0,0.0,0.307193,0.0,0.0
shows (2),0.0,-0.0,-0.0,0.0,0.0,-0.0,0.0,0.0
9 (3),0.0,-0.0,-0.0,0.265841,0.0,-0.0,0.0,0.0
: (4),0.0,-0.0,-0.0,0.0,0.0,-0.0,0.0,0.0
47 (5),0.0,-0.0,-0.0,0.0,0.0,0.540257,0.0,0.423717
PM (6),0.0,-0.0,0.0,0.0,0.0,-0.0,0.8215,1.0
. (7),0.0,-0.0,-0.0,0.0,0.0,-0.0,0.0,0.0
What (8),0.0,-0.0,-0.0,0.639039,0.759923,0.742686,0.83404,0.836332
time (9),0.226659,0.77783,0.732814,0.801206,0.554864,0.730914,0.799345,0.417455


Attribution matrix for prob_diff with perturbation strategy distant:
Input Tokens (Rows) vs. Output Tokens (Columns)


Unnamed: 0,The (0),clock (1),shows (2),(3),9 (4),: (5),47 (6),PM (7),. (8)
The (0),0.173069,0.990708,0.997162,0.999496,0.999765,0.999947,0.999916,0.999918,0.989687
clock (1),0.004415,0.209209,0.083948,0.011617,0.134388,3.7e-05,0.140484,0.012522,0.005322
shows (2),0.068842,0.005182,0.007144,0.005228,0.000216,4.6e-05,3.4e-05,0.000187,0.011756
9 (3),0.108911,0.006162,0.002425,0.0012,0.945123,4.2e-05,0.000185,4.4e-05,0.009254
: (4),0.133495,0.010202,0.004138,0.003494,0.004689,7.6e-05,0.001983,1.2e-05,0.008104
47 (5),0.059414,0.019443,0.001198,0.000893,0.000277,0.000759,0.998938,3e-05,0.010867
PM (6),0.125587,0.004899,0.021221,0.006192,0.067328,1.4e-05,0.064784,0.999918,0.989687
. (7),0.008513,0.006089,0.001306,0.000153,0.00016,5.1e-05,4.7e-05,2.1e-05,0.00432
What (8),0.101163,0.002933,0.075282,0.484466,0.999765,0.999947,0.999916,0.999918,0.989685
time (9),0.04091,0.990708,0.997162,0.996478,0.999765,0.999947,0.999916,0.999918,0.989687
