In [1]:
import torch
from revllm.gpt import GPT, GPTConfig
# from revllm.model_wrapper import TokenizerWrapper
from transformers import GPT2Tokenizer, GPT2LMHeadModel

from typing import Literal

In [2]:
model_hf = GPT2LMHeadModel.from_pretrained('gpt2')
model_distil = GPT2LMHeadModel.from_pretrained('distilgpt2')
model_nano = GPT.from_pretrained('gpt2')
model_nano.eval()
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

tokenizer = GPT2Tokenizer.from_pretrained('gpt2')
# tokenizer_nano = TokenizerWrapper()

loading weights from pretrained gpt: gpt2
forcing vocab_size=50257, block_size=1024, bias=True
number of parameters: 123.65M


# Batch Processing

In [3]:
def predict_hf(context: str) -> str:
    encoded_input = tokenizer(context, return_tensors="pt").to(device)
    input_ids = encoded_input["input_ids"].to(device)
    attention_mask = encoded_input["attention_mask"].to(device)

    output = model_hf.generate(**encoded_input, max_length=len(input_ids) + 15, do_sample=False)

    return tokenizer.decode(output[0])

# Generalize to context

In [10]:
import time
import warnings

import torch

import logging

logging.getLogger("transformers").setLevel(logging.ERROR)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
start_time = time.time()

context = "What is the capital of France?"

print(f"Initial context: {context}")
print("")
context_with_prediction = predict_hf(context)

print(context_with_prediction)
print("")

num_tokens_to_generate = 9

attributions_dict = {}
print("Predicted tokens, with attributions in order:")
print("(note: the first tokens are spaces)")

for _ in range(num_tokens_to_generate):
    predicted_token, predicted_token_id, input_ids, attention_mask, attributions = igs_nano(context)
    attributions_dict[predicted_token] = attributions

    # Append the predicted token ID to the input
    predicted_token_tensor = torch.tensor([[predicted_token_id]], dtype=torch.long)
    input_ids = torch.cat((input_ids, predicted_token_tensor), dim=1)
    context = tokenizer.decode(input_ids[0])

end_time = time.time()
print(f"Time elapsed: {end_time - start_time}")
print("")
print(tokenizer.decode(input_ids[0]))


Initial context: What is the capital of France?

What is the capital of France?

The capital of France is Paris.

Predicted tokens, with attributions in order:
(note: the first tokens are spaces)
Predicted Token: 


Ġof: 0.4667750298976898
Ġis: 0.09483694285154343
What: -0.20078647136688232
ĠFrance: -0.3289549946784973
Ġthe: -0.3516395092010498
?: -0.4432636499404907
Ġcapital: -0.5517855882644653

Predicted Token: 


What: -0.06957273185253143
Ġof: -0.23914484679698944
Ċ: -0.24015489220619202
Ġcapital: -0.3045584261417389
Ġthe: -0.33885839581489563
?: -0.41570526361465454
Ġis: -0.4330602288246155
ĠFrance: -0.5588936805725098

Predicted Token: 


?: 0.002080421196296811
Ġof: -0.0015987710794433951
What: -0.0022416578140109777
Ġcapital: -0.0026813633739948273
ĠFrance: -0.0035091438330709934
Ġthe: -0.00453113904222846
Ġis: -0.0058142454363405704
ĊĊ: -0.9999570846557617

Predicted Token: The

Ġof: 0.16130423545837402
Ġis: 0.14146742224693298
?: -0.08470432460308075
What: -0.219400629401206

KeyboardInterrupt: 

# For one word

In [None]:
# def igs_nano(context: str, n_steps: int = 50) -> tuple[str, int, torch.Tensor]:
#     encoded_input = tokenizer(context, return_tensors="pt") #returns a dict
#     input_ids = encoded_input["input_ids"]  # [1, 7]
#     baseline_input_ids = torch.zeros_like(input_ids) #[1, 7]

#     attention_mask = encoded_input["attention_mask"] # [1, 7]
#     all_tokens = tokenizer.convert_ids_to_tokens(input_ids[0].detach().tolist())    # [1, 7]

#     input_embeddings = model_nano.transformer.wte(input_ids).to(device) #[1, 7, 768]
#     baseline_embeddings = model_nano.transformer.wte(baseline_input_ids).to(device) #[1, 7, 768]

#     model_nano.eval()

#     output_logits = model_nano(input_ids)[0] #[1, 7, 50257]
#     next_token_logits = output_logits[0, 0, :]  # [50257]
#     predicted_token_id = torch.argmax(next_token_logits).item() # [1]

#     position_ids = torch.arange(0, input_embeddings.size(1)).unsqueeze(0) #[1, 7]
#     position_embeddings = model_nano.transformer.wpe(position_ids) #[1, 7, 768]

# # choose target_word_index
#     for target_word_index in range(input_embeddings.size(1)):
    
#         target_word_embedding = input_embeddings[0,target_word_index,:].unsqueeze(0) #[1, 768]
#         target_word_baseline = baseline_embeddings[0,target_word_index,:].unsqueeze(0) #[1, 768]

#         alphas = torch.linspace(0, 1, steps=n_steps).unsqueeze(-1) #[50, 1]

#         step_embeddings = target_word_baseline + alphas * (target_word_embedding - target_word_baseline) #[50, 768]
#         step_embeddings.requires_grad_(True) #[50, 768]
#         step_embeddings.retain_grad()
#         step_embeddings.grad = None

