## Importing libraries

In [1]:
import os
import timeit
from time import time
import sys
from typing import List

import numpy as np
import openai
from dotenv import load_dotenv

from transformers import (
    GPT2LMHeadModel,
    GPT2Tokenizer,
    PreTrainedModel,
    PreTrainedTokenizer,
)

notebook_path = os.path.abspath(get_ipython().starting_dir)
parent_path = os.path.dirname(notebook_path)

sys.path.append(parent_path)
from attribution.experiment_logger import ExperimentLogger
from attribution.token_perturbation import (
    get_replacement_token,
    get_most_similar_token_ids,
)
from attribution.attribution_metrics import (
    cosine_similarity_attribution,
    token_prob_difference,
    token_displacement,
)

In [2]:
load_dotenv()
client = openai.OpenAI(api_key=os.getenv("OPENAI_API_KEY"))

## Initialising tokenizer, embeddings and logger

In [3]:
model = GPT2LMHeadModel.from_pretrained("gpt2")  # or any other checkpoint
tokenizer = GPT2Tokenizer.from_pretrained("gpt2", add_prefix_space=True)

word_token_embeddings = model.transformer.wte.weight.detach().numpy()
position_embeddings = model.transformer.wpe.weight.detach().numpy()
token_cosine_distances = None

logger = ExperimentLogger()



## Attribution helper functions

In [4]:
def get_model_output(input: str) -> openai.types.chat.chat_completion.Choice:
    response = client.chat.completions.create(
        model="gpt-3.5-turbo",
        messages=[{"role": "user", "content": input}],
        temperature=0.0,
        seed=0,
        logprobs=True,
        top_logprobs=20,
    )
    return response.choices[0]


def calculate_token_importance(
    input_text: str,
    model: PreTrainedModel,
    tokenizer: PreTrainedTokenizer,
    perturbation_strategy: str = "fixed",
    attribution_strategies: List[str] = [
        "cosine",
        "prob_diff",
        "token_displacement",
    ],
    logger: ExperimentLogger = None,
    perturb_word_wise: bool = False,
    n: int = -1,
):
    timestamp = time()
    original_output = get_model_output(input_text)
    print(f"Chat Completion - Original: {round(time() - timestamp, 2)}s")

    if logger:
        logger.start_experiment(
            input_text,
            original_output.message.content,
            perturbation_strategy,
            perturb_word_wise,
        )

    exp_timestamp = time()

    # A unit is either a word or a single token
    unit_offset = 0
    if perturb_word_wise:
        words = input_text.split()
        tokens_per_unit = [tokenizer.tokenize(word) for word in words]
        token_ids_per_unit = [
            tokenizer.encode(word, add_special_tokens=False) for word in words
        ]
    else:
        tokens_per_unit = [[token] for token in tokenizer.tokenize(input_text)]
        token_ids_per_unit = [
            [token_id]
            for token_id in tokenizer.encode(input_text, add_special_tokens=False)
        ]

    for i_unit, unit_tokens in enumerate(tokens_per_unit):
        start_word_time = time()
        replacement_token_ids = [
            get_replacement_token(
                token_id,
                perturbation_strategy,
                word_token_embeddings,
                tokenizer,
                n
            )
            for token_id in token_ids_per_unit[i_unit]
        ]
        print(
            f"\nReplaced word '{''.join(unit_tokens)}': {round(time() - start_word_time, 2)}s - get_replacement_token()"
        )

        # Replace the current word with the new tokens
        left_token_ids = [
            token_id
            for unit_token_ids in token_ids_per_unit[:i_unit]
            for token_id in unit_token_ids
        ]
        right_token_ids = [
            token_id
            for unit_token_ids in token_ids_per_unit[i_unit + 1 :]
            for token_id in unit_token_ids
        ]
        perturbed_input = tokenizer.decode(
            left_token_ids + replacement_token_ids + right_token_ids
        )

        # Get the output logprobs for the perturbed input
        timestamp = time()
        print("Original: ", input_text)
        print("Perturbed: ", perturbed_input)
        perturbed_output = get_model_output(perturbed_input)
        print(f"Chat Completion - Perturbed: {round(time() - timestamp, 2)}s")

        timestamp = time()
        for attribution_strategy in attribution_strategies:
            attributed_tokens = [
                token_logprob.token
                for token_logprob in original_output.logprobs.content
            ]
            print(attribution_strategy, "attributed_tokens", attributed_tokens)
            if attribution_strategy == "cosine":
                cosine_timestamp = time()
                sentence_attr, token_attributions = cosine_similarity_attribution(
                    original_output, perturbed_output, model, tokenizer
                )
                cosine_timestamp_end = time()
            elif attribution_strategy == "prob_diff":
                prob_diff_timestamp = time()
                sentence_attr, attributed_tokens, token_attributions = (
                    token_prob_difference(
                        original_output.logprobs, perturbed_output.logprobs
                    )
                )
                prob_diff_timestamp_end = time()
            elif attribution_strategy == "token_displacement":
                token_displacement_timestamp = time()
                sentence_attr, attributed_tokens, token_attributions = (
                    token_displacement(
                        original_output.logprobs, perturbed_output.logprobs
                    )
                )
                token_displacement_timestamp_end = time()
            else:
                raise ValueError(
                    f"Unknown attribution strategy: {attribution_strategy}"
                )

            if logger:
                start_logging = time()
                for i, unit_token in enumerate(unit_tokens):
                    logger.log_input_token_attribution(
                        attribution_strategy,
                        unit_offset + i,
                        unit_token,
                        float(sentence_attr),
                    )
                    for j, attr_score in enumerate(token_attributions):
                        logger.log_token_attribution_matrix(
                            attribution_strategy,
                            unit_offset + i,
                            j,
                            attributed_tokens[j],
                            attr_score.squeeze(),
                        )
                end_logging = time()
        time_all_attrs = time() - timestamp
        # print(f"Attributions computation: {time_all_attrs}s")
        # print(f"- Cosine Attr: {round(cosine_timestamp_end - cosine_timestamp, 2)}s")
        # print(
        #     f"- Prob Diff Attr: {round(prob_diff_timestamp_end - prob_diff_timestamp, 2)}s"
        # )
        # print(
        #     f"- Token Displacement Attr: {round(token_displacement_timestamp_end - token_displacement_timestamp, 2)}s"
        # )
        # print(f"- Attr Logging: {round(end_logging - start_logging, 2)}s")
        # print(f"Total for word '{word}': {round(time() - start_word_time, 2)}s")

        unit_offset += len(unit_tokens)

    print(f"\n\nExp Total: {time() - exp_timestamp}s\n\n")

    if logger:
        logger.log_perturbation(
            i,
            tokenizer.decode(replacement_token_ids)[0],
            perturbation_strategy,
            input_text,
            original_output.message.content,
            perturbed_input,
            perturbed_output.message.content,
        )
        logger.stop_experiment()

    return (original_output.message.content,)

