In [47]:
import torch
from transformers import GPT2LMHeadModel
from transformer_lens import HookedTransformer


model = GPT2LMHeadModel.from_pretrained("apollo-research/gpt2_noLN").to("cpu")

# Undo my hacky LayerNorm removal
for block in model.transformer.h:
    block.ln_1.weight.data = block.ln_1.weight.data / 1e6
    block.ln_1.eps = 1e-5
    block.ln_2.weight.data = block.ln_2.weight.data / 1e6
    block.ln_2.eps = 1e-5
model.transformer.ln_f.weight.data = model.transformer.ln_f.weight.data / 1e6
model.transformer.ln_f.eps = 1e-5

# Properly replace LayerNorms by Identities
class HookedTransformerNoLN(HookedTransformer):
    def removeLN(self):
        for i in range(len(self.blocks)):
            self.blocks[i].ln1 = torch.nn.Identity()
            self.blocks[i].ln2 = torch.nn.Identity()
        self.ln_final = torch.nn.Identity()

hooked_model = HookedTransformerNoLN.from_pretrained("gpt2", hf_model=model, fold_ln=True, center_unembed=False).to("cpu")
hooked_model.removeLN()

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
prompt = torch.tensor([1,2,3,4], device=device)
logits = hooked_model(prompt)

print(logits.shape)
print(logits[0, 0, :10])



Loaded pretrained model gpt2 into HookedTransformer
Moving model to device:  cpu
torch.Size([1, 4, 50257])
tensor([25.9575, 27.0594, 23.7760, 24.0054, 25.5462, 23.6517, 26.7588, 25.8552,
        27.0049, 25.4789], grad_fn=<SliceBackward0>)


In [102]:
from nnsight import LanguageModel
# import gpt2 tokenizer
from transformers import GPT2Tokenizer
model_nnsight = LanguageModel(
    "apollo-research/gpt2_noLN",
    device_map = "cpu",
    tokenizer = GPT2Tokenizer.from_pretrained("gpt2"),
    # automodel = GPT2LMHeadModel,
    dispatch = True,
    # torch_dtype=torch.bfloat16,
)



In [103]:
model_nnsight.lm_head = # new linear layer
model_nnsight.lm_head = torch.nn.Linear(768, 50257)


Linear(in_features=768, out_features=50257, bias=False)

In [241]:
import torch
from einops import rearrange


num_layers = len(model_nnsight.transformer.h)

# Reassign embedding weights
model_nnsight.transformer.wte.weight.data.copy_(hooked_model.W_E.data)
model_nnsight.transformer.wpe.weight.data.copy_(hooked_model.W_pos.data)
model_nnsight.lm_head = torch.nn.Linear(768, 50257)
model_nnsight.lm_head.weight.data.copy_(hooked_model.W_U.data.t())

for layer in range(num_layers):
    # Replace layer norms with identity
    model_nnsight.transformer.h[layer].ln_1 = torch.nn.Identity()
    model_nnsight.transformer.h[layer].ln_2 = torch.nn.Identity()

    # Reassign attention weights (Q, K, V)
    # q = hooked_model.W_Q[layer].data.reshape(768, 64 * 12)
    # k = hooked_model.W_K[layer].data.reshape(768, 64 * 12)
    # v = hooked_model.W_V[layer].data.reshape(768, 64 * 12)
    q = rearrange(hooked_model.W_Q[layer].data, 'w d l -> d (w l)')
    k = rearrange(hooked_model.W_K[layer].data, 'w d l -> d (w l)')
    v = rearrange(hooked_model.W_V[layer].data, 'w d l -> d (w l)').T


    
    # Concatenate Q, K, V
    qkv = torch.cat([q, k, v], dim=1)
    
    # Assign to c_attn
    model_nnsight.transformer.h[layer].attn.c_attn.weight.data.copy_(qkv)
    
    #TODO: verify that the W_O needs to be transposed
    # Reassign output projection (W_O to c_proj)
    w_o = rearrange(hooked_model.W_O[layer].data, "w l d -> (w l) d").T

    model_nnsight.transformer.h[layer].attn.c_proj.weight.data.copy_(w_o)
    
    # Reassign MLP weights
    model_nnsight.transformer.h[layer].mlp.c_fc.weight.data.copy_(
        hooked_model.W_in[layer].data.reshape(3072, 768).t()
    )
    # and bias
    model_nnsight.transformer.h[layer].mlp.c_fc.bias.data.copy_(
        hooked_model.b_in[layer].data
    )
    model_nnsight.transformer.h[layer].mlp.c_proj.weight.data.copy_(
        hooked_model.W_out[layer].data.reshape(768, 3072).t()
    )
    # and bias
    model_nnsight.transformer.h[layer].mlp.c_proj.bias.data.copy_(
        hooked_model.b_out[layer].data
    )
