In [9]:
import os
import torch
from transformers import AutoTokenizer, AutoModel, PreTrainedModel, PreTrainedTokenizer, GPT2LMHeadModel, GPT2Tokenizer
from sklearn.neighbors import NearestNeighbors
from dotenv import load_dotenv
from typing import List
from openai import OpenAI

True

In [2]:
load_dotenv()
client = OpenAI(api_key=os.getenv('OPENAI_API_KEY'))

In [3]:
model = GPT2LMHeadModel.from_pretrained('gpt2')  # or any other checkpoint
tokenizer = GPT2Tokenizer.from_pretrained('gpt2')

word_token_embeddings = model.transformer.wte.weight  # Word Token Embeddings 
position_embeddings = model.transformer.wpe.weight  # Word Position Embeddings 
token_cosine_distances = None

In [6]:
token_cosine_distances = calculate_cosine_distances_between_tokens(word_token_embeddings)
print(token_cosine_distances.shape)

torch.Size([50257, 50257])


In [35]:
def get_model_output(input: str):
    response = client.chat.completions.create(
        model="gpt-3.5-turbo",
        messages=[{"role": "user", "content": input}],
        logprobs=True,
        top_logprobs=20,
    )

    return response.choices[0]


# TOKEN PERTURBATION METHODS
def get_most_similar_tokens(token: str, model: PreTrainedModel, tokenizer: PreTrainedTokenizer, n_tokens: int = 1) -> List[str]:
    token_id = tokenizer.encode(token, add_special_tokens=False)[0]
    
    # Get the embedding of the token
    embeddings = model.transformer.wte.weight
    token_embedding = embeddings[token_id, :]

    # Fit the NearestNeighbors model to the embeddings
    nbrs = NearestNeighbors(n_neighbors=n_tokens, algorithm='ball_tree').fit(embeddings.detach().numpy())
    
    # Find the nearest neighbor
    indices = nbrs.kneighbors(
        token_embedding.detach().numpy().reshape(1, -1),
        n_neighbors=10,
        return_distance=False
    )

    return [tokenizer.decode(ix) for ix in indices.flatten()]

def calculate_cosine_distances_between_tokens(word_token_embeddings: torch.Tensor) -> torch.Tensor:
    normalized_embeddings = torch.nn.functional.normalize(word_token_embeddings, p=2, dim=1)
    dot_product = torch.matmul(normalized_embeddings, normalized_embeddings.T)
    return 1 - dot_product


# OUTPUT SIMILARITY MEASURES
def is_token_in_top_20(token, top_logprobs):
    top_20_tokens = set(logprob.token for logprob in top_logprobs)
    # print(token in top_20_tokens, f'token: `{token}`, pred: `{top_logprobs[0].token}` - {top_20_tokens}')
    return token in top_20_tokens


def all_tokens_in_top_20(initial_logprobs, new_logprobs):
    if (
        initial_logprobs is None
        or new_logprobs is None
        or initial_logprobs.content is None
        or new_logprobs.content is None
    ):
        return False

    return all(
        is_token_in_top_20(initial_token.token, new_token.top_logprobs)
        for initial_token, new_token
        in zip(initial_logprobs.content, new_logprobs.content)
    )


def total_logprob_difference(initial_logprobs, perturbed_logprobs):
    # Get the logprobs of the top 20 tokens for the initial and perturbed outputs
    initial_top_logprobs = {logprob.token: logprob.logprob for logprob in initial_logprobs.content}
    perturbed_top_logprobs = {logprob.token: logprob.logprob for logprob in perturbed_logprobs.content}

    # Calculate the total difference in logprobs
    total_difference = 0
    for token, initial_logprob in initial_top_logprobs.items():
        perturbed_logprob = perturbed_top_logprobs.get(token, 0)
        total_difference += abs(initial_logprob - perturbed_logprob)

    return total_difference


def max_logprob_difference(initial_logprobs, perturbed_logprobs):
    # Get the logprobs of the top 20 tokens for the initial and perturbed outputs
    initial_top_logprobs = {logprob.token: logprob.logprob for logprob in initial_logprobs.content}
    perturbed_top_logprobs = {logprob.token: logprob.logprob for logprob in perturbed_logprobs.content}

    # Calculate the maximum difference in logprobs
    max_difference = 0
    for token, initial_logprob in initial_top_logprobs.items():
        perturbed_logprob = perturbed_top_logprobs.get(token, 0)
        max_difference = max(max_difference, abs(initial_logprob - perturbed_logprob))

    return max_difference


