In [8]:
import os
import timeit
from time import time
import sys
from typing import List

import numpy as np
import openai
from dotenv import load_dotenv

from transformers import (
    GPT2LMHeadModel,
    GPT2Tokenizer,
    PreTrainedModel,
    PreTrainedTokenizer,
)

notebook_path = os.path.abspath(get_ipython().starting_dir)
parent_path = os.path.dirname(notebook_path)

sys.path.append(parent_path)
from attribution.experiment_logger import ExperimentLogger
from attribution.token_perturbation import (
    get_replacement_token,
    get_most_similar_token_ids,
)
from attribution.attribution_metrics import (
    cosine_similarity_attribution,
    token_prob_difference,
    token_displacement,
)

In [9]:
load_dotenv()
client = openai.OpenAI(api_key=os.getenv("OPENAI_API_KEY"))

In [10]:
model = GPT2LMHeadModel.from_pretrained("gpt2")  # or any other checkpoint
tokenizer = GPT2Tokenizer.from_pretrained("gpt2")

word_token_embeddings = model.transformer.wte.weight.detach().numpy()
position_embeddings = model.transformer.wpe.weight.detach().numpy()
token_cosine_distances = None

logger = ExperimentLogger()

In [11]:
def get_model_output(input: str) -> openai.types.chat.chat_completion.Choice:
    response = client.chat.completions.create(
        model="gpt-3.5-turbo",
        messages=[{"role": "user", "content": input}],
        temperature=0.0,
        seed=0,
        logprobs=True,
        top_logprobs=20,
    )
    return response.choices[0]


def calculate_token_importance_in_sequence(
    input_sequence: str,
    model: PreTrainedModel,
    tokenizer: PreTrainedTokenizer,
    perturbation_strategy: str = "fixed",
    attribution_strategies: List[str] = [
        "cosine",
        "prob_diff",
        "token_displacement",
    ],
    logger: ExperimentLogger = None,
):
    timestamp = time()
    tokens = tokenizer.tokenize(input_sequence)
    token_ids = tokenizer.encode(input_sequence, add_special_tokens=False)
    original_output = get_model_output(input_sequence)
    print(f"Chat Completion - Original: {round(time() - timestamp, 2)}s")

    if logger:
        logger.start_experiment(
            input_sequence, original_output.message.content, perturbation_strategy
        )

    exp_timestamp = time()
    words = input_sequence.split()
    token_index = 0
    for word in words:
        word_tokens = tokenizer.tokenize(word)
        word_token_ids = tokenizer.encode(word, add_special_tokens=False)
        start_word_time = time()
        replacement_token_ids = [
            get_replacement_token(
                word_token_ids[i],
                perturbation_strategy,
                word_token_embeddings,
                tokenizer,
            )
            for i in range(len(word_tokens))
        ]
        print(
            f"\nReplaced word '{word}': {round(time() - start_word_time, 2)}s - get_replacement_token()"
        )

        # Replace the current word with the new tokens
        perturbed_input = tokenizer.decode(
            token_ids[:token_index]
            + replacement_token_ids
            + token_ids[token_index + len(word_tokens) :]
        )

        # Get the output logprobs for the perturbed input
        timestamp = time()
        print('Original: ',input_sequence)
        print('Perturbed: ',perturbed_input)
        perturbed_output = get_model_output(perturbed_input)
        print(f"Chat Completion - Perturbed: {round(time() - timestamp, 2)}s")

        timestamp = time()
        for attribution_strategy in attribution_strategies:
            attributed_tokens = tokens
            if attribution_strategy == "cosine":
                cosine_timestamp = time()
                sentence_attr, token_attributions = cosine_similarity_attribution(
                    original_output, perturbed_output, model, tokenizer
                )
                cosine_timestamp_end = time()
            elif attribution_strategy == "prob_diff":
                prob_diff_timestamp = time()
                sentence_attr, attributed_tokens, token_attributions = (
                    token_prob_difference(
                        original_output.logprobs, perturbed_output.logprobs
                    )
                )
                prob_diff_timestamp_end = time()
            elif attribution_strategy == "token_displacement":
                token_displacement_timestamp = time()
                sentence_attr, attributed_tokens, token_attributions = (
                    token_displacement(
                        original_output.logprobs, perturbed_output.logprobs
                    )
                )
                token_displacement_timestamp_end = time()
            else:
                raise ValueError(
                    f"Unknown attribution strategy: {attribution_strategy}"
                )

            if logger:
                start_logging = time()
                for i in range(len(word_tokens)):
                    logger.log_input_token_attribution(
                        attribution_strategy,
                        token_index + i,
                        word_tokens[i],
                        float(sentence_attr),
                    )
                    for j, attr_score in enumerate(token_attributions):
                        logger.log_token_attribution_matrix(
                            attribution_strategy,
                            token_index + i,
                            j,
                            attributed_tokens[j],
                            attr_score.squeeze(),
                        )
                end_logging = time()
        time_all_attrs = time() - timestamp
        print(f"Attributions computation: {time_all_attrs}s")
        print(f"- Cosine Attr: {round(cosine_timestamp_end - cosine_timestamp, 2)}s")
        print(
            f"- Prob Diff Attr: {round(prob_diff_timestamp_end - prob_diff_timestamp, 2)}s"
        )
        print(
            f"- Token Displacement Attr: {round(token_displacement_timestamp_end - token_displacement_timestamp, 2)}s"
        )
        print(f"- Attr Logging: {round(end_logging - start_logging, 2)}s")
        print(f"Total for word '{word}': {round(time() - start_word_time, 2)}s")

        token_index += len(word_tokens)

    print(f"\n\nExp Total: {time() - exp_timestamp}s\n\n")

    if logger:
        logger.log_perturbation(
            i,
            tokenizer.decode(replacement_token_ids)[0],
            perturbation_strategy,
            input_sequence,
            original_output.message.content,
            perturbed_input,
            perturbed_output.message.content,
        )
        logger.stop_experiment()

    return (original_output.message.content,)

