In [1]:
import os
import timeit
from time import time
import sys
from typing import List

import numpy as np
import openai
from dotenv import load_dotenv

from transformers import (
    GPT2LMHeadModel,
    GPT2Tokenizer,
    PreTrainedModel,
    PreTrainedTokenizer,
)

notebook_path = os.path.abspath(get_ipython().starting_dir)
parent_path = os.path.dirname(notebook_path)

sys.path.append(parent_path)

from attribution.experiment_logger import ExperimentLogger
from attribution.token_perturbation import (
    get_replacement_token,
    get_most_similar_token_ids,
)
from attribution.attribution_metrics import (
    cosine_similarity_attribution,
    token_prob_difference,
    token_displacement,
)

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
load_dotenv()
client = openai.OpenAI(api_key=os.getenv("OPENAI_API_KEY"))

In [3]:
model = GPT2LMHeadModel.from_pretrained("gpt2")  # or any other checkpoint
tokenizer = GPT2Tokenizer.from_pretrained("gpt2", add_prefix_space=True)

word_token_embeddings = model.transformer.wte.weight.detach().numpy()
position_embeddings = model.transformer.wpe.weight.detach().numpy()
token_cosine_distances = None

logger = ExperimentLogger()



In [7]:
def get_model_output(input: str) -> openai.types.chat.chat_completion.Choice:
    response = client.chat.completions.create(
        model="gpt-3.5-turbo",
        messages=[{"role": "user", "content": input}],
        temperature=0.0,
        seed=0,
        logprobs=True,
        top_logprobs=20,
    )
    return response.choices[0]


def calculate_token_importance(
    input_text: str,
    model: PreTrainedModel,
    tokenizer: PreTrainedTokenizer,
    perturbation_strategy: str = "fixed",
    attribution_strategies: List[str] = [
        "cosine",
        "prob_diff",
        "token_displacement",
    ],
    logger: ExperimentLogger = None,
    perturb_word_wise: bool = False,
):
    timestamp = time()
    original_output = get_model_output(input_text)
    print(f"Chat Completion - Original: {round(time() - timestamp, 2)}s")

    if logger:
        logger.start_experiment(
            input_text,
            original_output.message.content,
            perturbation_strategy,
            perturb_word_wise,
        )

    exp_timestamp = time()

    # A unit is either a word or a single token
    unit_offset = 0
    if perturb_word_wise:
        words = input_text.split()
        tokens_per_unit = [tokenizer.tokenize(word) for word in words]
        token_ids_per_unit = [
            tokenizer.encode(word, add_special_tokens=False) for word in words
        ]
    else:
        tokens_per_unit = [[token] for token in tokenizer.tokenize(input_text)]
        token_ids_per_unit = [
            [token_id]
            for token_id in tokenizer.encode(input_text, add_special_tokens=False)
        ]

    for i_unit, unit_tokens in enumerate(tokens_per_unit):
        start_word_time = time()
        replacement_token_ids = [
            get_replacement_token(
                token_id,
                perturbation_strategy,
                word_token_embeddings,
                tokenizer,
            )
            for token_id in token_ids_per_unit[i_unit]
        ]
        print(
            f"\nReplaced word '{''.join(unit_tokens)}': {round(time() - start_word_time, 2)}s - get_replacement_token()"
        )

        # Replace the current word with the new tokens
        left_token_ids = [
            token_id
            for unit_token_ids in token_ids_per_unit[:i_unit]
            for token_id in unit_token_ids
        ]
        right_token_ids = [
            token_id
            for unit_token_ids in token_ids_per_unit[i_unit + 1 :]
            for token_id in unit_token_ids
        ]
        perturbed_input = tokenizer.decode(
            left_token_ids + replacement_token_ids + right_token_ids
        )
        print(
            "before ", unit_tokens, " after ", tokenizer.decode(replacement_token_ids)
        )

        # Get the output logprobs for the perturbed input
        timestamp = time()
        print("Original: ", input_text)
        print("Perturbed: ", perturbed_input)
        perturbed_output = get_model_output(perturbed_input)
        print(f"Chat Completion - Perturbed: {round(time() - timestamp, 2)}s")

        timestamp = time()
        for attribution_strategy in attribution_strategies:
            attributed_tokens = [
                token_logprob.token
                for token_logprob in original_output.logprobs.content
            ]
            print(attribution_strategy, "attributed_tokens", attributed_tokens)
            if attribution_strategy == "cosine":
                cosine_timestamp = time()
                sentence_attr, token_attributions = cosine_similarity_attribution(
                    original_output, perturbed_output, model, tokenizer
                )
                cosine_timestamp_end = time()
            elif attribution_strategy == "prob_diff":
                prob_diff_timestamp = time()
                sentence_attr, attributed_tokens, token_attributions = (
                    token_prob_difference(
                        original_output.logprobs, perturbed_output.logprobs
                    )
                )
                prob_diff_timestamp_end = time()
            elif attribution_strategy == "token_displacement":
                token_displacement_timestamp = time()
                sentence_attr, attributed_tokens, token_attributions = (
                    token_displacement(
                        original_output.logprobs, perturbed_output.logprobs
                    )
                )
                token_displacement_timestamp_end = time()
            else:
                raise ValueError(
                    f"Unknown attribution strategy: {attribution_strategy}"
                )

            if logger:
                start_logging = time()
                for i, unit_token in enumerate(unit_tokens):
                    logger.log_input_token_attribution(
                        attribution_strategy,
                        unit_offset + i,
                        unit_token,
                        float(sentence_attr),
                    )
                    for j, attr_score in enumerate(token_attributions):
                        logger.log_token_attribution_matrix(
                            attribution_strategy,
                            unit_offset + i,
                            j,
                            attributed_tokens[j],
                            attr_score.squeeze(),
                        )
                end_logging = time()
        time_all_attrs = time() - timestamp
        # print(f"Attributions computation: {time_all_attrs}s")
        # print(f"- Cosine Attr: {round(cosine_timestamp_end - cosine_timestamp, 2)}s")
        # print(
        #     f"- Prob Diff Attr: {round(prob_diff_timestamp_end - prob_diff_timestamp, 2)}s"
        # )
        # print(
        #     f"- Token Displacement Attr: {round(token_displacement_timestamp_end - token_displacement_timestamp, 2)}s"
        # )
        # print(f"- Attr Logging: {round(end_logging - start_logging, 2)}s")
        # print(f"Total for word '{word}': {round(time() - start_word_time, 2)}s")

        unit_offset += len(unit_tokens)

    print(f"\n\nExp Total: {time() - exp_timestamp}s\n\n")

    if logger:
        logger.log_perturbation(
            i,
            tokenizer.decode(replacement_token_ids)[0],
            perturbation_strategy,
            input_text,
            original_output.message.content,
            perturbed_input,
            perturbed_output.message.content,
        )
        logger.stop_experiment()

    return (original_output.message.content,)

