In [1]:
import torch
from GPTNeoWithIntermediates import GPTNeoWithIntermediates
from transformers import GPT2Tokenizer

In [2]:
tokenizer = GPT2Tokenizer.from_pretrained("EleutherAI/gpt-neo-125M")
model = GPTNeoWithIntermediates.from_pretrained("EleutherAI/gpt-neo-125M")

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model.to(device)

GPTNeoWithIntermediates(
  (transformer): GPTNeoModel(
    (wte): Embedding(50257, 768)
    (wpe): Embedding(2048, 768)
    (drop): Dropout(p=0.0, inplace=False)
    (h): ModuleList(
      (0-11): 12 x GPTNeoBlock(
        (ln_1): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
        (attn): GPTNeoAttention(
          (attention): GPTNeoSelfAttention(
            (attn_dropout): Dropout(p=0.0, inplace=False)
            (resid_dropout): Dropout(p=0.0, inplace=False)
            (k_proj): Linear(in_features=768, out_features=768, bias=False)
            (v_proj): Linear(in_features=768, out_features=768, bias=False)
            (q_proj): Linear(in_features=768, out_features=768, bias=False)
            (out_proj): Linear(in_features=768, out_features=768, bias=True)
          )
        )
        (ln_2): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
        (mlp): GPTNeoMLP(
          (c_fc): Linear(in_features=768, out_features=3072, bias=True)
          (c_proj): Linear

In [3]:
import time

In [4]:
# Set the padding token to the eos_token
tokenizer.pad_token = tokenizer.eos_token

def perform_inference_with_all_intermediates(input_texts, selected_layers=None):
    # Tokenize the input texts in parallel
    inputs = tokenizer(input_texts, return_tensors="pt", padding=True, truncation=True).to(device)
    print(inputs)
    input_ids, attention_mask = inputs.input_ids, inputs.attention_mask
    
    with torch.no_grad():
        outputs = model(input_ids, attention_mask=attention_mask)

    predicted_ids = torch.argmax(outputs["logits"], dim=-1)
    generated_texts = [tokenizer.decode(ids, skip_special_tokens=True) for ids in predicted_ids]

    if selected_layers is None:
        selected_layers = list(range(len(outputs["layer_outputs"])))
    
    # Extract the intermediate outputs for selected layers for each input
    selected_outputs = [
        {i: layer_output for i, layer_output in enumerate(outputs["layer_outputs"]) if i in selected_layers}
        for _ in input_texts
    ]

    return [
        {
            "generated_text": generated_texts[idx],
            "layer_outputs": selected_outputs[idx],
        }
        for idx in range(len(input_texts))
    ]

In [5]:
input_texts = ["Once upon a time", "The quick brown fox jumps over the lazy dog", "brave", "ollowing of the past Revolution War were the theoy's fought much in the American of tactics of the American-. of were.. the, the was the lesson of young strateg how importance idea of theming a a best way to win a shipscladads."]
selected_layers = [1, 5, 11]

t = time.time()
results = perform_inference_with_all_intermediates(input_texts, selected_layers)
print(time.time() - t)

for idx, result in enumerate(results):
    print(f"Generated Text for input {idx + 1}: {result['generated_text']}")
    # for layer_num, layer_output in result["layer_outputs"].items():
    #     print(f"Layer {layer_num} - Attention Block Output Shape:", layer_output["attn_out"].shape)
    #     print(f"Layer {layer_num} - MLP Hidden Output Shape:", layer_output["mlp_hidden"].shape)
    #     print(f"Layer {layer_num} - MLP Final Output Shape:", layer_output["mlp_final"].shape)

{'input_ids': tensor([[ 7454,  2402,   257,   640, 50256, 50256, 50256, 50256, 50256, 50256,
         50256, 50256, 50256, 50256, 50256, 50256, 50256, 50256, 50256, 50256,
         50256, 50256, 50256, 50256, 50256, 50256, 50256, 50256, 50256, 50256,
         50256, 50256, 50256, 50256, 50256, 50256, 50256, 50256, 50256, 50256,
         50256, 50256, 50256, 50256, 50256, 50256, 50256, 50256, 50256, 50256,
         50256, 50256],
        [  464,  2068,  7586, 21831, 18045,   625,   262, 16931,  3290, 50256,
         50256, 50256, 50256, 50256, 50256, 50256, 50256, 50256, 50256, 50256,
         50256, 50256, 50256, 50256, 50256, 50256, 50256, 50256, 50256, 50256,
         50256, 50256, 50256, 50256, 50256, 50256, 50256, 50256, 50256, 50256,
         50256, 50256, 50256, 50256, 50256, 50256, 50256, 50256, 50256, 50256,
         50256, 50256],
        [   65,  5758, 50256, 50256, 50256, 50256, 50256, 50256, 50256, 50256,
         50256, 50256, 50256, 50256, 50256, 50256, 50256, 50256, 5025