## Testing on crafted samples

In [5]:
input_texts = ["The clock shows 9:47 PM. How many minutes 'til 10 PM?"]

for input_text in input_texts:
    for perturbation_strategy in ["distant", "fixed", "nearest"]:
        original_output = calculate_token_importance(
            input_text,
            model,
            tokenizer,
            perturbation_strategy,
            attribution_strategies=["cosine", "prob_diff", "token_displacement"],
            logger=logger,
            perturb_word_wise=True,
        )

        print(
            input_text,
            original_output,
        )

Chat Completion - Original: 0.45s

Replaced word 'ĠThe': 0.15s - get_replacement_token()
Original:  The clock shows 9:47 PM. How many minutes 'til 10 PM?
Perturbed:   Limit clock shows 9:47 PM. How many minutes 'til 10 PM?
Chat Completion - Perturbed: 0.81s
cosine attributed_tokens ['13', ' minutes', '.']
prob_diff attributed_tokens ['13', ' minutes', '.']
token_displacement attributed_tokens ['13', ' minutes', '.']

Replaced word 'Ġclock': 0.15s - get_replacement_token()
Original:  The clock shows 9:47 PM. How many minutes 'til 10 PM?
Perturbed:   The Penny shows 9:47 PM. How many minutes 'til 10 PM?
Chat Completion - Perturbed: 0.37s
cosine attributed_tokens ['13', ' minutes', '.']
prob_diff attributed_tokens ['13', ' minutes', '.']
token_displacement attributed_tokens ['13', ' minutes', '.']

Replaced word 'Ġshows': 0.15s - get_replacement_token()
Original:  The clock shows 9:47 PM. How many minutes 'til 10 PM?
Perturbed:   The clockmedia 9:47 PM. How many minutes 'til 10 PM?
Chat C

In [34]:
input_texts = ["The clock shows 9:47 PM. How many minutes 'til 10 PM?"]