In [12]:
input_sequences = [
    "The clock shows 9:47 PM. How many minutes 'til 10?"]
#     "The building is 132 meters tall. How tall is the building?",
#     "The package weighs 8.6 kilograms. How much does the package weigh?",
#     "The thermometer reads 23 degrees Celsius. What is the temperature according to the thermometer?",
#     "She drove 157 kilometers to visit her friend. How far did she drive to visit her friend?",
#     "John has 83 books on his shelf. How many books does John have on his shelf?",
#     "Maria is 37 years old today. How old is Maria?",
#     "There are 68 people registered for the webinar. How many people are registered for the webinar?",
#     "Alex saved $363 from his birthday gifts. How much money did Alex save?",
#     "The recipe requires 14 teaspoons of sugar. How many teaspoons of sugar does the recipe require?",
# ]


for input_sequence in input_sequences:
    for perturbation_strategy in ["distant"]:
        original_output = calculate_token_importance_in_sequence(
            input_sequence,
            model,
            tokenizer,
            perturbation_strategy,
            attribution_strategies=["cosine", "prob_diff", "token_displacement"],
            logger=logger,
        )

        print(
            input_sequence,
            original_output,
        )

Chat Completion - Original: 0.73s

Replaced word 'The': 0.06s - get_replacement_token()
Original:  The clock shows 9:47 PM. How many minutes 'til 10?
Perturbed:  exp clock shows 9:47 PM. How many minutes 'til 10?
Chat Completion - Perturbed: 0.73s
Attributions computation: 0.009429931640625s
- Cosine Attr: 0.0s
- Prob Diff Attr: 0.0s
- Token Displacement Attr: 0.0s
- Attr Logging: 0.0s
Total for word 'The': 0.8s

Replaced word 'clock': 0.05s - get_replacement_token()
Original:  The clock shows 9:47 PM. How many minutes 'til 10?
Perturbed:  Thework shows 9:47 PM. How many minutes 'til 10?
Chat Completion - Perturbed: 0.91s
Attributions computation: 0.008873939514160156s
- Cosine Attr: 0.0s
- Prob Diff Attr: 0.0s
- Token Displacement Attr: 0.0s
- Attr Logging: 0.0s
Total for word 'clock': 0.97s

Replaced word 'shows': 0.05s - get_replacement_token()
Original:  The clock shows 9:47 PM. How many minutes 'til 10?
Perturbed:  The clock cheesy 9:47 PM. How many minutes 'til 10?
Chat Completio

In [13]:
display(logger.df_experiments)

Unnamed: 0,exp_id,original_input,original_output,perturbation_strategy,duration
0,1,The clock shows 9:47 PM. How many minutes 'til...,13 minutes.,distant,9.658044


In [14]:
logger.print_sentence_attribution()

Unnamed: 0,exp_id,attribution_strategy,perturbation_strategy,token_1,token_2,token_3,token_4,token_5,token_6,token_7,token_8,token_9,token_10,token_11,token_12,token_13,token_14,token_15,token_16
0,1,cosine,distant,The 0.14,clock 0.14,shows 0.10,9 0.20,: 0.20,47 0.20,PM 0.17,. 0.17,How 0.14,many 0.14,min 0.14,utes 0.14,'t 0.10,il 0.10,10 0.00,? 0.00
1,1,prob_diff,distant,The 0.76,clock 0.70,shows 0.28,9 0.80,: 0.80,47 0.80,PM 0.78,. 0.78,How 0.79,many 0.72,min 0.80,utes 0.80,'t 0.29,il 0.29,10 0.10,? 0.10
2,1,token_displacement,distant,The 13.67,clock 13.00,shows 6.67,9 20.00,: 20.00,47 20.00,PM 13.33,. 13.33,How 13.67,many 13.67,min 16.67,utes 16.67,'t 6.67,il 6.67,10 0.00,? 0.00


