In [1]:
import torch
import torch.nn.functional as F

from revllm.gpt import GPT, GPTConfig
# from revllm.model_wrapper import TokenizerWrapper
from transformers import GPT2Tokenizer

In [101]:
def top_k_intersection_score(prob_tensor_a : torch.tensor, prob_tensor_b: torch.tensor, k: int) -> float:
    """
    Usin intersection scores, as defined in: https://arxiv.org/pdf/2305.13417.pdf
    Given: 
        - two 1D probability tensors
        - k: int
    Returns: their top k intersection score
    """

    assert all((len(prob_tensor_a.shape) == 1, len(prob_tensor_b.shape) == 1, prob_tensor_a.shape == prob_tensor_b.shape))

    topk_a = torch.topk(prob_tensor_a, k).indices.tolist()
    topk_b = torch.topk(prob_tensor_b, k).indices.tolist()
    intersection = set(topk_a).intersection(set(topk_b))
    
    return len(intersection) / k

def get_top_k_intersection_scores(probabilities_tensor: torch.Tensor, k: int) -> torch.Tensor:

    """
    Given: 
        - [layer, word, probabilities] tensor
        - k: int
    Returns: [layer, word, 1] tensor of intersection scores.
    """

    num_layers = probabilities_tensor.shape[0]
    num_words = probabilities_tensor.shape[1]

    final_layer_probabilities = probabilities_tensor[-1, :, :]

    intersection_scores_tensor = torch.zeros(
        (num_layers, num_words)
    ).unsqueeze(-1)

    for layer in range(num_layers):
        for word_index in range(num_words):
            layer_word_probabilities = probabilities_tensor[layer, word_index, :]
            final_word_probabilities = final_layer_probabilities[word_index, :]    
            intersection_scores_tensor[layer, word_index] = top_k_intersection_score(
                layer_word_probabilities,
                final_word_probabilities,
                k
            )

    return intersection_scores_tensor

def get_top_k_intersection_scores_control(k: int, probabilities_tensor: torch.Tensor) -> torch.Tensor:

    """
    Given a tensor whose last dimension represents probabilities, returns a tensor of intersection scores.
    """
    topk_tensor = torch.topk(probabilities_tensor, k, dim=-1).indices
    topk_tensor_final_output = topk_tensor[-1, :, :]
    topk_indices_final_output_dict = {}
    intersection_scores_tensor = torch.zeros(
        (topk_tensor.shape[0], topk_tensor.shape[1])
    ).unsqueeze(-1)

    for word in range(topk_tensor_final_output.shape[0]):
        topk_indices_final_output_dict[word] = set(topk_tensor_final_output[word].tolist())

    for layer in range(topk_tensor.shape[0]):
        layer_output_topk_tensor = topk_tensor[layer, :, :]
        for word in range(layer_output_topk_tensor.shape[0]):
            intersection = set(layer_output_topk_tensor[word].tolist()).intersection(
                topk_indices_final_output_dict[word]
            )
            intersection_scores_tensor[layer, word] = len(intersection) / k

    return intersection_scores_tensor

In [108]:
exp = F.softmax(torch.randn(4,2,8),dim = -1)
print(exp[-1,:,:])

print(torch.topk(exp, 3, dim=-1).indices)
int_scores = get_top_k_intersection_scores(exp, 3)
int_scores

tensor([[0.0285, 0.0952, 0.1723, 0.0839, 0.1112, 0.3831, 0.0743, 0.0514],
        [0.1357, 0.1781, 0.0119, 0.0201, 0.2026, 0.0221, 0.3048, 0.1246]])
tensor([[[4, 2, 7],
         [6, 0, 2]],

        [[6, 7, 4],
         [2, 0, 5]],

        [[1, 3, 6],
         [7, 5, 2]],

        [[5, 2, 4],
         [6, 4, 1]]])


tensor([[[0.6667],
         [0.3333]],

        [[0.3333],
         [0.0000]],

        [[0.0000],
         [0.0000]],

        [[1.0000],
         [1.0000]]])

In [53]:
length = 10
k = 4
a = F.softmax(torch.randn(length))
b = F.softmax(torch.randn(length))

score = top_k_intersection_score(a, b, k)

topk_a = torch.topk(a, k).indices.tolist()
topk_b = torch.topk(b, k).indices.tolist()

