## Imports

In [1]:
%load_ext autoreload
%autoreload 2

import torch
import numpy as np
import torch
import transformer_lens
import transformers
import tiktoken

import circuitsvis as cv

from torch import nn
from tqdm import tqdm
from transformer_lens.hook_points import HookPoint
from transformer_lens import HookedTransformer

from model.config import GPTNeoWithSelfAblationConfig
from model.gpt_neo import GPTNeoWithSelfAblation
from utils.compatibility import remap_state_dict_keys

## Setup

In [2]:
# We only need inference
torch.set_grad_enabled(False)

# Set cuda if available
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")

# Change this to the path of the model to test, change the config if needed
model_path = "model_weights/youthful-wave-20.pt"
model_specific_config = {
    'hidden_size': 128,
    'max_position_embeddings': 256,
    
    # These two are currently not mutually exclusive
    'has_layer_by_layer_ablation_mask': False,
    'has_overall_ablation_mask': True,
}

Using device: cuda


## Model Loading

In [3]:
model_config = GPTNeoWithSelfAblationConfig(**model_specific_config)
model = GPTNeoWithSelfAblation(model_config).to(device)
tokenizer = tiktoken.get_encoding("gpt2")

# Load the state dictionary from the file
state_dict = torch.load(model_path, map_location=device)

# Remap the keys in the state dictionary
remapped_state_dict = remap_state_dict_keys(state_dict)

# Load the modified state dictionary into the model
model.load_state_dict(remapped_state_dict)

model.eval()

  state_dict = torch.load(model_path, map_location=device)


GPTNeoWithSelfAblation(
  (wte): Embedding(50257, 128)
  (wpe): Embedding(256, 128)
  (blocks): ModuleList(
    (0-7): 8 x GPTNeoBlockWithSelfAblation(
      (ln_1): LayerNorm((128,), eps=1e-05, elementwise_affine=True)
      (attn): AttentionWithSelfAblation(
        (k_hook): HookPoint()
        (v_hook): HookPoint()
        (q_hook): HookPoint()
        (attn_hook): HookPoint()
        (context_hook): HookPoint()
        (ablated_context_hook): HookPoint()
        (attention): ModuleDict(
          (k_proj): Linear(in_features=128, out_features=128, bias=False)
          (v_proj): Linear(in_features=128, out_features=128, bias=False)
          (q_proj): Linear(in_features=128, out_features=128, bias=False)
          (out_proj): Linear(in_features=128, out_features=128, bias=True)
        )
      )
      (ln_2): LayerNorm((128,), eps=1e-05, elementwise_affine=True)
      (mlp): MLPWithSelfAblation(
        (c_fc): Linear(in_features=128, out_features=512, bias=True)
        (c_proj):

### Sanity Check

In [4]:
inputs = torch.randint(0, 50256, (1, 256)).to(device)

# Try run_with_cache
output, cache = model.run_with_cache(inputs)

# Seems to be working?
print(cache.keys())

dict_keys(['blocks.0.attn.q_hook', 'blocks.0.attn.k_hook', 'blocks.0.attn.v_hook', 'blocks.0.attn.attn_hook', 'blocks.0.attn.context_hook', 'blocks.0.attn.ablated_context_hook', 'blocks.0.hook_attn_out', 'blocks.0.mlp.fc_activation_hook', 'blocks.0.mlp.ablated_fc_activation_hook', 'blocks.0.hook_mlp_out', 'blocks.1.attn.q_hook', 'blocks.1.attn.k_hook', 'blocks.1.attn.v_hook', 'blocks.1.attn.attn_hook', 'blocks.1.attn.context_hook', 'blocks.1.attn.ablated_context_hook', 'blocks.1.hook_attn_out', 'blocks.1.mlp.fc_activation_hook', 'blocks.1.mlp.ablated_fc_activation_hook', 'blocks.1.hook_mlp_out', 'blocks.2.attn.q_hook', 'blocks.2.attn.k_hook', 'blocks.2.attn.v_hook', 'blocks.2.attn.attn_hook', 'blocks.2.attn.context_hook', 'blocks.2.attn.ablated_context_hook', 'blocks.2.hook_attn_out', 'blocks.2.mlp.fc_activation_hook', 'blocks.2.mlp.ablated_fc_activation_hook', 'blocks.2.hook_mlp_out', 'blocks.3.attn.q_hook', 'blocks.3.attn.k_hook', 'blocks.3.attn.v_hook', 'blocks.3.attn.attn_hook', 'b

In [5]:
input_text = "Sam and Tom are in the park. Tom said to"
input_ids = tokenizer.encode(input_text)

In [6]:
# decode the output
output, cache = model.run_with_cache(torch.tensor(input_ids).unsqueeze(0).to(device))
tokenizer.decode(torch.argmax(output["logits_clean"],-1)[0,-1:].tolist())

' Sam'

In [7]:
# Convert ids to tokens
tokens = [tokenizer.decode_single_token_bytes(token).decode('utf-8') for token in input_ids]

activation_pattern = cache['blocks.4.attn.attn_hook']

# Seems to be working
output = cv.attention.attention_pattern(tokens=tokens, attention=activation_pattern[0,0])

In [None]:
display(output)

## Activation Cache test

In [8]:
residual_stream, labels = cache.decompose_resid(return_labels=True, mode="attn")