In [8]:
input_texts = ["The clock shows 9:47 PM. How many minutes 'til 10?"]
#     "The building is 132 meters tall. How tall is the building?",
#     "The package weighs 8.6 kilograms. How much does the package weigh?",
#     "The thermometer reads 23 degrees Celsius. What is the temperature according to the thermometer?",
#     "She drove 157 kilometers to visit her friend. How far did she drive to visit her friend?",
#     "John has 83 books on his shelf. How many books does John have on his shelf?",
#     "Maria is 37 years old today. How old is Maria?",
#     "There are 68 people registered for the webinar. How many people are registered for the webinar?",
#     "Alex saved $363 from his birthday gifts. How much money did Alex save?",
#     "The recipe requires 14 teaspoons of sugar. How many teaspoons of sugar does the recipe require?",
# ]


for input_text in input_texts:
    for perturbation_strategy in ["distant"]:
        original_output = calculate_token_importance(
            input_text,
            model,
            tokenizer,
            perturbation_strategy,
            attribution_strategies=["cosine", "prob_diff", "token_displacement"],
            logger=logger,
            perturb_word_wise=True,
        )

        print(
            input_text,
            original_output,
        )

Chat Completion - Original: 1.31s

Replaced word 'ĠThe': 0.07s - get_replacement_token()
before  ['ĠThe']  after  Streamer
Original:  The clock shows 9:47 PM. How many minutes 'til 10?
Perturbed:  Streamer clock shows 9:47 PM. How many minutes 'til 10?
Chat Completion - Perturbed: 0.72s
cosine attributed_tokens ['There', ' are', ' ', '13', ' minutes', ' until', ' ', '10', ':', '00', ' PM', '.']
prob_diff attributed_tokens ['There', ' are', ' ', '13', ' minutes', ' until', ' ', '10', ':', '00', ' PM', '.']
token_displacement attributed_tokens ['There', ' are', ' ', '13', ' minutes', ' until', ' ', '10', ':', '00', ' PM', '.']

Replaced word 'Ġclock': 0.05s - get_replacement_token()
before  ['Ġclock']  after  ur
Original:  The clock shows 9:47 PM. How many minutes 'til 10?
Perturbed:   Theur shows 9:47 PM. How many minutes 'til 10?
Chat Completion - Perturbed: 0.91s
cosine attributed_tokens ['There', ' are', ' ', '13', ' minutes', ' until', ' ', '10', ':', '00', ' PM', '.']
prob_diff att

In [18]:
display(logger.df_experiments)

Unnamed: 0,exp_id,original_input,original_output,perturbation_strategy,perturb_word_wise,duration
0,1,The clock shows 9:47 PM. How many minutes 'til...,13 minutes.,distant,True,18.055734


In [19]:
logger.print_sentence_attribution()