# Replace final layer norm with identity
model_nnsight.transformer.ln_f = torch.nn.Identity()

print(f"Weights reassigned successfully for all {num_layers} layers.")

# Usage
# reassign_weights(hooked_model, model_nnsight)

torch.Size([768, 768])
torch.Size([768, 768])
torch.Size([768, 768])
torch.Size([768, 768])
torch.Size([768, 768])
torch.Size([768, 768])
torch.Size([768, 768])
torch.Size([768, 768])
torch.Size([768, 768])
torch.Size([768, 768])
torch.Size([768, 768])
torch.Size([768, 768])
Weights reassigned successfully for all 12 layers.


In [212]:
hooked_model.W_O[layer].shape
w_o = rearrange(hooked_model.W_O[layer].data, "w l d -> (w l) d")
w_o.T == hooked_model.W_O[layer].data.reshape(768, 768)

tensor([[ True, False, False,  ..., False, False, False],
        [False,  True, False,  ..., False, False, False],
        [False, False,  True,  ..., False, False, False],
        ...,
        [False, False, False,  ...,  True, False, False],
        [False, False, False,  ..., False,  True, False],
        [False, False, False,  ..., False, False,  True]])

In [181]:
from einops import rearrange
tt = hooked_model.W_Q[layer].data
q_w = rearrange(tt, 'w d l -> d (w l)')
q = hooked_model.W_Q[layer].data.reshape(768, 64 * 12)
q_w

tensor([[-0.0142,  0.0025, -0.0053,  ...,  0.0020, -0.0021, -0.0142],
        [ 0.0008, -0.0083, -0.0019,  ..., -0.0025, -0.0029, -0.0009],
        [ 0.0077, -0.0039,  0.0032,  ..., -0.0027,  0.0036, -0.0049],
        ...,
        [ 0.0025,  0.0060,  0.0022,  ..., -0.0117,  0.0016,  0.0092],
        [-0.0054,  0.0048,  0.0004,  ..., -0.0070, -0.0028,  0.0018],
        [ 0.0096,  0.0027, -0.0016,  ...,  0.0030, -0.0039, -0.0040]])

In [219]:
hooked_model.W_Q[layer].data.shape, hooked_model.W_V[layer].data.shape

(torch.Size([12, 768, 64]), torch.Size([12, 768, 64]))

In [194]:
model_nnsight.tokenizer.pad_token = model_nnsight.tokenizer.eos_token

In [242]:
# with model_nnsight.trace("The Eiffel Tower is in the city of") as runner:
with torch.no_grad():
    with model_nnsight.trace(prompt) as runner:
        embedding = model_nnsight.transformer.wte.output.save()
        layer_1_output = model_nnsight.transformer.h[0].ln_1.output.save()
        attn_mid = model_nnsight.transformer.h[0].attn.c_attn.output.save()
        attn_out = model_nnsight.transformer.h[0].attn.c_proj.output.save()
        logits2 = model_nnsight.lm_head.output.save()

print(attn_out)
# print(embedding)

tensor([[[-0.1718,  1.5577,  1.8667,  ..., -1.3739,  5.0515, -5.4432],
         [-0.6355,  0.7799,  1.6813,  ..., -0.8211,  4.3663, -4.2849],
         [-0.4399,  1.3715,  1.9858,  ..., -1.1645,  3.6269, -2.7695],
         [ 0.0286,  0.8474,  1.3604,  ..., -0.9685,  3.2331, -2.9667]]])


In [243]:
(cache["blocks.0.hook_resid_pre"] == layer_1_output).all()

tensor(True)

In [244]:
(cache["blocks.0.hook_attn_out"] == attn_out)

