## Imports

In [2]:
%load_ext autoreload
%autoreload 2

import torch
import numpy as np
import torch
import transformer_lens
import transformers
from torch import nn
from tqdm import tqdm
from transformer_lens.hook_points import HookPoint
from transformer_lens import HookedTransformer

from model.config import GPTNeoWithSelfAblationConfig
from model.gpt_neo import GPTNeoWithSelfAblation

## Setup

In [3]:
# We only need inference
torch.set_grad_enabled(False)

# Set cuda if available
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")

# Change this to the path of the model to test, change the config if needed
model_path = "model_weights/youthful-wave-20.pt"
model_specific_config = {
    'hidden_size': 128,
    'max_position_embeddings': 256,
    
    # These two are currently not mutually exclusive
    'has_layer_by_layer_ablation_mask': False,
    'has_overall_ablation_mask': True,
}

Using device: cuda


## Model Loading

In [4]:
model_config = GPTNeoWithSelfAblationConfig(**model_specific_config)
model = GPTNeoWithSelfAblation(model_config)
model.load_state_dict(torch.load(model_path, map_location=device))
model.eval()

  model.load_state_dict(torch.load(model_path, map_location=device))


GPTNeoWithSelfAblation(
  (transformer): ModuleDict(
    (wte): Embedding(50257, 128)
    (wpe): Embedding(256, 128)
    (h): ModuleList(
      (0-7): 8 x GPTNeoBlockWithSelfAblation(
        (ln_1): LayerNorm((128,), eps=1e-05, elementwise_affine=True)
        (attn): AttentionWithSelfAblation(
          (attention): ModuleDict(
            (k_proj): Linear(in_features=128, out_features=128, bias=False)
            (v_proj): Linear(in_features=128, out_features=128, bias=False)
            (q_proj): Linear(in_features=128, out_features=128, bias=False)
            (out_proj): Linear(in_features=128, out_features=128, bias=True)
          )
        )
        (ln_2): LayerNorm((128,), eps=1e-05, elementwise_affine=True)
        (mlp): MLPWithSelfAblation(
          (c_fc): Linear(in_features=128, out_features=512, bias=True)
          (c_proj): Linear(in_features=512, out_features=128, bias=True)
          (act): NewGELUActivation()
        )
      )
    )
    (ln_f): LayerNorm((128,), 

## Hooked Model Definition

According to the Codebook repo, we need to first create a HookedTransformer and then overwrite parts of it?

In [5]:
hooked_kwargs = dict(
    center_unembed=False,
    center_writing_weights=False,
    fold_ln=False,
    fold_value_biases=False,
    refactor_factored_attn_matrices=False,
    device=device,
)

# Create the HookedTransformer model from GPT Neo
hooked_model = HookedTransformer.from_pretrained("tiny-stories-3M", **hooked_kwargs)
hooked_model.to(device).eval()

Loaded pretrained model tiny-stories-3M into HookedTransformer
Moving model to device:  cuda




HookedTransformer(
  (embed): Embed()
  (hook_embed): HookPoint()
  (pos_embed): PosEmbed()
  (hook_pos_embed): HookPoint()
  (blocks): ModuleList(
    (0-7): 8 x TransformerBlock(
      (ln1): LayerNorm(
        (hook_scale): HookPoint()
        (hook_normalized): HookPoint()
      )
      (ln2): LayerNorm(
        (hook_scale): HookPoint()
        (hook_normalized): HookPoint()
      )
      (attn): Attention(
        (hook_k): HookPoint()
        (hook_q): HookPoint()
        (hook_v): HookPoint()
        (hook_z): HookPoint()
        (hook_attn_scores): HookPoint()
        (hook_pattern): HookPoint()
        (hook_result): HookPoint()
      )
      (mlp): MLP(
        (hook_pre): HookPoint()
        (hook_post): HookPoint()
      )
      (hook_attn_in): HookPoint()
      (hook_q_input): HookPoint()
      (hook_k_input): HookPoint()
      (hook_v_input): HookPoint()
      (hook_mlp_in): HookPoint()
      (hook_attn_out): HookPoint()
      (hook_mlp_out): HookPoint()
      (hook_resi

In [6]:
for k, v in model_config.__dict__.items():
    if k not in hooked_model.cfg.__dict__:
        hooked_model.cfg.__setattr__(k, v)
        
# for k1, k2 in base_model_config.attribute_map.items():
#     if k1 not in model.cfg.__dict__:
#         model.cfg.__setattr__(k1, base_model_config.__getattribute__(k2))

In [7]:
print(hooked_model)

HookedTransformer(
  (embed): Embed()
  (hook_embed): HookPoint()
  (pos_embed): PosEmbed()
  (hook_pos_embed): HookPoint()
  (blocks): ModuleList(
    (0-7): 8 x TransformerBlock(
      (ln1): LayerNorm(
        (hook_scale): HookPoint()
        (hook_normalized): HookPoint()
      )
      (ln2): LayerNorm(
        (hook_scale): HookPoint()
        (hook_normalized): HookPoint()
      )
      (attn): Attention(
        (hook_k): HookPoint()
        (hook_q): HookPoint()
        (hook_v): HookPoint()
        (hook_z): HookPoint()
        (hook_attn_scores): HookPoint()
        (hook_pattern): HookPoint()
        (hook_result): HookPoint()
      )
      (mlp): MLP(
        (hook_pre): HookPoint()
        (hook_post): HookPoint()
      )
      (hook_attn_in): HookPoint()
      (hook_q_input): HookPoint()
      (hook_k_input): HookPoint()
      (hook_v_input): HookPoint()
      (hook_mlp_in): HookPoint()
      (hook_attn_out): HookPoint()
      (hook_mlp_out): HookPoint()
      (hook_resi

In [8]:
def reshape_checkpoint_weights(state_dict, model_state_dict):
    new_state_dict = {}

    num_heads = 16  # Update based on your model configuration
    d_model = 128   # Update based on your model configuration
    d_head = d_model // num_heads

    for key, value in state_dict.items():
        if "attn.W_Q" in key or "attn.W_K" in key or "attn.W_V" in key:
            # Reshape [d_model, d_model] to [num_heads, d_head, d_model] and transpose the last two dimensions
            reshaped_value = value.view(num_heads, d_head, d_model).transpose(1, 2)
            new_state_dict[key] = reshaped_value

        elif "attn.W_O" in key:
            # Reshape [d_model, d_model] to [num_heads, d_model, d_head] and transpose the last two dimensions
            reshaped_value = value.view(num_heads, d_model, d_head).transpose(1, 2)
            new_state_dict[key] = reshaped_value

        elif "unembed.W_U" in key:
            # Transpose [d_model, vocab_size] to [vocab_size, d_model]
            reshaped_value = value.T
            new_state_dict[key] = reshaped_value

        else:
            # Keep the value as-is for other keys
            new_state_dict[key] = value

    # Update with the rest of the state_dict from the model's state_dict
    for key, value in model_state_dict.items():
        if key not in new_state_dict:
            new_state_dict[key] = value

    return new_state_dict

In [9]:
# Function to remap keys
def remap_state_dict_keys(state_dict):
    new_state_dict = {}

    for i in range(8):  # Assuming there are 8 blocks
        new_state_dict[f"transformer.h.{i}.ln_1.weight"] = f"blocks.{i}.ln1.w"
        new_state_dict[f"transformer.h.{i}.ln_1.bias"] = f"blocks.{i}.ln1.b"
        new_state_dict[f"transformer.h.{i}.ln_2.weight"] = f"blocks.{i}.ln2.w"
        new_state_dict[f"transformer.h.{i}.ln_2.bias"] = f"blocks.{i}.ln2.b"
        new_state_dict[f"transformer.h.{i}.attn.attention.q_proj.weight"] = f"blocks.{i}.attn.W_Q"
        new_state_dict[f"transformer.h.{i}.attn.attention.out_proj.weight"] = f"blocks.{i}.attn.W_O"
        new_state_dict[f"transformer.h.{i}.attn.attention.out_proj.bias"] = f"blocks.{i}.attn.b_O"
        new_state_dict[f"transformer.h.{i}.attn.attention.k_proj.weight"] = f"blocks.{i}.attn.W_K"
        new_state_dict[f"transformer.h.{i}.attn.attention.v_proj.weight"] = f"blocks.{i}.attn.W_V"
        new_state_dict[f"transformer.h.{i}.attn.attention.k_proj.bias"] = f"blocks.{i}.attn.b_K"
        new_state_dict[f"transformer.h.{i}.attn.attention.v_proj.bias"] = f"blocks.{i}.attn.b_V"

    # Remapping other consistent patterns
    new_state_dict["transformer.wte.weight"] = "embed.W_E"
    new_state_dict["transformer.wpe.weight"] = "pos_embed.W_pos"
    new_state_dict["transformer.ln_f.weight"] = "ln_final.w"
    new_state_dict["transformer.ln_f.bias"] = "ln_final.b"
    new_state_dict["lm_head.weight"] = "unembed.W_U"

    return new_state_dict

In [14]:
# Function to update and load the remapped state_dict
def load_remapped_state_dict(hooked_model, model):
    original_state_dict = model.state_dict()

    # Remap the keys
    remapped_state_dict = remap_state_dict_keys(original_state_dict)

    # Create a new state_dict to load into the hooked_model
    new_state_dict = {remapped_state_dict.get(k, k): v for k, v in original_state_dict.items()}
    
    # Get the hooked model's state dictionary
    hooked_model_state_dict = hooked_model.state_dict()
    
    # Reshape the weights in the checkpoint to match the model
    reshaped_state_dict = reshape_checkpoint_weights(new_state_dict, hooked_model_state_dict)
    
    if 'pos_embed.W_pos' in reshaped_state_dict:
        del reshaped_state_dict['pos_embed.W_pos']

    # Load the reshaped state_dict into the hooked model
    hooked_model.load_state_dict(reshaped_state_dict, strict=False)

# Example usage:
# hooked_model is your target model and model is your source PyTorch model
load_remapped_state_dict(hooked_model, model)


In [20]:
# Test the model
input_ids = torch.randint(0, 50256, (1, 128)).to(device)
output = hooked_model(input_ids)

In [22]:
print(hooked_model)

HookedTransformer(
  (embed): Embed()
  (hook_embed): HookPoint()
  (pos_embed): PosEmbed()
  (hook_pos_embed): HookPoint()
  (blocks): ModuleList(
    (0-7): 8 x TransformerBlock(
      (ln1): LayerNorm(
        (hook_scale): HookPoint()
        (hook_normalized): HookPoint()
      )
      (ln2): LayerNorm(
        (hook_scale): HookPoint()
        (hook_normalized): HookPoint()
      )
      (attn): Attention(
        (hook_k): HookPoint()
        (hook_q): HookPoint()
        (hook_v): HookPoint()
        (hook_z): HookPoint()
        (hook_attn_scores): HookPoint()
        (hook_pattern): HookPoint()
        (hook_result): HookPoint()
      )
      (mlp): MLP(
        (hook_pre): HookPoint()
        (hook_post): HookPoint()
      )
      (hook_attn_in): HookPoint()
      (hook_q_input): HookPoint()
      (hook_k_input): HookPoint()
      (hook_v_input): HookPoint()
      (hook_mlp_in): HookPoint()
      (hook_attn_out): HookPoint()
      (hook_mlp_out): HookPoint()
      (hook_resi