In [1]:
import os
import sys
import math
import itertools
from typing import List, Tuple

import numpy as np
import openai
from dotenv import load_dotenv

from transformers import (
    GPT2LMHeadModel,
    GPT2Tokenizer,
    PreTrainedModel,
    PreTrainedTokenizer,
)

notebook_path = os.path.abspath(get_ipython().starting_dir)
parent_path = os.path.dirname(notebook_path)

sys.path.append(parent_path)
from attribution.experiment_logger import ExperimentLogger
from attribution.token_perturbation import get_replacement_token
from attribution.attribution_metrics import (
    cosine_similarity_attribution,
    token_prob_difference,
    token_displacement,
)

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
load_dotenv()
client = openai.OpenAI(api_key=os.getenv("OPENAI_API_KEY"))

In [3]:
model = GPT2LMHeadModel.from_pretrained("gpt2")  # or any other checkpoint
tokenizer = GPT2Tokenizer.from_pretrained("gpt2")

word_token_embeddings = model.transformer.wte.weight.detach().numpy()
position_embeddings = model.transformer.wpe.weight.detach().numpy()
token_cosine_distances = None

logger = ExperimentLogger()

In [6]:
def get_model_output(input: str) -> openai.types.chat.chat_completion.Choice:
    response = client.chat.completions.create(
        model="gpt-3.5-turbo",
        messages=[{"role": "user", "content": input}],
        logprobs=True,
        top_logprobs=20,
    )

    return response.choices[0]


def calculate_token_importance_in_sequence(
    input_sequence: str,
    model: PreTrainedModel,
    tokenizer: PreTrainedTokenizer,
    perturbation_strategy: str = "fixed",
    attribution_strategies: List[str] = [
        "cosine",
        "prob_diff",
        "token_displacement",
    ],
    logger: ExperimentLogger = None,
):
    tokens = tokenizer.tokenize(input_sequence)
    token_ids = tokenizer.encode(input_sequence, add_special_tokens=False)
    original_output = get_model_output(input_sequence)

    if logger:
        logger.log_experiment(
            input_sequence, original_output.message.content, perturbation_strategy
        )

    for i, token in enumerate(tokens):
        replacement_token_id = get_replacement_token(
            token_ids[i], perturbation_strategy, word_token_embeddings, tokenizer
        )

        # Replace the current token with the new token
        perturbed_input = tokenizer.decode(
            token_ids[:i] + [replacement_token_id] + token_ids[i + 1 :]
        )

        # Get the output logprobs for the perturbed input
        perturbed_output = get_model_output(perturbed_input)

        for attribution_strategy in attribution_strategies:
            attributed_tokens = tokens
            if attribution_strategy == "cosine":
                sentence_attr, token_attributions = cosine_similarity_attribution(
                    original_output, perturbed_output, model, tokenizer
                )
            elif attribution_strategy == "prob_diff":
                sentence_attr, attributed_tokens, token_attributions = (
                    token_prob_difference(
                        original_output.logprobs, perturbed_output.logprobs
                    )
                )
            elif attribution_strategy == "token_displacement":
                sentence_attr, attributed_tokens, token_attributions = (
                    token_displacement(
                        original_output.logprobs, perturbed_output.logprobs
                    )
                )
            else:
                raise ValueError(
                    f"Unknown attribution strategy: {attribution_strategy}"
                )

            if logger:
                logger.log_input_token_attribution(
                    attribution_strategy, i, token, float(sentence_attr)
                )
                for j, attr_score in enumerate(token_attributions):
                    logger.log_token_attribution_matrix(
                        attribution_strategy,
                        i,
                        j,
                        attributed_tokens[j],
                        attr_score.squeeze(),
                    )

    if logger:
        logger.log_perturbation(
            i,
            tokenizer.decode([replacement_token_id])[0],
            perturbation_strategy,
            input_sequence,
            original_output.message.content,
            perturbed_input,
            perturbed_output.message.content,
        )

    return (original_output.message.content,)

In [21]:
input_sequences = [
    "The clock shows 9:47 PM. What time does the clock show?",
    "The building is 132 meters tall. How tall is the building?",
    "The package weighs 8.6 kilograms. How much does the package weigh?",
    "The thermometer reads 23 degrees Celsius. What is the temperature according to the thermometer?",
    "She drove 157 kilometers to visit her friend. How far did she drive to visit her friend?",
    "John has 83 books on his shelf. How many books does John have on his shelf?",
    "Maria is 37 years old today. How old is Maria?",
    "There are 68 people registered for the webinar. How many people are registered for the webinar?",
    "Alex saved $363 from his birthday gifts. How much money did Alex save?",
    "The recipe requires 14 teaspoons of sugar. How many teaspoons of sugar does the recipe require?",
]