def token_displacement(initial_logprobs, perturbed_logprobs):
    # Get the top 20 tokens for the initial and perturbed outputs
    initial_top_tokens = [logprob.token for logprob in initial_logprobs.content]
    perturbed_top_tokens = [logprob.token for logprob in perturbed_logprobs.content]

    # Calculate the total displacement of tokens
    total_displacement = 0
    for i, token in enumerate(initial_top_tokens):
        if token in perturbed_top_tokens:
            total_displacement += abs(i - perturbed_top_tokens.index(token))

    return total_displacement

In [26]:
def calculate_token_importance(input_sequence: str, model: PreTrainedModel, tokenizer: PreTrainedTokenizer):
    tokens = tokenizer.tokenize(input_sequence)
    initial_output = get_model_output(input_sequence)

    # Initialize a dictionary to store the importance of each token
    token_importance = {}

    for i, token in enumerate(tokens):
        print(i)
        # Get the most similar tokens to the current token
        similar_tokens = get_most_similar_tokens(token, model, tokenizer, n_tokens=20)

        # For each similar token
        for similar_token in similar_tokens:
            # Replace the current token with the similar token
            perturbed_input = " ".join(tokens[:i] + [similar_token] + tokens[i+1:])

            # Get the output logprobs for the perturbed input
            perturbed_output = get_model_output(perturbed_input)

            # If all output logprobs are no longer in the top 20
            if not all_tokens_in_top_20(initial_output.logprobs, perturbed_output.logprobs):
                # Assign importance to the current token based on the index of the similar token
                token_importance[token] = similar_tokens.index(similar_token)
                break

    return initial_output.message.content, token_importance

In [28]:
input_sequence = "Translate to French: 'I am Mike'"
output_message, token_importance = calculate_token_importance(input_sequence, model, tokenizer)

# Print the importance of each token
print(output_message)
for token, importance in token_importance.items():
    print(f"Token: {token}, Importance: {importance}")

0
1
2
3
4
5
6
7
8
9
ChatCompletionMessage(content='Je suis Mike.', role='assistant', function_call=None, tool_calls=None)
Token: Trans, Importance: 4
Token: late, Importance: 7
Token: Ġto, Importance: 1
Token: ĠFrench, Importance: 0
Token: I, Importance: 5
Token: Ġam, Importance: 7
Token: ĠMike, Importance: 1
Token: ', Importance: 1


In [36]:
# Inspect values of different output similarity measures

input_sequences = [
    "Translate the following English text to Spanish: 'Hello, how are you?'",
    "What is the capital of France?",
    "Who won the world series in 2020?",
]

# Define some example perturbations
perturbations = [
    "Translate the following English message to Spanish: 'Hello, how are you?'",
    "What is the population of France?",
    "Who won the world series in 2019?",
]

# For each input sequence and its perturbation
for input_sequence, perturbation in zip(input_sequences, perturbations):
    print(f"Input: {input_sequence}")
    print(f"Perturbation: {perturbation}")

    # Get the initial and perturbed output logprobs
    initial_output = get_model_output(input_sequence)
    perturbed_output = get_model_output(perturbation)

    # Calculate the total logprob difference
    print(initial_output.logprobs)
    total_diff = total_logprob_difference(initial_output.logprobs, perturbed_output.logprobs)
    print(f"Total Logprob Difference: {total_diff}")

    # Calculate the max logprob difference
    max_diff = max_logprob_difference(initial_output.logprobs, perturbed_output.logprobs)
    print(f"Max Logprob Difference: {max_diff}")

    # Calculate the token displacement
    displacement = token_displacement(initial_output.logprobs, perturbed_output.logprobs)
    print(f"Token Displacement: {displacement}")

    print("\n")