print(f"Top {k} intersection score: {score}")
print(f"Top {k} a: {topk_a}")
print(f"Top {k} b: {topk_b}")

Top 4 intersection score: 0.75
Top 4 a: [2, 8, 5, 6]
Top 4 b: [0, 8, 6, 2]


  a = F.softmax(torch.randn(length))
  b = F.softmax(torch.randn(length))


In [49]:
a

tensor([0.0624, 0.0388, 0.0655, 0.1238, 0.1454, 0.0911, 0.1353, 0.0905, 0.1608,
        0.0864])

In [44]:
    for word in range(topk_tensor_final_output.shape[0]):
        topk_indices_final_output_dict[word] = set(topk_tensor_final_output[word].tolist())

    for layer in range(topk_tensor.shape[0]):
        layer_output_topk_tensor = topk_tensor[layer, :, :]
        for word in range(layer_output_topk_tensor.shape[0]):
            intersection = set(layer_output_topk_tensor[word].tolist()).intersection(
                topk_indices_final_output_dict[word]
            )
            intersection_scores_tensor[layer, word] = len(intersection) / k

    return intersection_scores_tensor

[2, 1, 5, 4]
[9, 3, 6, 0]


In [2]:
model = GPT.from_pretrained('gpt2')
tokenizer = GPT2Tokenizer.from_pretrained('gpt2')

loading weights from pretrained gpt: gpt2
forcing vocab_size=50257, block_size=1024, bias=True


In [3]:
context = "The quick brown fox jumps over the lazy dog"
input_ids = tokenizer.encode(context, return_tensors='pt')
position_ids = torch.arange(input_ids.size(1)).unsqueeze(0)

In [3]:
model

GPT(
  (transformer): ModuleDict(
    (wte): Embedding(50257, 768)
    (wpe): Embedding(1024, 768)
    (drop): Dropout(p=0.0, inplace=False)
    (h): ModuleList(
      (0-11): 12 x Block(
        (ln_1): LayerNorm()
        (attn): CausalSelfAttention(
          (c_attn): Linear(in_features=768, out_features=2304, bias=True)
          (c_proj): Linear(in_features=768, out_features=768, bias=True)
          (attn_dropout): Dropout(p=0.0, inplace=False)
          (resid_dropout): Dropout(p=0.0, inplace=False)
        )
        (ln_2): LayerNorm()
        (mlp): MLP(
          (c_fc): Linear(in_features=768, out_features=3072, bias=True)
          (gelu): GELU(approximate='none')
          (c_proj): Linear(in_features=3072, out_features=768, bias=True)
          (dropout): Dropout(p=0.0, inplace=False)
        )
      )
    )
    (ln_f): LayerNorm()
  )
  (lm_head): Linear(in_features=768, out_features=50257, bias=False)
)

In [7]:
output_logits = model(input_ids)

# TODO:
* Extract attention scores for each level
* I've done it with gpt2 (it's an argument in the model) - easy change to make here?

In [14]:
from transformers import GPT2Tokenizer, GPT2Model
import torch

tokenizer = GPT2Tokenizer.from_pretrained('gpt2')
model = GPT2Model.from_pretrained('gpt2', output_attentions=True)

input_ids = tokenizer.encode(context, return_tensors='pt')

# Run the model and get attention scores
outputs = model(input_ids)
attention = outputs.attentions  # Get attention scores



In [16]:
attention.shape

AttributeError: 'tuple' object has no attribute 'shape'

In [None]:

# Specify the layer and head you want to extract scores from
layer = 0  # For example, first layer
head = 0   # For example, first head
attention_scores = attention[layer][0, head]  # 0 for batch index, as we have a single input

# attention_scores now contains the attention scores for the specified layer and head
# for the input text. It's a matrix of shape (sequence_length, sequence_length),
# where each entry [i, j] represents the attention from token i to token j.

# To work with or visualize these scores further, you can convert them to numpy, for example:
attention_scores_np = attention_scores.detach().numpy()

print(attention_scores_np)

# Batch Processing

In [3]:
# def predict_hf(context: str) -> str:
#     encoded_input = tokenizer(context, return_tensors="pt").to(device)
#     input_ids = encoded_input["input_ids"].to(device)
#     attention_mask = encoded_input["attention_mask"].to(device)

#     output = model_hf.generate(**encoded_input, max_length=len(input_ids) + 15, do_sample=False)

#     return tokenizer.decode(output[0])

# Generalize to context

