In [1]:
import os
import sys
import timeit
from time import time
from typing import List

import openai
from dotenv import load_dotenv
from transformers import (
    GPT2LMHeadModel,
    GPT2Tokenizer,
    PreTrainedModel,
    PreTrainedTokenizer,
)

notebook_path = os.path.abspath(get_ipython().starting_dir)
parent_path = os.path.dirname(notebook_path)

sys.path.append(parent_path)

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
from attribution.api_attribution import APILLMAttributor
from attribution.experiment_logger import ExperimentLogger
from attribution.attribution_metrics import (
    cosine_similarity_attribution,
    token_displacement,
    token_prob_difference,
)
from attribution.token_perturbation import FixedPerturbationStrategy, NthNearestPerturbationStrategy

In [6]:
attributor = APILLMAttributor()

input_texts = ["The clock shows 9:47 PM. How many minutes 'til 10?"]

perturbation_strategies = [
    FixedPerturbationStrategy(),
    NthNearestPerturbationStrategy(n=0),
    NthNearestPerturbationStrategy(n=-1),
    # (PerturbationStrategy.NTH_NEAREST, {"n": 10}),
]

# Initialize the logger
logger = ExperimentLogger()

# Perform the experiment
for input_text in input_texts:
    for perturbation_strategy in perturbation_strategies:
        original_output = attributor.compute_attributions(
            input_text,
            perturbation_strategy=perturbation_strategy,
            attribution_strategies=["cosine", "prob_diff", "token_displacement"],
            logger=logger,
            perturb_word_wise=True,
        )

        print(
            input_text,
            original_output,
        )

# Display the results
display(logger.df_experiments)
logger.print_sentence_attribution()



The clock shows 9:47 PM. How many minutes 'til 10? None




The clock shows 9:47 PM. How many minutes 'til 10? None




The clock shows 9:47 PM. How many minutes 'til 10? None


Unnamed: 0,exp_id,original_input,original_output,perturbation_strategy,perturb_word_wise,duration
0,1,The clock shows 9:47 PM. How many minutes 'til...,13 minutes.,fixed,True,10.263212
1,2,The clock shows 9:47 PM. How many minutes 'til...,13 minutes.,nth_nearest (n=0),True,11.657135
2,3,The clock shows 9:47 PM. How many minutes 'til...,13 minutes.,nth_nearest (n=-1),True,10.759343


Unnamed: 0,exp_id,attribution_strategy,perturbation_strategy,perturb_word_wise,token_1,token_2,token_3,token_4,token_5,token_6,token_7,token_8,token_9,token_10,token_11,token_12,token_13,token_14,token_15
0,1,cosine,fixed,True,The 0.13,clock 0.13,shows 0.00,9 0.16,: 0.16,47 0.16,PM 0.12,. 0.12,How 0.13,many 0.13,minutes 0.00,' 0.00,til 0.00,10 0.17,? 0.17
1,1,prob_diff,fixed,True,The 0.72,clock 0.72,shows 0.10,9 0.81,: 0.81,47 0.81,PM 0.72,. 0.72,How 0.78,many 0.68,minutes 0.09,' 0.06,til 0.06,10 0.80,? 0.80
2,1,token_displacement,fixed,True,The 13.67,clock 12.67,shows 0.00,9 17.67,: 17.67,47 17.67,PM 13.67,. 13.67,How 13.67,many 13.67,minutes 0.00,' 0.00,til 0.00,10 14.67,? 14.67
3,2,cosine,nth_nearest (n=0),True,The 0.13,clock 0.13,shows 0.13,9 0.13,: 0.13,47 0.13,PM 0.13,. 0.13,How 0.13,many 0.13,minutes 0.13,' 0.13,til 0.13,10 0.13,? 0.13
4,2,prob_diff,nth_nearest (n=0),True,The 0.68,clock 0.71,shows 0.67,9 0.68,: 0.68,47 0.68,PM 0.71,. 0.71,How 0.68,many 0.71,minutes 0.71,' 0.67,til 0.67,10 0.66,? 0.66
5,2,token_displacement,nth_nearest (n=0),True,The 13.67,clock 13.00,shows 13.00,9 12.33,: 12.33,47 12.33,PM 12.33,. 12.33,How 13.67,many 12.33,minutes 13.33,' 13.00,til 13.00,10 13.67,? 13.67
6,3,cosine,nth_nearest (n=-1),True,The 0.00,clock 0.13,shows 0.13,9 0.15,: 0.15,47 0.15,PM 0.12,. 0.12,How 0.13,many 0.13,minutes 0.13,' 0.11,til 0.11,10 0.11,? 0.11
7,3,prob_diff,nth_nearest (n=-1),True,The 0.05,clock 0.69,shows 0.67,9 0.81,: 0.81,47 0.81,PM 0.80,. 0.80,How 0.79,many 0.74,minutes 0.76,' 0.27,til 0.27,10 0.28,? 0.28
8,3,token_displacement,nth_nearest (n=-1),True,The 0.00,clock 13.00,shows 12.67,9 20.00,: 20.00,47 20.00,PM 13.67,. 13.67,How 13.67,many 13.67,minutes 12.00,' 6.67,til 6.67,10 6.67,? 6.67