tensor([[[False, False, False,  ..., False, False, False],
         [False, False, False,  ..., False, False, False],
         [False, False, False,  ..., False, False, False],
         [False, False, False,  ..., False, False, False]]])

In [245]:
cache["blocks.0.hook_attn_out"], attn_out

(tensor([[[ 1.7667,  0.8346, -0.4114,  ..., -0.1989,  0.0705,  0.0129],
          [ 0.7096, -0.2735, -0.5014,  ..., -0.1504,  0.1146,  0.1199],
          [-0.8842,  0.2575, -0.2667,  ..., -0.0960,  0.0550,  0.0929],
          [-0.1350,  0.2363, -0.2860,  ..., -0.0152, -0.0279,  0.1651]]]),
 tensor([[[-0.1718,  1.5577,  1.8667,  ..., -1.3739,  5.0515, -5.4432],
          [-0.6355,  0.7799,  1.6813,  ..., -0.8211,  4.3663, -4.2849],
          [-0.4399,  1.3715,  1.9858,  ..., -1.1645,  3.6269, -2.7695],
          [ 0.0286,  0.8474,  1.3604,  ..., -0.9685,  3.2331, -2.9667]]]))

In [246]:
nn_q, nn_k, nn_v = attn_mid.chunk(3, dim=-1)

In [247]:
t_q = cache["blocks.0.attn.hook_q"].reshape(1, 4, 12*64)
t_k = cache["blocks.0.attn.hook_k"].reshape(1, 4, 12*64)
t_v = cache["blocks.0.attn.hook_v"].reshape(1, 4, 12*64)

In [251]:
nn_v, t_v

(tensor([[[-4.8225e-02, -3.1213e-01,  4.0030e-01,  ..., -1.2137e+00,
           -1.0406e-01,  3.6115e-01],
          [ 2.9163e-01,  8.9444e-04,  1.9521e-01,  ..., -4.4900e-01,
            9.9427e-02, -7.0002e-01],
          [-2.4940e-01, -2.8410e-02, -3.4262e-03,  ..., -1.2382e-01,
            2.7189e-01, -8.3504e-01],
          [-6.9130e-02, -3.5274e-01, -1.0790e-01,  ..., -1.8287e-01,
            1.0387e-01, -8.1425e-01]]]),
 tensor([[[ 0.3170,  0.1050,  0.1032,  ...,  0.0614, -0.1157,  0.1413],
          [-0.1031,  0.2391, -0.1225,  ..., -0.0869,  0.4240,  0.1991],
          [ 0.0802, -0.1100,  0.1615,  ...,  0.1792,  0.0789,  0.3793],
          [ 0.0471,  0.0282,  0.0530,  ..., -0.1254,  0.1734,  0.2399]]]))

In [250]:
(nn_q ==t_q).all(), (nn_k == t_k).all(), (nn_v == t_v).all()

(tensor(True), tensor(True), tensor(False))

In [249]:
cache["blocks.0.hook_attn_out"]

tensor([[[ 1.7667,  0.8346, -0.4114,  ..., -0.1989,  0.0705,  0.0129],
         [ 0.7096, -0.2735, -0.5014,  ..., -0.1504,  0.1146,  0.1199],
         [-0.8842,  0.2575, -0.2667,  ..., -0.0960,  0.0550,  0.0929],
         [-0.1350,  0.2363, -0.2860,  ..., -0.0152, -0.0279,  0.1651]]])

In [127]:
logits, cache = hooked_model.run_with_cache(prompt)
print(logits)
(cache["hook_embed"] == embedding).all()

tensor([[[  86.2191,   95.8736,    8.8613,  ...,  -66.9563,  -51.0069,
            68.8289],
         [ 276.7668,  329.2255,   60.9252,  ..., -353.2906, -310.6661,
           226.2671],
         [ 304.9153,  355.0207,  104.5744,  ..., -398.3969, -337.3764,
           256.1240],
         [ 265.2124,  299.4694,   95.2499,  ..., -390.3568, -341.6725,
           231.0227]]], grad_fn=<ViewBackward0>)


tensor(True)

In [105]:
hooked_model.W_E.data, model_nnsight.transformer.wte.weight.data