In [10]:
# import time
# import warnings

# import torch

# import logging

# logging.getLogger("transformers").setLevel(logging.ERROR)
# device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
# start_time = time.time()

# context = "What is the capital of France?"

# print(f"Initial context: {context}")
# print("")
# context_with_prediction = predict_hf(context)

# print(context_with_prediction)
# print("")

# num_tokens_to_generate = 9

# attributions_dict = {}
# print("Predicted tokens, with attributions in order:")
# print("(note: the first tokens are spaces)")

# for _ in range(num_tokens_to_generate):
#     predicted_token, predicted_token_id, input_ids, attention_mask, attributions = igs_nano(context)
#     attributions_dict[predicted_token] = attributions

#     # Append the predicted token ID to the input
#     predicted_token_tensor = torch.tensor([[predicted_token_id]], dtype=torch.long)
#     input_ids = torch.cat((input_ids, predicted_token_tensor), dim=1)
#     context = tokenizer.decode(input_ids[0])

# end_time = time.time()
# print(f"Time elapsed: {end_time - start_time}")
# print("")
# print(tokenizer.decode(input_ids[0]))


Initial context: What is the capital of France?

What is the capital of France?

The capital of France is Paris.

Predicted tokens, with attributions in order:
(note: the first tokens are spaces)
Predicted Token: 


Ġof: 0.4667750298976898
Ġis: 0.09483694285154343
What: -0.20078647136688232
ĠFrance: -0.3289549946784973
Ġthe: -0.3516395092010498
?: -0.4432636499404907
Ġcapital: -0.5517855882644653

Predicted Token: 


What: -0.06957273185253143
Ġof: -0.23914484679698944
Ċ: -0.24015489220619202
Ġcapital: -0.3045584261417389
Ġthe: -0.33885839581489563
?: -0.41570526361465454
Ġis: -0.4330602288246155
ĠFrance: -0.5588936805725098

Predicted Token: 


?: 0.002080421196296811
Ġof: -0.0015987710794433951
What: -0.0022416578140109777
Ġcapital: -0.0026813633739948273
ĠFrance: -0.0035091438330709934
Ġthe: -0.00453113904222846
Ġis: -0.0058142454363405704
ĊĊ: -0.9999570846557617

Predicted Token: The

Ġof: 0.16130423545837402
Ġis: 0.14146742224693298
?: -0.08470432460308075
What: -0.219400629401206

KeyboardInterrupt: 

# For one word

In [None]:
# def igs_nano(context: str, n_steps: int = 50) -> tuple[str, int, torch.Tensor]:
#     encoded_input = tokenizer(context, return_tensors="pt") #returns a dict
#     input_ids = encoded_input["input_ids"]  # [1, 7]
#     baseline_input_ids = torch.zeros_like(input_ids) #[1, 7]

#     attention_mask = encoded_input["attention_mask"] # [1, 7]
#     all_tokens = tokenizer.convert_ids_to_tokens(input_ids[0].detach().tolist())    # [1, 7]

#     input_embeddings = model_nano.transformer.wte(input_ids).to(device) #[1, 7, 768]
#     baseline_embeddings = model_nano.transformer.wte(baseline_input_ids).to(device) #[1, 7, 768]

#     model_nano.eval()

#     output_logits = model_nano(input_ids)[0] #[1, 7, 50257]
#     next_token_logits = output_logits[0, 0, :]  # [50257]
#     predicted_token_id = torch.argmax(next_token_logits).item() # [1]

#     position_ids = torch.arange(0, input_embeddings.size(1)).unsqueeze(0) #[1, 7]
#     position_embeddings = model_nano.transformer.wpe(position_ids) #[1, 7, 768]

# # choose target_word_index
#     for target_word_index in range(input_embeddings.size(1)):
    
#         target_word_embedding = input_embeddings[0,target_word_index,:].unsqueeze(0) #[1, 768]
#         target_word_baseline = baseline_embeddings[0,target_word_index,:].unsqueeze(0) #[1, 768]

#         alphas = torch.linspace(0, 1, steps=n_steps).unsqueeze(-1) #[50, 1]

#         step_embeddings = target_word_baseline + alphas * (target_word_embedding - target_word_baseline) #[50, 768]
#         step_embeddings.requires_grad_(True) #[50, 768]
#         step_embeddings.retain_grad()
#         step_embeddings.grad = None