for input_text in input_texts:
    for perturbation_strategy in  ["nearest"]:
        original_output = calculate_token_importance(
            input_text,
            model,
            tokenizer,
            perturbation_strategy,
            attribution_strategies=["cosine", "prob_diff", "token_displacement"],
            logger=logger,
            perturb_word_wise=True,
        )

        print(
            input_text,
            original_output,
        )

Chat Completion - Original: 0.6s

Replaced word 'ĠThe': 5.99s - get_replacement_token()
Original:  The clock shows 9:47 PM. How many minutes 'til 10 PM?
Perturbed:  The clock shows 9:47 PM. How many minutes 'til 10 PM?
Chat Completion - Perturbed: 0.5s
cosine attributed_tokens ['13', ' minutes', '.']
prob_diff attributed_tokens ['13', ' minutes', '.']
token_displacement attributed_tokens ['13', ' minutes', '.']

Replaced word 'Ġclock': 6.09s - get_replacement_token()
Original:  The clock shows 9:47 PM. How many minutes 'til 10 PM?
Perturbed:   The Clock shows 9:47 PM. How many minutes 'til 10 PM?
Chat Completion - Perturbed: 0.56s
cosine attributed_tokens ['13', ' minutes', '.']
prob_diff attributed_tokens ['13', ' minutes', '.']
token_displacement attributed_tokens ['13', ' minutes', '.']

Replaced word 'Ġshows': 6.08s - get_replacement_token()
Original:  The clock shows 9:47 PM. How many minutes 'til 10 PM?
Perturbed:   The clock show 9:47 PM. How many minutes 'til 10 PM?
Chat Comple

In [6]:
display(logger.df_experiments)

Unnamed: 0,exp_id,original_input,original_output,perturbation_strategy,perturb_word_wise,duration
0,1,The clock shows 9:47 PM. How many minutes 'til...,13 minutes.,distant,True,11.168126
1,2,The clock shows 9:47 PM. How many minutes 'til...,13 minutes.,fixed,True,7.974742
2,3,The clock shows 9:47 PM. How many minutes 'til...,13 minutes.,nearest,True,106.18574


## Quantitative Metric

In [11]:
relevant_input_ids = [3,5,13]
relevant_output_ids = [0]

In [17]:
perturbation_strategy = ['distant', 'fixed', 'nearest']
for exp_id in range(1,4):
    for attribution_strategy in ['cosine', 'prob_diff', 'token_displacement']:
        for output_id in relevant_output_ids:
            success = 0
            total = 0
            mean_attr_score = logger.df_token_attribution_matrix[(logger.df_token_attribution_matrix['exp_id'] == exp_id) & (logger.df_token_attribution_matrix['attribution_strategy'] == attribution_strategy) & (logger.df_token_attribution_matrix['output_token_pos'] == output_id)]['attr_score'].mean()
            for input_id in relevant_input_ids:
                attr_score = logger.df_token_attribution_matrix[(logger.df_token_attribution_matrix['exp_id'] == exp_id) & (logger.df_token_attribution_matrix['attribution_strategy'] == attribution_strategy) & (logger.df_token_attribution_matrix['output_token_pos'] == output_id) & (logger.df_token_attribution_matrix['input_token_pos'] == input_id)]['attr_score'].item()
                if attr_score >= mean_attr_score:
                    success += 1
                total += 1

        print(f'Metric score for perturbation_strategy: {perturbation_strategy[exp_id-1]} and attribution_strategy: {attribution_strategy} - {success/total}')

Metric score for perturbation_strategy: distant and attribution_strategy: cosine - 0.6666666666666666
Metric score for perturbation_strategy: distant and attribution_strategy: prob_diff - 0.6666666666666666
Metric score for perturbation_strategy: distant and attribution_strategy: token_displacement - 0.6666666666666666
Metric score for perturbation_strategy: fixed and attribution_strategy: cosine - 0.6666666666666666
Metric score for perturbation_strategy: fixed and attribution_strategy: prob_diff - 0.6666666666666666
Metric score for perturbation_strategy: fixed and attribution_strategy: token_displacement - 0.6666666666666666
Metric score for perturbation_strategy: nearest and attribution_strategy: cosine - 0.3333333333333333
Metric score for perturbation_strategy: nearest and attribution_strategy: prob_diff - 1.0
Metric score for perturbation_strategy: nearest and attribution_strategy: token_displacement - 0.6666666666666666


## More crafted examples