(tensor([[-0.0989, -0.0314,  0.0315,  ..., -0.1332,  0.0162,  0.0412],
         [ 0.0224, -0.0670,  0.0560,  ...,  0.0702, -0.0031,  0.0405],
         [-0.0920,  0.0471,  0.1977,  ...,  0.0885, -0.1139, -0.0840],
         ...,
         [-0.0403, -0.0671,  0.0289,  ...,  0.0636,  0.0674, -0.0481],
         [ 0.1565,  0.0520,  0.0939,  ..., -0.1061,  0.0661, -0.0163],
         [ 0.0319, -0.0337,  0.0541,  ...,  0.0057,  0.1539,  0.1086]]),
 tensor([[-0.0989, -0.0314,  0.0315,  ..., -0.1332,  0.0162,  0.0412],
         [ 0.0224, -0.0670,  0.0560,  ...,  0.0702, -0.0031,  0.0405],
         [-0.0920,  0.0471,  0.1977,  ...,  0.0885, -0.1139, -0.0840],
         ...,
         [-0.0403, -0.0671,  0.0289,  ...,  0.0636,  0.0674, -0.0481],
         [ 0.1565,  0.0520,  0.0939,  ..., -0.1061,  0.0661, -0.0163],
         [ 0.0319, -0.0337,  0.0541,  ...,  0.0057,  0.1539,  0.1086]]))

In [93]:
model_nnsight.transformer.wte.weight.data.copy_(hooked_model.W_E.data)

tensor([[-0.0989, -0.0314,  0.0315,  ..., -0.1332,  0.0162,  0.0412],
        [ 0.0224, -0.0670,  0.0560,  ...,  0.0702, -0.0031,  0.0405],
        [-0.0920,  0.0471,  0.1977,  ...,  0.0885, -0.1139, -0.0840],
        ...,
        [-0.0403, -0.0671,  0.0289,  ...,  0.0636,  0.0674, -0.0481],
        [ 0.1565,  0.0520,  0.0939,  ..., -0.1061,  0.0661, -0.0163],
        [ 0.0319, -0.0337,  0.0541,  ...,  0.0057,  0.1539,  0.1086]])

In [53]:
logits, logits2

(tensor([[[25.9575, 27.0594, 23.7760,  ..., 17.6671, 18.9357, 26.1557],
          [24.1554, 25.3277, 22.7836,  ..., 11.8491, 13.6343, 20.8459],
          [24.2269, 24.8723, 26.2722,  ..., 11.1776, 14.7698, 21.2890],
          [20.3492, 20.0784, 22.7114,  ...,  7.8679, 10.2955, 17.9141]]],
        grad_fn=<ViewBackward0>),
 tensor([[[  1820.8167,  -4156.5581, -10981.4482,  ...,   6701.2451,
            10023.3066,  -3584.3672],
          [ 12723.0762,   4954.4199,  -3065.8164,  ...,  15503.4863,
            13852.7109,   4862.6143],
          [ 23775.0645,  15293.0449,   9126.6895,  ...,  25649.3242,
            21742.1602,  18540.6562],
          [ 20258.9121,  13101.9219,   5408.4473,  ...,  19866.7871,
            16596.6367,  13572.1055]]], grad_fn=<UnsafeViewBackward0>))

In [66]:
model_nnsight

GPT2LMHeadModel(
  (transformer): GPT2Model(
    (wte): Embedding(50257, 768)
    (wpe): Embedding(1024, 768)
    (drop): Dropout(p=0, inplace=False)
    (h): ModuleList(
      (0-11): 12 x GPT2Block(
        (ln_1): Identity()
        (attn): GPT2SdpaAttention(
          (c_attn): Conv1D()
          (c_proj): Conv1D()
          (attn_dropout): Dropout(p=0, inplace=False)
          (resid_dropout): Dropout(p=0, inplace=False)
        )
        (ln_2): Identity()
        (mlp): GPT2MLP(
          (c_fc): Conv1D()
          (c_proj): Conv1D()
          (act): NewGELUActivation()
          (dropout): Dropout(p=0, inplace=False)
        )
      )
    )
    (ln_f): Identity()
  )
  (lm_head): Linear(in_features=768, out_features=50257, bias=False)
  (generator): WrapperModule()
)

In [3]:
model_nnsight.transformer.wte.weight.data = hooked_model.W_E.data
model_nnsight.transformer.wpe.weight.data = hooked_model.W_pos.data
model_nnsight.lm_head.weight.data = hooked_model.W_U.data.T