In [7]:
logger.print_attribution_matrix(exp_id=1)
logger.print_attribution_matrix(exp_id=1, attribution_strategy="prob_diff")
logger.print_attribution_matrix(1, "cosine")
logger.print_attribution_matrix(2, attribution_strategy="cosine")
logger.print_attribution_matrix(2, attribution_strategy="prob_diff")

Attribution matrix for cosine with perturbation strategy fixed:
Input Tokens (Rows) vs. Output Tokens (Columns)


Unnamed: 0,13 (0),minutes (1),. (2)
The (0),0.688956,0.759586,0.701984
clock (1),0.688956,0.759586,0.701984
shows (2),0.0,-0.0,0.0
9 (3),0.688956,0.759586,0.817238
: (4),0.688956,0.759586,0.817238
47 (5),0.688956,0.759586,0.817238
PM (6),0.688956,0.759586,0.701984
. (7),0.688956,0.759586,0.701984
How (8),0.688956,0.759586,0.701984
many (9),0.688956,0.759586,0.701984


Attribution matrix for prob_diff with perturbation strategy fixed:
Input Tokens (Rows) vs. Output Tokens (Columns)


Unnamed: 0,13 (0),minutes (1),. (2)
The (0),0.526091,0.999904,0.638828
clock (1),0.513809,0.999904,0.638828
shows (2),0.293041,2e-06,0.012882
9 (3),0.799628,0.999904,0.638828
: (4),0.799628,0.999904,0.638828
47 (5),0.799628,0.999904,0.638828
PM (6),0.509623,0.999904,0.638828
. (7),0.509623,0.999904,0.638828
How (8),0.69522,0.999904,0.638828
many (9),0.411303,0.999904,0.638828


Attribution matrix for token_displacement with perturbation strategy fixed:
Input Tokens (Rows) vs. Output Tokens (Columns)


Unnamed: 0,13 (0),minutes (1),. (2)
The (0),1.0,20.0,20.0
clock (1),1.0,17.0,20.0
shows (2),0.0,0.0,0.0
9 (3),20.0,13.0,20.0
: (4),20.0,13.0,20.0
47 (5),20.0,13.0,20.0
PM (6),1.0,20.0,20.0
. (7),1.0,20.0,20.0
How (8),1.0,20.0,20.0
many (9),1.0,20.0,20.0


Attribution matrix for prob_diff with perturbation strategy fixed:
Input Tokens (Rows) vs. Output Tokens (Columns)


Unnamed: 0,13 (0),minutes (1),. (2)
The (0),0.526091,0.999904,0.638828
clock (1),0.513809,0.999904,0.638828
shows (2),0.293041,2e-06,0.012882
9 (3),0.799628,0.999904,0.638828
: (4),0.799628,0.999904,0.638828
47 (5),0.799628,0.999904,0.638828
PM (6),0.509623,0.999904,0.638828
. (7),0.509623,0.999904,0.638828
How (8),0.69522,0.999904,0.638828
many (9),0.411303,0.999904,0.638828


Attribution matrix for cosine with perturbation strategy fixed:
Input Tokens (Rows) vs. Output Tokens (Columns)


Unnamed: 0,13 (0),minutes (1),. (2)
The (0),0.688956,0.759586,0.701984
clock (1),0.688956,0.759586,0.701984
shows (2),0.0,-0.0,0.0
9 (3),0.688956,0.759586,0.817238
: (4),0.688956,0.759586,0.817238
47 (5),0.688956,0.759586,0.817238
PM (6),0.688956,0.759586,0.701984
. (7),0.688956,0.759586,0.701984
How (8),0.688956,0.759586,0.701984
many (9),0.688956,0.759586,0.701984


Attribution matrix for cosine with perturbation strategy nth_nearest (n=0):
Input Tokens (Rows) vs. Output Tokens (Columns)


Unnamed: 0,13 (0),minutes (1),. (2)
The (0),0.688956,0.759586,0.701984
clock (1),0.688956,0.759586,0.701984
shows (2),0.688956,0.759586,0.701984
9 (3),0.688956,0.759586,0.701984
: (4),0.688956,0.759586,0.701984
47 (5),0.688956,0.759586,0.701984
PM (6),0.688956,0.759586,0.701984
. (7),0.688956,0.759586,0.701984
How (8),0.688956,0.759586,0.701984
many (9),0.688956,0.759586,0.701984


Attribution matrix for prob_diff with perturbation strategy nth_nearest (n=0):
Input Tokens (Rows) vs. Output Tokens (Columns)


Unnamed: 0,13 (0),minutes (1),. (2)
The (0),0.432522,0.999891,0.612652
clock (1),0.518021,0.999891,0.612652
shows (2),0.401822,0.999891,0.612652
9 (3),0.432369,0.999891,0.612652
: (4),0.432369,0.999891,0.612652
47 (5),0.432369,0.999891,0.612652
PM (6),0.518021,0.999891,0.612652
. (7),0.518021,0.999891,0.612652
How (8),0.432522,0.999891,0.612652
many (9),0.518021,0.999891,0.612652