Input: Translate the following English text to Spanish: 'Hello, how are you?'
Perturbation: Translate the following English message to Spanish: 'Hello, how are you?'
ChoiceLogprobs(content=[ChatCompletionTokenLogprob(token="'H", bytes=[39, 72], logprob=-1.0393208, top_logprobs=[TopLogprob(token='Hola', bytes=[72, 111, 108, 97], logprob=-0.5348233), TopLogprob(token="'H", bytes=[39, 72], logprob=-1.0393208), TopLogprob(token='¡', bytes=[194, 161], logprob=-2.9327002), TopLogprob(token='"H', bytes=[34, 72], logprob=-4.978692), TopLogprob(token="'", bytes=[39], logprob=-8.790809), TopLogprob(token='"', bytes=[34], logprob=-10.078332), TopLogprob(token='The', bytes=[84, 104, 101], logprob=-10.162473), TopLogprob(token='-H', bytes=[45, 72], logprob=-10.2620325), TopLogprob(token="''", bytes=[39, 39], logprob=-10.335744), TopLogprob(token='Translation', bytes=[84, 114, 97, 110, 115, 108, 97, 116, 105, 111, 110], logprob=-10.907006), TopLogprob(token='´', bytes=[194, 180], logprob=-11.357045)

In [8]:
token = 'cat'
token_index = tokenizer.encode(token, add_special_tokens=False)[0]

# Get the cosine distances for the selected token
token_distances = token_cosine_distances[token_index]

sorted_indices = torch.argsort(token_distances)
percentiles = [0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9]
percentile_indices = [sorted_indices[int(len(sorted_indices) * p)] for p in percentiles]
token = tokenizer.decode(token_index)
percentile_tokens = [tokenizer.decode(idx) for idx in percentile_indices]

print(f"Token: {token}")
for p, t in zip(percentiles, percentile_tokens):
    print(f"{int(p * 100)}% percentile token: {t}")

Token: cat
10% percentile token: 2000
20% percentile token: tw
30% percentile token:  Sho
40% percentile token:  Armed
50% percentile token:  overwhelm
60% percentile token:  Wasteland
70% percentile token:  Pieces
80% percentile token:  bright
90% percentile token:  middle


In [38]:

initial_logprobs = get_output_logprobs('the 5 continents are europe, asia, africa')
new_logprobs = get_output_logprobs('all 5 continents are europe, asia, africa')

In [39]:
print(initial_logprobs.message)
print(new_logprobs.message)
all_tokens_in_top_20(initial_logprobs.logprobs, new_logprobs.logprobs)

ChatCompletionMessage(content=', north america, south america, and australia.', role='assistant', function_call=None, tool_calls=None)
ChatCompletionMessage(content=', Australia, and the Americas.', role='assistant', function_call=None, tool_calls=None)
True token: `,`, pred: `,` - {' ,', 'Australia', ' ', 'the', ' North', 'amer', ' the', 'The', 'am', 'north', '\n\n', ' australia', 'North', ' americ', ',', ' south', ' Australia', ' north', ' america', 'south'}
True token: ` north`, pred: ` North` - {' o', ' Antarctica', ' ', '  \n', ' O', ' \n', ' North', ' the', ' South', ' \n\n', 'am', ' australia', ' americ', ' America', ' south', ' Australia', ' north', ' Americas', ' america', ' Austral'}
False token: ` america`, pred: `,` - {' ,', ' /', ' and', ' O', '/A', '.', ' or', '/', '/P', ' (', ',', ',\n\n', '-O', ',and', ',\n', '/O', '/New', '/o', ' &', '(O'}


False

In [None]:
tokenizer = AutoTokenizer.from_pretrained('gpt2')
model = AutoModel.from_pretrained('gpt2')

# Define the input sequence
input_sequence = "Translate the following English text to French: 'Hello, how are you?'"

baseline_token = " "
tokens = tokenizer.tokenize(input_sequence)
embeddings = model(**tokenizer(tokens, return_tensors='pt'))[0][0].detach().numpy()
initial_output = get_output_logprobs(input_sequence)

for i, token in enumerate(tokens):
    print(f"Processing token: {token}")

    distant_tokens = get_increasingly_distant_tokens(token, embeddings, tokenizer)

    # For each distant token
    for distant_token in distant_tokens:
        print(f"  Replacing with distant token: {distant_token}")

        # Replace the current token with the distant token
        perturbed_input = " ".join(tokens[:i] + [distant_token] + tokens[i+1:])

        # Get the output logprobs for the perturbed input
        perturbed_logprobs = get_output_logprobs(perturbed_input)

        # If all output logprobs are no longer in the top 20
        if all_logprobs_not_in_top_20(perturbed_logprobs, initial_output.logprobs):
            # Assign importance to the current token based on the index of the distant token
            importance = get_importance(distant_tokens.index(distant_token))
            print(f"Importance of token '{token}': {importance}")
            break