In [15]:
logger.print_attribution_matrix(exp_id=1)

Attribution matrix for cosine with perturbation strategy distant:
Input Tokens (Rows) vs. Output Tokens (Columns)


Unnamed: 0,The (0),clock (1),shows (2)
The (0),0.600346,0.759586,0.701984
clock (1),0.600346,0.759586,0.701984
shows (2),-0.0,-0.0,1.0
9 (3),0.600346,0.759586,0.74748
: (4),0.600346,0.759586,0.74748
47 (5),0.600346,0.759586,0.74748
PM (6),0.600346,0.759586,0.701984
. (7),0.600346,0.759586,0.701984
How (8),0.600346,0.759586,0.701984
many (9),0.600346,0.759586,0.701984


Attribution matrix for prob_diff with perturbation strategy distant:
Input Tokens (Rows) vs. Output Tokens (Columns)


Unnamed: 0,13 (0),minutes (1),. (2)
The (0),0.646714,0.999919,0.619983
clock (1),0.489712,0.999919,0.619983
shows (2),0.231651,3.3e-05,0.619983
9 (3),0.785889,0.999919,0.619983
: (4),0.785889,0.999919,0.619983
47 (5),0.785889,0.999919,0.619983
PM (6),0.707377,0.999919,0.619983
. (7),0.707377,0.999919,0.619983
How (8),0.754307,0.999919,0.619983
many (9),0.535038,0.999919,0.619983


Attribution matrix for token_displacement with perturbation strategy distant:
Input Tokens (Rows) vs. Output Tokens (Columns)


Unnamed: 0,13 (0),minutes (1),. (2)
The (0),1.0,20.0,20.0
clock (1),1.0,18.0,20.0
shows (2),0.0,0.0,20.0
9 (3),20.0,20.0,20.0
: (4),20.0,20.0,20.0
47 (5),20.0,20.0,20.0
PM (6),1.0,19.0,20.0
. (7),1.0,19.0,20.0
How (8),1.0,20.0,20.0
many (9),1.0,20.0,20.0


In [None]:
logger.print_attribution_matrix(1, "cosine")

Attribution matrix for cosine with perturbation strategy distant:
Input Tokens (Rows) vs. Output Tokens (Columns)


Unnamed: 0,The (0),clock (1),shows (2),9 (3),: (4),47 (5),PM (6),. (7)
The (0),0.0,-0.0,-0.0,0.0,0.0,-0.0,0.0,0.0
clock (1),0.0,-0.0,-0.0,0.0,0.0,-0.0,0.0,0.0
shows (2),0.0,-0.0,-0.0,0.0,0.0,-0.0,0.0,0.0
9 (3),0.0,-0.0,-0.0,0.289177,0.0,0.274936,0.0,0.0
: (4),0.0,-0.0,-0.0,0.289177,0.0,0.274936,0.0,0.0
47 (5),0.0,-0.0,-0.0,0.289177,0.0,0.274936,0.0,0.0
PM (6),0.0,-0.0,0.0,0.0,0.0,-0.0,0.8215,1.0
. (7),0.0,-0.0,0.0,0.0,0.0,-0.0,0.8215,1.0
What (8),0.0,-0.0,-0.0,0.0,0.0,-0.0,0.0,0.0
time (9),0.0,-0.0,0.0,0.488054,0.0,-0.0,0.8215,1.0


In [None]:
logger.print_attribution_matrix(2, attribution_strategy="cosine")
logger.print_attribution_matrix(2, attribution_strategy="prob_diff")

Attribution matrix for cosine with perturbation strategy distant:
Input Tokens (Rows) vs. Output Tokens (Columns)


Unnamed: 0,The (0),clock (1),shows (2),9 (3),: (4),47 (5),PM (6),. (7)
The (0),0.0,-0.0,-0.0,0.0,0.0,-0.0,0.0,0.0
clock (1),0.0,-0.0,-0.0,0.0,0.0,-0.0,0.0,0.0
shows (2),0.0,-0.0,-0.0,0.0,0.0,-0.0,0.0,0.0
9 (3),0.0,-0.0,-0.0,0.289177,0.0,0.274936,0.0,0.0
: (4),0.0,-0.0,-0.0,0.289177,0.0,0.274936,0.0,0.0
47 (5),0.0,-0.0,-0.0,0.289177,0.0,0.274936,0.0,0.0
PM (6),0.0,-0.0,0.0,0.0,0.0,-0.0,0.8215,1.0
. (7),0.0,-0.0,0.0,0.0,0.0,-0.0,0.8215,1.0
What (8),0.0,-0.0,-0.0,0.0,0.0,-0.0,0.0,0.0
time (9),0.0,-0.0,0.0,0.488054,0.0,-0.0,0.8215,1.0