Unnamed: 0,exp_id,attribution_strategy,perturbation_strategy,perturb_word_wise,token_1,token_2,token_3,token_4,token_5,token_6,token_7,token_8,token_9,token_10,token_11,token_12,token_13,token_14,token_15
0,1,cosine,distant,True,The 0.00,clock 0.00,shows 0.00,9 0.15,: 0.15,47 0.15,PM 0.15,. 0.15,How 0.13,many 0.13,minutes 0.13,' 0.00,til 0.00,10 0.11,? 0.11
1,1,prob_diff,distant,True,The 0.08,clock 0.12,shows 0.11,9 0.82,: 0.82,47 0.82,PM 0.81,. 0.81,How 0.81,many 0.80,minutes 0.82,' 0.05,til 0.05,10 0.24,? 0.24
2,1,token_displacement,distant,True,The 0.00,clock 0.00,shows 0.00,9 19.33,: 19.33,47 19.33,PM 12.67,. 12.67,How 13.67,many 13.67,minutes 13.67,' 0.00,til 0.00,10 6.67,? 6.67


In [23]:
logger.print_attribution_matrix(exp_id=1)

Attribution matrix for cosine with perturbation strategy distant:
Input Tokens (Rows) vs. Output Tokens (Columns)


Unnamed: 0,13 (0),minutes (1),. (2)
The (0),0.0,0.0,0.0
clock (1),0.0,0.0,0.0
shows (2),0.0,0.0,0.0
9 (3),0.688956,0.759586,0.74748
: (4),0.688956,0.759586,0.74748
47 (5),0.688956,0.759586,0.74748
PM (6),0.688956,0.759586,0.701984
. (7),0.688956,0.759586,0.701984
How (8),0.688956,0.759586,0.701984
many (9),0.688956,0.759586,0.701984


Attribution matrix for prob_diff with perturbation strategy distant:
Input Tokens (Rows) vs. Output Tokens (Columns)


Unnamed: 0,13 (0),minutes (1),. (2)
The (0),0.239891,1.2e-05,0.004187
clock (1),0.339138,6.1e-05,0.025875
shows (2),0.299949,2.9e-05,0.024715
9 (3),0.817992,0.999872,0.640487
: (4),0.817992,0.999872,0.640487
47 (5),0.817992,0.999872,0.640487
PM (6),0.804387,0.999872,0.640487
. (7),0.804387,0.999872,0.640487
How (8),0.787602,0.999872,0.640487
many (9),0.758278,0.999872,0.640487


Attribution matrix for token_displacement with perturbation strategy distant:
Input Tokens (Rows) vs. Output Tokens (Columns)


Unnamed: 0,13 (0),minutes (1),. (2)
The (0),0.0,0.0,0.0
clock (1),0.0,0.0,0.0
shows (2),0.0,0.0,0.0
9 (3),20.0,18.0,20.0
: (4),20.0,18.0,20.0
47 (5),20.0,18.0,20.0
PM (6),1.0,17.0,20.0
. (7),1.0,17.0,20.0
How (8),1.0,20.0,20.0
many (9),1.0,20.0,20.0


In [21]:
logger.print_attribution_matrix(exp_id=1, attribution_strategy="prob_diff")

Attribution matrix for prob_diff with perturbation strategy distant:
Input Tokens (Rows) vs. Output Tokens (Columns)


Unnamed: 0,13 (0),minutes (1),. (2)
The (0),0.239891,1.2e-05,0.004187
clock (1),0.339138,6.1e-05,0.025875
shows (2),0.299949,2.9e-05,0.024715
9 (3),0.817992,0.999872,0.640487
: (4),0.817992,0.999872,0.640487
47 (5),0.817992,0.999872,0.640487
PM (6),0.804387,0.999872,0.640487
. (7),0.804387,0.999872,0.640487
How (8),0.787602,0.999872,0.640487
many (9),0.758278,0.999872,0.640487


In [22]:
logger.print_attribution_matrix(1, "cosine")

Attribution matrix for cosine with perturbation strategy distant:
Input Tokens (Rows) vs. Output Tokens (Columns)


Unnamed: 0,13 (0),minutes (1),. (2)
The (0),0.0,0.0,0.0
clock (1),0.0,0.0,0.0
shows (2),0.0,0.0,0.0
9 (3),0.688956,0.759586,0.74748
: (4),0.688956,0.759586,0.74748
47 (5),0.688956,0.759586,0.74748
PM (6),0.688956,0.759586,0.701984
. (7),0.688956,0.759586,0.701984
How (8),0.688956,0.759586,0.701984
many (9),0.688956,0.759586,0.701984


In [15]:
logger.print_attribution_matrix(2, attribution_strategy="cosine")
logger.print_attribution_matrix(2, attribution_strategy="prob_diff")

IndexError: index 0 is out of bounds for axis 0 with size 0