In [27]:
input_texts = ["Maria is 37 years old today. How many years till she's 50?"]
logger = ExperimentLogger()
for input_text in input_texts:
    for perturbation_strategy in ["distant", "fixed", "nearest"]:
        original_output = calculate_token_importance(
            input_text,
            model,
            tokenizer,
            perturbation_strategy,
            attribution_strategies=["cosine", "prob_diff", "token_displacement"],
            logger=logger,
            perturb_word_wise=True,
        )

        print(
            input_text,
            original_output,
        )

Chat Completion - Original: 2.31s

Replaced word 'ĠMaria': 0.15s - get_replacement_token()
Original:  Maria is 37 years old today. How many years till she's 50?
Perturbed:   folders is 37 years old today. How many years till she's 50?
Chat Completion - Perturbed: 0.47s
cosine attributed_tokens ['Maria', ' is', ' ', '13', ' years', ' away', ' from', ' turning', ' ', '50', '.']
prob_diff attributed_tokens ['Maria', ' is', ' ', '13', ' years', ' away', ' from', ' turning', ' ', '50', '.']
token_displacement attributed_tokens ['Maria', ' is', ' ', '13', ' years', ' away', ' from', ' turning', ' ', '50', '.']

Replaced word 'Ġis': 0.15s - get_replacement_token()
Original:  Maria is 37 years old today. How many years till she's 50?
Perturbed:   Maria testified 37 years old today. How many years till she's 50?
Chat Completion - Perturbed: 1.38s
cosine attributed_tokens ['Maria', ' is', ' ', '13', ' years', ' away', ' from', ' turning', ' ', '50', '.']
prob_diff attributed_tokens ['Maria', ' i

In [28]:
display(logger.df_experiments)

Unnamed: 0,exp_id,original_input,original_output,perturbation_strategy,perturb_word_wise,duration
0,1,Maria is 37 years old today. How many years ti...,Maria is 13 years away from turning 50.,distant,True,14.900491
1,2,Maria is 37 years old today. How many years ti...,Maria is 13 years away from turning 50.,fixed,True,11.295896
2,3,Maria is 37 years old today. How many years ti...,Maria is 13 years away from turning 50.,nearest,True,105.708576


In [29]:
perturbation_strategy = ['distant', 'fixed', 'nearest']
for exp_id in range(1,4):
    for attribution_strategy in ['cosine', 'prob_diff', 'token_displacement']:
        for output_id in relevant_output_ids:
            success = 0
            total = 0
            mean_attr_score = logger.df_token_attribution_matrix[(logger.df_token_attribution_matrix['exp_id'] == exp_id) & (logger.df_token_attribution_matrix['attribution_strategy'] == attribution_strategy) & (logger.df_token_attribution_matrix['output_token_pos'] == output_id)]['attr_score'].mean()
            for input_id in relevant_input_ids:
                attr_score = logger.df_token_attribution_matrix[(logger.df_token_attribution_matrix['exp_id'] == exp_id) & (logger.df_token_attribution_matrix['attribution_strategy'] == attribution_strategy) & (logger.df_token_attribution_matrix['output_token_pos'] == output_id) & (logger.df_token_attribution_matrix['input_token_pos'] == input_id)]['attr_score'].item()
                if attr_score >= mean_attr_score:
                    success += 1
                total += 1

        print(f'Metric score for perturbation_strategy: {perturbation_strategy[exp_id-1]} and attribution_strategy: {attribution_strategy} - {success/total}')

Metric score for perturbation_strategy: distant and attribution_strategy: cosine - 0.3333333333333333
Metric score for perturbation_strategy: distant and attribution_strategy: prob_diff - 0.3333333333333333
Metric score for perturbation_strategy: distant and attribution_strategy: token_displacement - 0.3333333333333333
Metric score for perturbation_strategy: fixed and attribution_strategy: cosine - 0.3333333333333333
Metric score for perturbation_strategy: fixed and attribution_strategy: prob_diff - 0.3333333333333333
Metric score for perturbation_strategy: fixed and attribution_strategy: token_displacement - 0.3333333333333333
Metric score for perturbation_strategy: nearest and attribution_strategy: cosine - 1.0
Metric score for perturbation_strategy: nearest and attribution_strategy: prob_diff - 0.0
Metric score for perturbation_strategy: nearest and attribution_strategy: token_displacement - 1.0


## All examples

In [None]:
input_texts = ["The clock shows 9:47 PM. How many minutes 'til 10 PM?", "Maria is 37 years old today. How many years till she's 50?"]