#         forward_embeddings = step_embeddings + position_embeddings #[50.768]
#         forward_embeddings = model_nano.transformer.drop(forward_embeddings) #[50, 768]   

#         for block in model_nano.transformer.h:
#             forward_embeddings = block(forward_embeddings) #[50, 768]

#         forward_embeddings = model_nano.transformer.ln_f(forward_embeddings) #[50, 768]
#         output_at_step = model_nano.lm_head(forward_embeddings) #[50, 50257]

#         class_output_at_step = output_at_step[:, predicted_token_id] #[50]
#         summed_output_for_gradient_computation = class_output_at_step.sum() #[1]
#         summed_output_for_gradient_computation.backward(retain_graph=True)
#         # class_output_at_step.backward(retain_graph=True)
#         # DO THINGS

#         assert step_embeddings.grad is not None

#         step_embeddings_grad_pre_sum = step_embeddings.grad/n_steps #[50, 768]

#         target_word_igs = step_embeddings_grad_pre_sum.sum(dim=0) #[1, 768]
#         target_word_igs = target_word_igs * (target_word_embedding - target_word_baseline) #[1, 768]

# Interrupt Model at Will

In [None]:
# context = "What is the capital of France?"

# encoded_input = tokenizer(context, return_tensors='pt')
# input_ids = encoded_input['input_ids']
# attention_mask = encoded_input['attention_mask']

In [None]:
# def predict_with_interruption(context: str, 
#                               layer: Literal['word_embeddings', 
#                                              'position_embeddings',
#                                              'embedding_layer',
#                                              'block_0',
#                                              'block_1',
#                                              'block_2',
#                                              'block_3',
#                                              'block_4',
#                                              'block_5',
#                                              'block_6',
#                                              'block_7',
#                                              'block_8',
#                                              'block_9',
#                                              'block_10',
#                                              'block_11']
#                                ) -> tuple[int, torch.Tensor]:

#     blocks_list = ['block_0',
#                    'block_1',
#                    'block_2',
#                    'block_3',
#                    'block_4',
#                    'block_5',
#                    'block_6',
#                    'block_7',
#                    'block_8',
#                    'block_9',
#                    'block_10',
#                    'block_11']

#     input_ids = tokenizer.encode(context, return_tensors='pt')
#     word_embeddings = model.transformer.wte(input_ids)

#     if layer == word_embeddings:
#         extracted_entity = word_embeddings

#     position_ids = torch.arange(0, input_ids.size(-1)).unsqueeze(0)
#     position_embeddings = model.transformer.wpe(position_ids)

#     model_forward_embeddings = word_embeddings + position_embeddings

#     if layer == position_embeddings:
#         extracted_entity = model_forward_embeddings

#     model_forward_embeddings = model.transformer.drop(model_forward_embeddings)

#     if layer == embedding_layer:
#         extracted_entity = model_forward_embeddings

#     block_counter = 0
#     for block in model.transformer.h:
#         model_forward_embeddings = block(model_forward_embeddings)[0]
#         if layer == blocks_list[block_counter]:
#             extracted_entity = model_forward_embeddings
#         block_counter += 1

#     model_forward_embeddings = model.transformer.ln_f(model_forward_embeddings)

#     model_forward_embeddings = model.lm_head(model_forward_embeddings)

#     predicted_next_token_logits = model_forward_embeddings[0, -1, :]
#     predicted_next_token = torch.argmax(predicted_next_token_logits).item()

#     return predicted_next_token, extracted_entity

In [None]:
# interrupt_at_layer(8)

12


198

In [None]:
# from transformers import GPT2Tokenizer, GPT2Model
# tokenizer = GPT2Tokenizer.from_pretrained('distilgpt2')
# model = GPT2Model.from_pretrained('distilgpt2')
# text = "Replace me by any text you'd like."
# encoded_input = tokenizer(text, return_tensors='pt')
# output = model(**encoded_input)


Downloading vocab.json:   0%|          | 0.00/1.04M [00:00<?, ?B/s]

Downloading merges.txt:   0%|          | 0.00/456k [00:00<?, ?B/s]

Downloading tokenizer.json:   0%|          | 0.00/1.36M [00:00<?, ?B/s]

Downloading config.json:   0%|          | 0.00/762 [00:00<?, ?B/s]

Downloading model.safetensors:   0%|          | 0.00/353M [00:00<?, ?B/s]