Attribution matrix for prob_diff with perturbation strategy distant:
Input Tokens (Rows) vs. Output Tokens (Columns)


Unnamed: 0,The (0),clock (1),shows (2),(3),9 (4),: (5),47 (6),PM (7),. (8)
The (0),0.064448,0.007869,0.003085,0.002989,0.000108,0.000128,0.000118,5.4e-05,0.007132
clock (1),0.136458,0.033954,0.00381,0.002736,0.000285,0.000101,0.001414,0.000432,0.000383
shows (2),0.151668,0.006804,0.00929,0.006151,7.9e-05,0.000141,0.000733,0.000125,0.016142
9 (3),0.191758,0.008194,0.016663,0.040286,0.951041,0.000315,0.958258,0.000138,0.001367
: (4),0.191758,0.008194,0.016663,0.040286,0.951041,0.000315,0.958258,0.000138,0.001367
47 (5),0.191758,0.008194,0.016663,0.040286,0.951041,0.000315,0.958258,0.000138,0.001367
PM (6),0.171069,0.000117,0.001073,0.000121,0.000113,0.000139,1.6e-05,0.999858,0.994526
. (7),0.171069,0.000117,0.001073,0.000121,0.000113,0.000139,1.6e-05,0.999858,0.994526
What (8),0.198133,0.007814,0.148431,0.018417,0.003066,6.5e-05,0.00312,0.001621,0.126054
time (9),0.19684,0.005341,0.009895,0.077305,0.884765,5.2e-05,0.000132,0.925403,0.994526


In [None]:
logger.print_attribution_matrix(3, attribution_strategy="cosine")
logger.print_attribution_matrix(3, attribution_strategy="prob_diff")

Attribution matrix for cosine with perturbation strategy distant:
Input Tokens (Rows) vs. Output Tokens (Columns)


Unnamed: 0,The (0),building (1),is (2),132 (3),meters (4),tall (5),. (6)
The (0),0.0,0.0,-0.0,0.0,0.0,0.0,0.0
building (1),0.226659,0.705013,0.520012,0.806286,0.855425,0.756068,0.46633
is (2),0.0,0.0,-0.0,0.0,0.0,0.0,0.0
132 (3),0.226659,0.770341,0.621099,0.807823,0.785745,0.767478,0.480238
met (4),0.226659,0.705013,0.520012,0.78801,0.798989,0.771913,0.412676
ers (5),0.226659,0.705013,0.520012,0.78801,0.798989,0.771913,0.412676
tall (6),0.0,0.0,-0.0,0.0,0.0,0.0,0.0
. (7),0.0,0.0,-0.0,0.0,0.0,0.0,0.0
How (8),0.0,0.0,-0.0,0.0,0.0,0.0,0.0
tall (9),0.0,0.0,-0.0,0.0,0.0,0.0,0.0


Attribution matrix for prob_diff with perturbation strategy distant:
Input Tokens (Rows) vs. Output Tokens (Columns)


Unnamed: 0,The (0),building (1),is (2),(3),132 (4),meters (5),tall (6),. (7)
The (0),0.018736,0.000113,0.000309,7.9e-05,5.3e-05,6e-06,0.000121,2.2e-05
building (1),0.777744,0.999023,0.999883,0.999955,0.999962,0.999989,0.999822,0.999864
is (2),0.015364,0.000713,0.008316,0.000166,1.6e-05,7e-06,0.000146,0.000301
132 (3),0.985685,0.999023,0.999883,0.999893,0.999962,0.999989,0.999822,0.999864
met (4),0.963465,0.999023,0.999883,0.999942,0.999962,0.999989,0.999822,0.999864
ers (5),0.963465,0.999023,0.999883,0.999942,0.999962,0.999989,0.999822,0.999864
tall (6),0.057443,3.9e-05,1e-06,0.001233,0.000211,1.3e-05,0.001011,0.000231
. (7),0.057443,3.9e-05,1e-06,0.001233,0.000211,1.3e-05,0.001011,0.000231
How (8),0.017349,0.013034,0.006253,0.351449,0.024528,0.000197,0.034164,0.228418
tall (9),0.028325,0.024104,0.00219,0.004585,5.1e-05,1.6e-05,0.0001,0.001761
