## Imports

In [74]:
%load_ext autoreload
%autoreload 2

import torch
import numpy as np
import torch
import transformer_lens
import transformers
import tiktoken

import circuitsvis as cv

from torch import nn
from tqdm import tqdm
from transformer_lens.hook_points import HookPoint
from transformer_lens import HookedTransformer

from model.config import GPTNeoWithSelfAblationConfig
from model.gpt_neo import GPTNeoWithSelfAblation

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


## Setup

In [80]:
# We only need inference
torch.set_grad_enabled(False)

# Set cuda if available
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")

# Change this to the path of the model to test, change the config if needed
model_path = "model_weights/youthful-wave-20.pt"
model_specific_config = {
    'hidden_size': 128,
    'max_position_embeddings': 256,
    
    # These two are currently not mutually exclusive
    'has_layer_by_layer_ablation_mask': False,
    'has_overall_ablation_mask': True,
}

Using device: cuda


## Model Loading

In [81]:
model_config = GPTNeoWithSelfAblationConfig(**model_specific_config)
model = GPTNeoWithSelfAblation(model_config).to(device)
tokenizer = tiktoken.get_encoding("gpt2")
model.load_state_dict(torch.load(model_path, map_location=device))
model.eval()

  model.load_state_dict(torch.load(model_path, map_location=device))


GPTNeoWithSelfAblation(
  (transformer): ModuleDict(
    (wte): Embedding(50257, 128)
    (wpe): Embedding(256, 128)
    (h): ModuleList(
      (0-7): 8 x GPTNeoBlockWithSelfAblation(
        (ln_1): LayerNorm((128,), eps=1e-05, elementwise_affine=True)
        (attn): AttentionWithSelfAblation(
          (hook_k): HookPoint()
          (hook_v): HookPoint()
          (hook_q): HookPoint()
          (attn_hook): HookPoint()
          (context): HookPoint()
          (ablated_context): HookPoint()
          (attention): ModuleDict(
            (k_proj): Linear(in_features=128, out_features=128, bias=False)
            (v_proj): Linear(in_features=128, out_features=128, bias=False)
            (q_proj): Linear(in_features=128, out_features=128, bias=False)
            (out_proj): Linear(in_features=128, out_features=128, bias=True)
          )
        )
        (ln_2): LayerNorm((128,), eps=1e-05, elementwise_affine=True)
        (mlp): MLPWithSelfAblation(
          (c_fc): Linear(in_fe

### Sanity Check

In [82]:
inputs = torch.randint(0, 50256, (1, 256)).to(device)

# Try run_with_cache
output, cache = model.run_with_cache(inputs)

# Seems to be working?
print(cache.keys())

dict_keys(['transformer.h.0.attn.hook_q', 'transformer.h.0.attn.hook_k', 'transformer.h.0.attn.hook_v', 'transformer.h.0.attn.attn_hook', 'transformer.h.0.attn.context', 'transformer.h.0.attn.ablated_context', 'transformer.h.0.attn_hook', 'transformer.h.0.mlp.hook_fc_activation', 'transformer.h.0.mlp.hook_ablated_fc_activation', 'transformer.h.0.mlp_hook', 'transformer.h.1.attn.hook_q', 'transformer.h.1.attn.hook_k', 'transformer.h.1.attn.hook_v', 'transformer.h.1.attn.attn_hook', 'transformer.h.1.attn.context', 'transformer.h.1.attn.ablated_context', 'transformer.h.1.attn_hook', 'transformer.h.1.mlp.hook_fc_activation', 'transformer.h.1.mlp.hook_ablated_fc_activation', 'transformer.h.1.mlp_hook', 'transformer.h.2.attn.hook_q', 'transformer.h.2.attn.hook_k', 'transformer.h.2.attn.hook_v', 'transformer.h.2.attn.attn_hook', 'transformer.h.2.attn.context', 'transformer.h.2.attn.ablated_context', 'transformer.h.2.attn_hook', 'transformer.h.2.mlp.hook_fc_activation', 'transformer.h.2.mlp.ho

In [83]:
input_text = "Sam and Tom are in the park. Tom said to"
input_ids = tokenizer.encode(input_text)

In [84]:
# decode the output
output, cache = model.run_with_cache(torch.tensor(input_ids).unsqueeze(0).to(device))
tokenizer.decode(torch.argmax(output["logits_clean"],-1)[0,-1:].tolist())

' Sam'

In [94]:
# Convert ids to tokens
tokens = [tokenizer.decode_single_token_bytes(token).decode('utf-8') for token in input_ids]

activation_pattern = cache['transformer.h.4.attn.attn_hook']

# Seems to be working
output = cv.attention.attention_pattern(tokens=tokens, attention=activation_pattern[0,0])

In [95]:
display(output)