In [1]:
import os
import timeit
from time import time
import sys
from typing import List

import numpy as np
import openai
from dotenv import load_dotenv

from transformers import (
    GPT2LMHeadModel,
    GPT2Tokenizer,
    PreTrainedModel,
    PreTrainedTokenizer,
)

notebook_path = os.path.abspath(get_ipython().starting_dir)
parent_path = os.path.dirname(notebook_path)

sys.path.append(parent_path)
from attribution.experiment_logger import ExperimentLogger
from attribution.token_perturbation import (
    get_replacement_token,
    get_most_similar_token_ids,
)
from attribution.attribution_metrics import (
    cosine_similarity_attribution,
    token_prob_difference,
    token_displacement,
)

In [2]:
load_dotenv()
client = openai.OpenAI(api_key=os.getenv("OPENAI_API_KEY")

In [3]:
model = GPT2LMHeadModel.from_pretrained("gpt2")  # or any other checkpoint
tokenizer = GPT2Tokenizer.from_pretrained("gpt2", add_prefix_space=True)

word_token_embeddings = model.transformer.wte.weight.detach().numpy()
position_embeddings = model.transformer.wpe.weight.detach().numpy()
token_cosine_distances = None

logger = ExperimentLogger()



In [4]:
def get_model_output(input: str) -> openai.types.chat.chat_completion.Choice:
    response = client.chat.completions.create(
        model="gpt-3.5-turbo",
        messages=[{"role": "user", "content": input}],
        temperature=0.0,
        seed=0,
        logprobs=True,
        top_logprobs=20,
    )
    return response.choices[0]


def calculate_token_importance(
    input_text: str,
    model: PreTrainedModel,
    tokenizer: PreTrainedTokenizer,
    perturbation_strategy: str = "fixed",
    attribution_strategies: List[str] = [
        "cosine",
        "prob_diff",
        "token_displacement",
    ],
    logger: ExperimentLogger = None,
    perturb_word_wise: bool = False,
    n: int = -1,
):
    timestamp = time()
    original_output = get_model_output(input_text)
    print(f"Chat Completion - Original: {round(time() - timestamp, 2)}s")

    if logger:
        logger.start_experiment(
            input_text,
            original_output.message.content,
            perturbation_strategy,
            perturb_word_wise,
        )

    exp_timestamp = time()

    # A unit is either a word or a single token
    unit_offset = 0
    if perturb_word_wise:
        words = input_text.split()
        tokens_per_unit = [tokenizer.tokenize(word) for word in words]
        token_ids_per_unit = [
            tokenizer.encode(word, add_special_tokens=False) for word in words
        ]
    else:
        tokens_per_unit = [[token] for token in tokenizer.tokenize(input_text)]
        token_ids_per_unit = [
            [token_id]
            for token_id in tokenizer.encode(input_text, add_special_tokens=False)
        ]

    for i_unit, unit_tokens in enumerate(tokens_per_unit):
        start_word_time = time()
        replacement_token_ids = [
            get_replacement_token(
                token_id,
                perturbation_strategy,
                word_token_embeddings,
                tokenizer,
                n
            )
            for token_id in token_ids_per_unit[i_unit]
        ]
        print(
            f"\nReplaced word '{''.join(unit_tokens)}': {round(time() - start_word_time, 2)}s - get_replacement_token()"
        )

        # Replace the current word with the new tokens
        left_token_ids = [
            token_id
            for unit_token_ids in token_ids_per_unit[:i_unit]
            for token_id in unit_token_ids
        ]
        right_token_ids = [
            token_id
            for unit_token_ids in token_ids_per_unit[i_unit + 1 :]
            for token_id in unit_token_ids
        ]
        perturbed_input = tokenizer.decode(
            left_token_ids + replacement_token_ids + right_token_ids
        )

        # Get the output logprobs for the perturbed input
        timestamp = time()
        print("Original: ", input_text)
        print("Perturbed: ", perturbed_input)
        perturbed_output = get_model_output(perturbed_input)
        print(f"Chat Completion - Perturbed: {round(time() - timestamp, 2)}s")

        timestamp = time()
        for attribution_strategy in attribution_strategies:
            attributed_tokens = [
                token_logprob.token
                for token_logprob in original_output.logprobs.content
            ]
            print(attribution_strategy, "attributed_tokens", attributed_tokens)
            if attribution_strategy == "cosine":
                cosine_timestamp = time()
                sentence_attr, token_attributions = cosine_similarity_attribution(
                    original_output, perturbed_output, model, tokenizer
                )
                cosine_timestamp_end = time()
            elif attribution_strategy == "prob_diff":
                prob_diff_timestamp = time()
                sentence_attr, attributed_tokens, token_attributions = (
                    token_prob_difference(
                        original_output.logprobs, perturbed_output.logprobs
                    )
                )
                prob_diff_timestamp_end = time()
            elif attribution_strategy == "token_displacement":
                token_displacement_timestamp = time()
                sentence_attr, attributed_tokens, token_attributions = (
                    token_displacement(
                        original_output.logprobs, perturbed_output.logprobs
                    )
                )
                token_displacement_timestamp_end = time()
            else:
                raise ValueError(
                    f"Unknown attribution strategy: {attribution_strategy}"
                )

            if logger:
                start_logging = time()
                for i, unit_token in enumerate(unit_tokens):
                    logger.log_input_token_attribution(
                        attribution_strategy,
                        unit_offset + i,
                        unit_token,
                        float(sentence_attr),
                    )
                    for j, attr_score in enumerate(token_attributions):
                        logger.log_token_attribution_matrix(
                            attribution_strategy,
                            unit_offset + i,
                            j,
                            attributed_tokens[j],
                            attr_score.squeeze(),
                        )
                end_logging = time()
        time_all_attrs = time() - timestamp
        # print(f"Attributions computation: {time_all_attrs}s")
        # print(f"- Cosine Attr: {round(cosine_timestamp_end - cosine_timestamp, 2)}s")
        # print(
        #     f"- Prob Diff Attr: {round(prob_diff_timestamp_end - prob_diff_timestamp, 2)}s"
        # )
        # print(
        #     f"- Token Displacement Attr: {round(token_displacement_timestamp_end - token_displacement_timestamp, 2)}s"
        # )
        # print(f"- Attr Logging: {round(end_logging - start_logging, 2)}s")
        # print(f"Total for word '{word}': {round(time() - start_word_time, 2)}s")

        unit_offset += len(unit_tokens)

    print(f"\n\nExp Total: {time() - exp_timestamp}s\n\n")

    if logger:
        logger.log_perturbation(
            i,
            tokenizer.decode(replacement_token_ids)[0],
            perturbation_strategy,
            input_text,
            original_output.message.content,
            perturbed_input,
            perturbed_output.message.content,
        )
        logger.stop_experiment()

    return (original_output.message.content,)

In [7]:
input_texts = ["The clock shows 9:47 PM. How many minutes 'til 10 PM?"]
#       ["Complete: Rose are red"]
#     "The building is 132 meters tall. How tall is the building?",
#     "The package weighs 8.6 kilograms. How much does the package weigh?",
#     "The thermometer reads 23 degrees Celsius. What is the temperature according to the thermometer?",
#     "She drove 157 kilometers to visit her friend. How far did she drive to visit her friend?",
#     "John has 83 books on his shelf. How many books does John have on his shelf?",
#     "Maria is 37 years old today. How old is Maria?",
#     "There are 68 people registered for the webinar. How many people are registered for the webinar?",
#     "Alex saved $363 from his birthday gifts. How much money did Alex save?",
#     "The recipe requires 14 teaspoons of sugar. How many teaspoons of sugar does the recipe require?",
# ]


for input_text in input_texts:
    for neighbour in [0, 100, 1000, 10000, -1]:
        original_output = calculate_token_importance(
            input_text,
            model,
            tokenizer,
            perturbation_strategy="nearest",
            attribution_strategies=["cosine"], #, "prob_diff", "token_displacement"],
            logger=logger,
            perturb_word_wise=True,
            n=neighbour,
        )

        print(
            input_text,
            original_output,
        )

Chat Completion - Original: 0.79s

Replaced word 'ĠThe': 0.17s - get_replacement_token()
Original:  The clock shows 9:47 PM. How many minutes 'til 10 PM?
Perturbed:   The clock shows 9:47 PM. How many minutes 'til 10 PM?
Chat Completion - Perturbed: 0.48s
cosine attributed_tokens ['13', ' minutes', '.']

Replaced word 'Ġclock': 0.16s - get_replacement_token()
Original:  The clock shows 9:47 PM. How many minutes 'til 10 PM?
Perturbed:   The clock shows 9:47 PM. How many minutes 'til 10 PM?
Chat Completion - Perturbed: 1.17s
cosine attributed_tokens ['13', ' minutes', '.']

Replaced word 'Ġshows': 0.16s - get_replacement_token()
Original:  The clock shows 9:47 PM. How many minutes 'til 10 PM?
Perturbed:   The clock shows 9:47 PM. How many minutes 'til 10 PM?
Chat Completion - Perturbed: 0.92s
cosine attributed_tokens ['13', ' minutes', '.']

Replaced word 'Ġ9:47': 0.58s - get_replacement_token()
Original:  The clock shows 9:47 PM. How many minutes 'til 10 PM?
Perturbed:   The clock shows

In [8]:
display(logger.df_experiments)

Unnamed: 0,exp_id,original_input,original_output,perturbation_strategy,perturb_word_wise,duration
0,1,The clock shows 9:47 PM. How many minutes 'til...,13 minutes.,[nearest],True,
1,2,The clock shows 9:47 PM. How many minutes 'til...,13 minutes.,[nearest],True,
2,3,The clock shows 9:47 PM. How many minutes 'til...,13 minutes.,nearest,True,12.072218
3,4,The clock shows 9:47 PM. How many minutes 'til...,13 minutes.,nearest,True,10.696381
4,5,The clock shows 9:47 PM. How many minutes 'til...,13 minutes.,nearest,True,12.739646
5,6,The clock shows 9:47 PM. How many minutes 'til...,13 minutes.,nearest,True,9.977688
6,7,The clock shows 9:47 PM. How many minutes 'til...,13 minutes.,nearest,True,10.965171


In [9]:
logger.print_sentence_attribution()

Unnamed: 0,exp_id,attribution_strategy,perturbation_strategy,perturb_word_wise,token_1,token_2,token_3,token_4,token_5,token_6,token_7,token_8,token_9,token_10,token_11,token_12,token_13,token_14,token_15,token_16
0,3,cosine,nearest,True,The 0.00,clock 0.13,shows 0.13,9 0.13,: 0.13,47 0.13,PM 0.13,. 0.13,How 0.13,many 0.13,minutes 0.13,' 0.13,til 0.13,10 0.13,PM 0.13,? 0.13
1,4,cosine,nearest,True,The 0.13,clock 0.13,shows 0.00,9 0.06,: 0.06,47 0.06,PM 0.13,. 0.13,How 0.13,many 0.13,minutes 0.14,' 0.14,til 0.14,10 0.00,PM 0.14,? 0.14
2,5,cosine,nearest,True,The 0.13,clock 0.00,shows 0.13,9 0.15,: 0.15,47 0.15,PM 0.13,. 0.13,How 0.14,many 0.13,minutes 0.14,' 0.11,til 0.11,10 0.00,PM 0.11,? 0.11
3,6,cosine,nearest,True,The 0.13,clock 0.00,shows 0.13,9 0.19,: 0.19,47 0.19,PM 0.13,. 0.13,How 0.25,many 0.00,minutes 0.12,' 0.11,til 0.11,10 0.00,PM 0.00,? 0.00
4,7,cosine,nearest,True,The 0.00,clock 0.13,shows 0.00,9 0.15,: 0.15,47 0.15,PM 0.13,. 0.13,How 0.13,many 0.13,minutes 0.11,' 0.11,til 0.11,10 0.00,PM 0.11,? 0.11


In [17]:
logger.print_attribution_matrix(exp_id=3)

Attribution matrix for cosine with perturbation strategy nearest:
Input Tokens (Rows) vs. Output Tokens (Columns)


Unnamed: 0,13 (0),minutes (1),. (2)
The (0),0.0,-0.0,0.0
clock (1),0.688956,0.759586,0.701984
shows (2),0.688956,0.759586,0.701984
9 (3),0.688956,0.759586,0.701984
: (4),0.688956,0.759586,0.701984
47 (5),0.688956,0.759586,0.701984
PM (6),0.688956,0.759586,0.701984
. (7),0.688956,0.759586,0.701984
How (8),0.688956,0.759586,0.701984
many (9),0.688956,0.759586,0.701984


In [104]:
logger.print_attribution_matrix(exp_id=10, attribution_strategy="prob_diff")

Attribution matrix for prob_diff with perturbation strategy nearest:
Input Tokens (Rows) vs. Output Tokens (Columns)


Unnamed: 0,157 (0),", (1)",000 (2),meters (3)
She (0),3e-06,0.002872,0.0,0.000102
drove (1),4.7e-05,0.003232,1.2e-05,9.9e-05
157 (2),0.999992,0.008727,8e-06,0.000166
kilometers (3),1e-05,0.001221,2e-06,0.000249
to (4),6.3e-05,0.001631,2e-06,0.000233
visit (5),7.7e-05,0.001165,2e-06,0.000284
her (6),2.6e-05,5.5e-05,1.5e-05,0.000192
friend (7),7.3e-05,0.001867,3.1e-05,0.00013
. (8),7.3e-05,0.001867,3.1e-05,0.00013
How (9),0.000122,0.009969,9e-06,0.000138


In [36]:
logger.print_attribution_matrix(4, "cosine")

Attribution matrix for cosine with perturbation strategy distant:
Input Tokens (Rows) vs. Output Tokens (Columns)


Unnamed: 0,157 (0),", (1)",000 (2),meters (3)
She (0),-0.0,0.0,-0.0,0.0
drove (1),-0.0,0.0,-0.0,0.0
157 (2),0.748192,0.843399,0.731184,0.886992
kilometers (3),-0.0,0.926494,0.739293,0.71063
to (4),-0.0,0.0,-0.0,0.0
visit (5),-0.0,0.0,-0.0,0.0
her (6),-0.0,0.0,-0.0,0.0
friend (7),-0.0,0.0,-0.0,0.0
. (8),-0.0,0.0,-0.0,0.0
How (9),-0.0,0.0,-0.0,0.0


In [37]:
logger.print_attribution_matrix(4, attribution_strategy="cosine")
logger.print_attribution_matrix(4, attribution_strategy="prob_diff")

Attribution matrix for cosine with perturbation strategy distant:
Input Tokens (Rows) vs. Output Tokens (Columns)


Unnamed: 0,157 (0),", (1)",000 (2),meters (3)
She (0),-0.0,0.0,-0.0,0.0
drove (1),-0.0,0.0,-0.0,0.0
157 (2),0.748192,0.843399,0.731184,0.886992
kilometers (3),-0.0,0.926494,0.739293,0.71063
to (4),-0.0,0.0,-0.0,0.0
visit (5),-0.0,0.0,-0.0,0.0
her (6),-0.0,0.0,-0.0,0.0
friend (7),-0.0,0.0,-0.0,0.0
. (8),-0.0,0.0,-0.0,0.0
How (9),-0.0,0.0,-0.0,0.0


Attribution matrix for prob_diff with perturbation strategy distant:
Input Tokens (Rows) vs. Output Tokens (Columns)


Unnamed: 0,157 (0),", (1)",000 (2),meters (3)
She (0),4.7e-05,0.006988,2.1e-05,0.000936
drove (1),0.001236,0.003509,1.8e-05,0.000198
157 (2),0.999995,0.997071,0.999996,0.999673
kilometers (3),0.153902,0.995923,0.999996,0.999673
to (4),8e-06,0.009697,4e-06,0.000471
visit (5),1.1e-05,0.008603,5e-06,0.000178
her (6),7.4e-05,0.009388,5e-06,0.000369
friend (7),0.000526,0.004561,5e-05,0.00907
. (8),0.000526,0.004561,5e-05,0.00907
How (9),0.000899,0.023841,3e-05,0.00021


## Quantitative metric

In [119]:
input_str = "The clock shows 9:47 PM. How many minutes 'til 10 PM?"
output_str = get_model_output(input_str).message.content

In [120]:
output_str

'13 minutes.'

In [121]:
tokenizer = GPT2Tokenizer.from_pretrained("gpt2", add_prefix_space=True)




In [122]:
[tokenizer.decode(token) for token in tokenizer(input_str).input_ids]

[' The',
 ' clock',
 ' shows',
 ' 9',
 ':',
 '47',
 ' PM',
 '.',
 ' How',
 ' many',
 ' minutes',
 " '",
 'til',
 ' 10',
 ' PM',
 '?']

In [123]:
[tokenizer.decode(token) for token in tokenizer(output_str).input_ids]

[' 13', ' minutes', '.']

In [124]:
relevant_input_ids = [3,5,13]
relevant_output_ids = [0]

In [125]:
perturbation_strategy = ['distant', 'fixed', 'nearest']
for exp_id in [14, 15, 16]:
    for attribution_strategy in ['cosine', 'prob_diff', 'token_displacement']:
        for output_id in relevant_output_ids:
            success = 0
            total = 0
            mean_attr_score = logger.df_token_attribution_matrix[(logger.df_token_attribution_matrix['exp_id'] == exp_id) & (logger.df_token_attribution_matrix['attribution_strategy'] == attribution_strategy) & (logger.df_token_attribution_matrix['output_token_pos'] == output_id)]['attr_score'].mean().item()
            for input_id in relevant_input_ids:
                attr_score = logger.df_token_attribution_matrix[(logger.df_token_attribution_matrix['exp_id'] == exp_id) & (logger.df_token_attribution_matrix['attribution_strategy'] == attribution_strategy) & (logger.df_token_attribution_matrix['output_token_pos'] == output_id) & (logger.df_token_attribution_matrix['input_token_pos'] == input_id)]['attr_score'].item()
                if attr_score >= mean_attr_score:
                    success += 1
                total += 1

        print(f'Metric score for perturbation_strategy: {perturbation_strategy[exp_id-14]} and attribution_strategy: {attribution_strategy} - {success/total}')

Metric score for perturbation_strategy: distant and attribution_strategy: cosine - 0.6666666666666666
Metric score for perturbation_strategy: distant and attribution_strategy: prob_diff - 0.6666666666666666
Metric score for perturbation_strategy: distant and attribution_strategy: token_displacement - 0.6666666666666666
Metric score for perturbation_strategy: fixed and attribution_strategy: cosine - 0.6666666666666666
Metric score for perturbation_strategy: fixed and attribution_strategy: prob_diff - 0.6666666666666666
Metric score for perturbation_strategy: fixed and attribution_strategy: token_displacement - 0.6666666666666666
Metric score for perturbation_strategy: nearest and attribution_strategy: cosine - 1.0
Metric score for perturbation_strategy: nearest and attribution_strategy: prob_diff - 1.0
Metric score for perturbation_strategy: nearest and attribution_strategy: token_displacement - 0.6666666666666666