#         forward_embeddings = step_embeddings + position_embeddings #[50.768]
#         forward_embeddings = model_nano.transformer.drop(forward_embeddings) #[50, 768]   

#         for block in model_nano.transformer.h:
#             forward_embeddings = block(forward_embeddings) #[50, 768]

#         forward_embeddings = model_nano.transformer.ln_f(forward_embeddings) #[50, 768]
#         output_at_step = model_nano.lm_head(forward_embeddings) #[50, 50257]

#         class_output_at_step = output_at_step[:, predicted_token_id] #[50]
#         summed_output_for_gradient_computation = class_output_at_step.sum() #[1]
#         summed_output_for_gradient_computation.backward(retain_graph=True)
#         # class_output_at_step.backward(retain_graph=True)
#         # DO THINGS

#         assert step_embeddings.grad is not None

#         step_embeddings_grad_pre_sum = step_embeddings.grad/n_steps #[50, 768]

#         target_word_igs = step_embeddings_grad_pre_sum.sum(dim=0) #[1, 768]
#         target_word_igs = target_word_igs * (target_word_embedding - target_word_baseline) #[1, 768]

# Interrupt Model at Will

In [None]:
context = "What is the capital of France?"

encoded_input = tokenizer(context, return_tensors='pt')
input_ids = encoded_input['input_ids']
attention_mask = encoded_input['attention_mask']

In [None]:
def predict_with_interruption(context: str, 
                              layer: Literal['word_embeddings', 
                                             'position_embeddings',
                                             'embedding_layer',
                                             'block_0',
                                             'block_1',
                                             'block_2',
                                             'block_3',
                                             'block_4',
                                             'block_5',
                                             'block_6',
                                             'block_7',
                                             'block_8',
                                             'block_9',
                                             'block_10',
                                             'block_11']
                               ) -> tuple[int, torch.Tensor]:

    blocks_list = ['block_0',
                   'block_1',
                   'block_2',
                   'block_3',
                   'block_4',
                   'block_5',
                   'block_6',
                   'block_7',
                   'block_8',
                   'block_9',
                   'block_10',
                   'block_11']

    input_ids = tokenizer.encode(context, return_tensors='pt')
    word_embeddings = model.transformer.wte(input_ids)

    if layer == word_embeddings:
        extracted_entity = word_embeddings

    position_ids = torch.arange(0, input_ids.size(-1)).unsqueeze(0)
    position_embeddings = model.transformer.wpe(position_ids)

    model_forward_embeddings = word_embeddings + position_embeddings

    if layer == position_embeddings:
        extracted_entity = model_forward_embeddings

    model_forward_embeddings = model.transformer.drop(model_forward_embeddings)

    if layer == embedding_layer:
        extracted_entity = model_forward_embeddings

    block_counter = 0
    for block in model.transformer.h:
        model_forward_embeddings = block(model_forward_embeddings)[0]
        if layer == blocks_list[block_counter]:
            extracted_entity = model_forward_embeddings
        block_counter += 1

    model_forward_embeddings = model.transformer.ln_f(model_forward_embeddings)

    model_forward_embeddings = model.lm_head(model_forward_embeddings)

    predicted_next_token_logits = model_forward_embeddings[0, -1, :]
    predicted_next_token = torch.argmax(predicted_next_token_logits).item()

    return predicted_next_token, extracted_entity

In [None]:
interrupt_at_layer(8)

12


198

In [None]:
predicted_next_token_logits

tensor([-120.2345, -119.8962, -121.1645,  ..., -131.9775, -131.2008,
        -116.6154], grad_fn=<SliceBackward0>)

In [None]:
from transformers import GPT2Tokenizer, GPT2Model
tokenizer = GPT2Tokenizer.from_pretrained('distilgpt2')
model = GPT2Model.from_pretrained('distilgpt2')
text = "Replace me by any text you'd like."
encoded_input = tokenizer(text, return_tensors='pt')
output = model(**encoded_input)


Downloading vocab.json:   0%|          | 0.00/1.04M [00:00<?, ?B/s]

Downloading merges.txt:   0%|          | 0.00/456k [00:00<?, ?B/s]

Downloading tokenizer.json:   0%|          | 0.00/1.36M [00:00<?, ?B/s]

Downloading config.json:   0%|          | 0.00/762 [00:00<?, ?B/s]

Downloading model.safetensors:   0%|          | 0.00/353M [00:00<?, ?B/s]

In [None]:
model

GPT2Model(
  (wte): Embedding(50257, 768)
  (wpe): Embedding(1024, 768)
  (drop): Dropout(p=0.1, inplace=False)
  (h): ModuleList(
    (0-5): 6 x GPT2Block(
      (ln_1): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
      (attn): GPT2Attention(
        (c_attn): Conv1D()
        (c_proj): Conv1D()
        (attn_dropout): Dropout(p=0.1, inplace=False)
        (resid_dropout): Dropout(p=0.1, inplace=False)
      )
      (ln_2): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
      (mlp): GPT2MLP(
        (c_fc): Conv1D()
        (c_proj): Conv1D()
        (act): NewGELUActivation()
        (dropout): Dropout(p=0.1, inplace=False)
      )
    )
  )
  (ln_f): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
)