In [32]:
# Ln Done already
# Attn
hooked_model.W_in.shape, hooked_model.W_out.shape, model_nnsight.transformer.h[0].mlp.c_fc.weight.data.shape, model_nnsight.transformer.h[0].mlp.c_proj.weight.data.shape

(torch.Size([12, 768, 3072]),
 torch.Size([12, 3072, 768]),
 torch.Size([768, 3072]),
 torch.Size([3072, 768]))

In [87]:
hooked_model.W_pos.shape, model_nnsight.transformer.wpe.weight.data.shape

(torch.Size([1024, 768]), torch.Size([1024, 768]))

In [17]:
hooked_model.W_K[0].data.shape, model_nnsight.transformer.h[0].attn.c_attn.weight.data.shape

(torch.Size([12, 768, 64]), torch.Size([768, 2304]))

In [21]:
model_nnsight.transformer.h[0].attn.c_proj.weight.data.shape

torch.Size([768, 768])

In [28]:
# for i in range(len(self.blocks)):
#     self.blocks[i].ln1 = torch.nn.Identity()
#     self.blocks[i].ln2 = torch.nn.Identity()
# self.ln_final = torch.nn.Identity()
for i in range(len(model_nnsight.transformer.h)):
    model_nnsight.transformer.h[i].ln_1 = torch.nn.Identity()
    model_nnsight.transformer.h[i].ln_2 = torch.nn.Identity()
model_nnsight.transformer.ln_f = torch.nn.Identity()

In [30]:
model_nnsight.transformer.h[0].ln_1

Identity()

In [15]:
model

GPT2LMHeadModel(
  (transformer): GPT2Model(
    (wte): Embedding(50257, 768)
    (wpe): Embedding(1024, 768)
    (drop): Dropout(p=0, inplace=False)
    (h): ModuleList(
      (0-11): 12 x GPT2Block(
        (ln_1): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
        (attn): GPT2SdpaAttention(
          (c_attn): Conv1D()
          (c_proj): Conv1D()
          (attn_dropout): Dropout(p=0, inplace=False)
          (resid_dropout): Dropout(p=0, inplace=False)
        )
        (ln_2): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
        (mlp): GPT2MLP(
          (c_fc): Conv1D()
          (c_proj): Conv1D()
          (act): NewGELUActivation()
          (dropout): Dropout(p=0, inplace=False)
        )
      )
    )
    (ln_f): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
  )
  (lm_head): Linear(in_features=768, out_features=50257, bias=False)
)

In [4]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
prompt = torch.tensor([1,2,3,4], device=device)
logits = hooked_model(prompt)

print(logits.shape)
print(logits[0, 0, :10])

torch.Size([1, 4, 50257])
tensor([25.9575, 27.0594, 23.7760, 24.0054, 25.5462, 23.6517, 26.7588, 25.8552,
        27.0049, 25.4789], grad_fn=<SliceBackward0>)


In [5]:
hooked_model

HookedTransformerNoLN(
  (embed): Embed()
  (hook_embed): HookPoint()
  (pos_embed): PosEmbed()
  (hook_pos_embed): HookPoint()
  (blocks): ModuleList(
    (0-11): 12 x TransformerBlock(
      (ln1): Identity()
      (ln2): Identity()
      (attn): Attention(
        (hook_k): HookPoint()
        (hook_q): HookPoint()
        (hook_v): HookPoint()
        (hook_z): HookPoint()
        (hook_attn_scores): HookPoint()
        (hook_pattern): HookPoint()
        (hook_result): HookPoint()
      )
      (mlp): MLP(
        (hook_pre): HookPoint()
        (hook_post): HookPoint()
      )
      (hook_attn_in): HookPoint()
      (hook_q_input): HookPoint()
      (hook_k_input): HookPoint()
      (hook_v_input): HookPoint()
      (hook_mlp_in): HookPoint()
      (hook_attn_out): HookPoint()
      (hook_mlp_out): HookPoint()
      (hook_resid_pre): HookPoint()
      (hook_resid_mid): HookPoint()
      (hook_resid_post): HookPoint()
    )
  )
  (ln_final): Identity()
  (unembed): Unembed()
)