for input_sequence in input_sequences:
    for perturbation_strategy in ["distant"]:
        original_output = calculate_token_importance_in_sequence(
            input_sequence,
            model,
            tokenizer,
            perturbation_strategy,
            attribution_strategies=["cosine", "prob_diff", "token_displacement"],
            logger=logger,
        )

        print(
            input_sequence,
            original_output,
        )

The clock shows 9:47 PM. What time does the clock show? ('9:47 PM',)
The building is 132 meters tall. How tall is the building? ('The building is 132 meters tall.',)
The package weighs 8.6 kilograms. How much does the package weigh? ('The package weighs 8.6 kilograms.',)
The thermometer reads 23 degrees Celsius. What is the temperature according to the thermometer? ('The temperature according to the thermometer is 23 degrees Celsius.',)
She drove 157 kilometers to visit her friend. How far did she drive to visit her friend? ('She drove 157 kilometers to visit her friend.',)
John has 83 books on his shelf. How many books does John have on his shelf? ('John has 83 books on his shelf.',)
Maria is 37 years old today. How old is Maria? ('Maria is 37 years old.',)
There are 68 people registered for the webinar. How many people are registered for the webinar? ('There are 68 people registered for the webinar.',)
Alex saved $363 from his birthday gifts. How much money did Alex save? ('Alex save

In [22]:
logger.df_experiments

Unnamed: 0,exp_id,original_input,original_output,perturbation_strategy
0,1,The clock shows 9:47 PM. What time does the cl...,The clock shows 9:47 PM.,distant
1,2,The clock shows 9:47 PM. What time does the cl...,The clock shows 9:47 PM.,distant
2,3,The clock shows 9:47 PM. What time does the cl...,The clock shows 9:47 PM.,distant
3,4,The building is 132 meters tall. How tall is t...,The building is 132 meters tall.,distant
4,5,The package weighs 8.6 kilograms. How much doe...,The package weighs 8.6 kilograms.,distant
5,6,The clock shows 9:47 PM. What time does the cl...,9:47 PM,distant
6,7,The building is 132 meters tall. How tall is t...,The building is 132 meters tall.,distant
7,8,The package weighs 8.6 kilograms. How much doe...,The package weighs 8.6 kilograms.,distant
8,9,The thermometer reads 23 degrees Celsius. What...,The temperature according to the thermometer i...,distant
9,10,She drove 157 kilometers to visit her friend. ...,She drove 157 kilometers to visit her friend.,distant


In [27]:
logger.print_sentence_attribution()

Unnamed: 0,exp_id,attribution_strategy,perturbation_strategy,token_1,token_2,token_3,token_4,token_5,token_6,token_7,token_8,token_9,token_10,token_11,token_12,token_13,token_14,token_15,token_16,token_17,token_18,token_19,token_20
0,1,cosine,distant,The 0.00,clock 0.00,shows 0.11,9 0.01,: 0.00,47 0.01,PM 0.02,. 0.11,What 0.09,time 0.03,does 0.00,the 0.13,clock 0.08,show 0.03,? 0.12,,,,,
1,2,cosine,distant,The 0.11,clock 0.00,shows 0.00,9 0.01,: 0.00,47 0.02,PM 0.02,. 0.00,What 0.00,time 0.00,does 0.00,the 0.07,clock 0.10,show 0.00,? 0.11,,,,,
2,3,cosine,distant,The 0.11,clock 0.01,shows 0.00,9 0.01,: 0.00,47 0.06,PM 0.02,. 0.00,What 0.08,time 0.19,does 0.00,the 0.08,clock 0.08,show 0.03,? 0.11,,,,,
3,3,prob_diff,distant,The 0.91,clock 0.07,shows 0.01,9 0.12,: 0.02,47 0.12,PM 0.25,. 0.00,What 0.63,time 0.89,does 0.03,the 0.88,clock 0.90,show 0.08,? 0.97,,,,,
4,3,token_displacement,distant,The 16.33,clock 4.00,shows 4.00,9 3.67,: 4.00,47 4.67,PM 6.78,. 4.00,What 10.22,time 16.44,does 4.00,the 10.89,clock 13.67,show 3.89,? 16.44,,,,,
5,4,cosine,distant,The 0.00,building 0.16,is 0.00,132 0.12,meters 0.00,tall 0.00,. 0.00,How 0.00,tall 0.11,is 0.00,the 0.16,building 0.14,? 0.17,,,,,,,
6,4,prob_diff,distant,The 0.00,building 1.00,is 0.01,132 1.00,meters 0.06,tall 0.00,. 0.00,How 0.00,tall 0.55,is 0.01,the 0.95,building 0.98,? 0.99,,,,,,,
7,4,token_displacement,distant,The 3.50,building 16.88,is 3.50,132 16.00,meters 3.50,tall 3.50,. 3.50,How 3.50,tall 6.75,is 3.50,the 10.62,building 17.88,? 16.38,,,,,,,
8,5,cosine,distant,The 0.00,package 0.00,weighs 0.00,8 0.01,. 0.00,6 0.03,kilograms 0.02,. 0.00,How 0.00,much 0.00,does 0.00,the 0.02,package 0.11,weigh 0.00,? 0.00,,,,,
9,5,prob_diff,distant,The 0.00,package 0.10,weighs 0.00,8 0.11,. 0.02,6 0.42,kilograms 0.10,. 0.00,How 0.00,much 0.01,does 0.01,the 0.79,package 1.00,weigh 0.01,? 0.05,,,,,


In [28]:
logger.print_attribution_matrix(1, "cosine")

Attribution matrix for cosine with perturbation strategy distant:
Input Tokens (Rows) vs. Output Tokens (Columns)


Unnamed: 0,The (0),clock (1),shows (2),9 (3),: (4),47 (5),PM (6),. (7)
The (0),0.0,-0.0,-0.0,0.0,0.0,-0.0,0.0,0.0
clock (1),0.0,-0.0,-0.0,0.0,0.0,-0.0,0.0,0.0
shows (2),0.56058,0.85697,0.696821,0.746326,1.0,1.0,1.0,1.0
9 (3),0.0,-0.0,-0.0,0.289177,0.0,-0.0,0.0,0.0
: (4),0.0,-0.0,-0.0,0.0,0.0,-0.0,0.0,0.0
47 (5),0.0,-0.0,-0.0,0.0,0.0,0.325577,0.0,0.0
PM (6),0.0,-0.0,0.0,0.0,0.0,-0.0,0.8215,1.0
. (7),0.56058,0.85697,0.696821,0.746326,1.0,1.0,1.0,1.0
What (8),0.748034,0.777353,0.741282,0.701788,0.727512,0.693778,0.781776,0.75865
time (9),0.0,-0.0,-0.0,0.639039,0.69659,0.746812,0.746326,0.27142


In [29]:
logger.print_attribution_matrix(3, "prob_diff")

Attribution matrix for prob_diff with perturbation strategy distant:
Input Tokens (Rows) vs. Output Tokens (Columns)


Unnamed: 0,The (0),clock (1),shows (2),(3),9 (4),: (5),47 (6),PM (7),. (8)
The (0),0.173069,0.990708,0.997162,0.999496,0.999765,0.999947,0.999916,0.999918,0.989687
clock (1),0.004415,0.209209,0.083948,0.011617,0.134388,3.7e-05,0.140484,0.012522,0.005322
shows (2),0.068842,0.005182,0.007144,0.005228,0.000216,4.6e-05,3.4e-05,0.000187,0.011756
9 (3),0.108911,0.006162,0.002425,0.0012,0.945123,4.2e-05,0.000185,4.4e-05,0.009254
: (4),0.133495,0.010202,0.004138,0.003494,0.004689,7.6e-05,0.001983,1.2e-05,0.008104
47 (5),0.059414,0.019443,0.001198,0.000893,0.000277,0.000759,0.998938,3e-05,0.010867
PM (6),0.125587,0.004899,0.021221,0.006192,0.067328,1.4e-05,0.064784,0.999918,0.989687
. (7),0.008513,0.006089,0.001306,0.000153,0.00016,5.1e-05,4.7e-05,2.1e-05,0.00432
What (8),0.101163,0.002933,0.075282,0.484466,0.999765,0.999947,0.999916,0.999918,0.989685
time (9),0.04091,0.990708,0.997162,0.996478,0.999765,0.999947,0.999916,0.999918,0.989687


In [30]:
logger.print_attribution_matrix(2, attribution_strategy="cosine")
logger.print_attribution_matrix(2, attribution_strategy="prob_diff")

Attribution matrix for cosine with perturbation strategy distant:
Input Tokens (Rows) vs. Output Tokens (Columns)


Unnamed: 0,The (0),clock (1),shows (2),9 (3),: (4),47 (5),PM (6),. (7)
The (0),0.56058,0.85697,0.696821,0.746326,1.0,1.0,1.0,1.0
clock (1),0.0,-0.0,-0.0,0.0,0.0,-0.0,0.0,0.0
shows (2),0.0,-0.0,-0.0,0.0,0.0,-0.0,0.0,0.0
9 (3),0.0,-0.0,-0.0,0.289177,0.0,-0.0,0.0,0.0
: (4),0.0,-0.0,-0.0,0.0,0.0,-0.0,0.0,0.0
47 (5),0.0,-0.0,-0.0,0.0,0.0,0.540257,0.0,0.0
PM (6),0.0,-0.0,0.0,0.0,0.0,-0.0,0.8215,1.0
. (7),0.0,-0.0,-0.0,0.0,0.0,-0.0,0.0,0.0
What (8),0.0,-0.0,-0.0,0.0,0.0,-0.0,0.0,0.0
time (9),0.0,-0.0,-0.0,0.0,0.0,-0.0,0.0,0.0


Attribution matrix for prob_diff with perturbation strategy distant:
Input Tokens (Rows) vs. Output Tokens (Columns)


In [31]:
logger.print_attribution_matrix(3, attribution_strategy="cosine")
logger.print_attribution_matrix(3, attribution_strategy="prob_diff")

Attribution matrix for cosine with perturbation strategy distant:
Input Tokens (Rows) vs. Output Tokens (Columns)


Unnamed: 0,The (0),clock (1),shows (2),9 (3),: (4),47 (5),PM (6),. (7)
The (0),0.56058,0.85697,0.696821,0.746326,1.0,1.0,1.0,1.0
clock (1),0.0,-0.0,-0.0,0.0,0.0,0.307193,0.0,0.0
shows (2),0.0,-0.0,-0.0,0.0,0.0,-0.0,0.0,0.0
9 (3),0.0,-0.0,-0.0,0.265841,0.0,-0.0,0.0,0.0
: (4),0.0,-0.0,-0.0,0.0,0.0,-0.0,0.0,0.0
47 (5),0.0,-0.0,-0.0,0.0,0.0,0.540257,0.0,0.423717
PM (6),0.0,-0.0,0.0,0.0,0.0,-0.0,0.8215,1.0
. (7),0.0,-0.0,-0.0,0.0,0.0,-0.0,0.0,0.0
What (8),0.0,-0.0,-0.0,0.639039,0.759923,0.742686,0.83404,0.836332
time (9),0.226659,0.77783,0.732814,0.801206,0.554864,0.730914,0.799345,0.417455


Attribution matrix for prob_diff with perturbation strategy distant:
Input Tokens (Rows) vs. Output Tokens (Columns)


Unnamed: 0,The (0),clock (1),shows (2),(3),9 (4),: (5),47 (6),PM (7),. (8)
The (0),0.173069,0.990708,0.997162,0.999496,0.999765,0.999947,0.999916,0.999918,0.989687
clock (1),0.004415,0.209209,0.083948,0.011617,0.134388,3.7e-05,0.140484,0.012522,0.005322
shows (2),0.068842,0.005182,0.007144,0.005228,0.000216,4.6e-05,3.4e-05,0.000187,0.011756
9 (3),0.108911,0.006162,0.002425,0.0012,0.945123,4.2e-05,0.000185,4.4e-05,0.009254
: (4),0.133495,0.010202,0.004138,0.003494,0.004689,7.6e-05,0.001983,1.2e-05,0.008104
47 (5),0.059414,0.019443,0.001198,0.000893,0.000277,0.000759,0.998938,3e-05,0.010867
PM (6),0.125587,0.004899,0.021221,0.006192,0.067328,1.4e-05,0.064784,0.999918,0.989687
. (7),0.008513,0.006089,0.001306,0.000153,0.00016,5.1e-05,4.7e-05,2.1e-05,0.00432
What (8),0.101163,0.002933,0.075282,0.484466,0.999765,0.999947,0.999916,0.999918,0.989685
time (9),0.04091,0.990708,0.997162,0.996478,0.999765,0.999947,0.999916,0.999918,0.989687
