In [109]:
import json

import matplotlib.pyplot as plt
import numpy as np
import torch as t
import torch.nn.functional as F
from beartype import beartype as typed
from datasets import load_dataset
from jaxtyping import Float, Int
from torch import Tensor as TT
from transformer_lens import (
    ActivationCache,
    HookedTransformer,
    HookedTransformerConfig,
    utils,
)
from einops import rearrange
from transformers import AutoModelForCausalLM, AutoTokenizer

%load_ext autoreload
%autoreload 2

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


In [27]:
mapping = {
    "attn.attention.k_proj.weight": "attn.W_K",
    "attn.attention.q_proj.weight": "attn.W_Q",
    "attn.attention.v_proj.weight": "attn.W_V",
    "attn.attention.out_proj.bias": "attn.b_O",
    "attn.attention.out_proj.weight": "attn.W_O",
    "ln_1.bias": "ln1.b",
    "ln_1.weight": "ln1.w",
    "ln_2.bias": "ln2.b",
    "ln_2.weight": "ln2.w",
    "mlp.c_fc.bias": "mlp.b_in",
    "mlp.c_fc.weight": "mlp.W_in",
    "mlp.c_proj.bias": "mlp.b_out",
    "mlp.c_proj.weight": "mlp.W_out",
}

config = HookedTransformerConfig(
    **{
        "act_fn": "gelu_new",
        "attention_dir": "causal",
        "attn_only": False,
        "attn_types": [
            "global",
            "local",
            "global",
            "local",
            "global",
            "local",
            "global",
            "local",
        ],
        "checkpoint_index": None,
        "checkpoint_label_type": None,
        "checkpoint_value": None,
        "d_head": 16,
        "d_mlp": 1024,
        "d_model": 256,
        "d_vocab": 502,
        "d_vocab_out": 502,
        "device": "cpu",
        "eps": 1e-05,
        "final_rms": False,
        "from_checkpoint": False,
        "gated_mlp": False,
        "init_mode": "gpt2",
        "init_weights": True,
        "initializer_range": 0.05,
        "model_name": "brackets-nested",
        "n_ctx": 2048,
        "n_devices": 1,
        "n_heads": 16,
        "n_layers": 8,
        "n_params": 6291456,
        "normalization_type": "LN",
        "original_architecture": "GPTNeoForCausalLM",
        "parallel_attn_mlp": False,
        "positional_embedding_type": "standard",
        "rotary_dim": None,
        "scale_attn_by_inverse_layer_idx": False,
        "seed": None,
        "tokenizer_name": "Mlxa/brackets-nested",
        "use_attn_result": False,
        "use_attn_scale": False,
        "use_hook_tokens": False,
        "use_local_attn": True,
        "use_split_qkv_input": False,
        "window_size": 256,
    }
)

In [119]:
model = HookedTransformer(config)
param_holder = AutoModelForCausalLM.from_pretrained("Mlxa/brackets-nested")
old_dict = param_holder.state_dict()
new_dict = model.state_dict()
ka = list(old_dict.keys())
kb = list(new_dict.keys())

loaded = set()
for x in ka:
    if "transformer.h." in x:
        y = x.replace("transformer.h.", "blocks.")
        suffix = y[len("blocks.#."):]
        y = y.replace(suffix, mapping[suffix])
    elif "lm_head" in x:
        y = "unembed.W_U"
    elif "wpe" in x:
        y = "pos_embed.W_pos"
    elif "wte" in x:
        y = "embed.W_E"
    elif "ln_f.bias" in x:
        y = "ln_final.b"
    elif "ln_f.weight" in x:
        y = "ln_final.w"
    else:
        print("UNKNOWN", x)
    data = old_dict[x]
    if data.shape != new_dict[y].shape:
        data = data.T
    if data.shape != new_dict[y].shape:
        if new_dict[y].shape[1] > new_dict[y].shape[2]:
            data = rearrange(data, "d_in (h d_out) -> h d_in d_out", h=new_dict[y].shape[0])
        else:
            data = rearrange(data, "(h d_in) d_out -> h d_in d_out", h=new_dict[y].shape[0])
    print(x, "->", y, ":", data.shape, new_dict[y].shape)
    assert new_dict[y].shape == data.shape
    new_dict[y] = data
    assert y in kb
    loaded.add(y)

for x in kb:
    if x in loaded:
        continue
    elif "b_" in x:
        print("zero", x, new_dict[x].shape)
        new_dict[x] = t.zeros_like(new_dict[x])
        loaded.add(x)
    else:
        print("constant", x)
        loaded.add(x)

print(len(loaded))
model.load_state_dict(new_dict)

Using eos_token, but it is not set yet.
Using bos_token, but it is not set yet.


transformer.wte.weight -> embed.W_E : torch.Size([502, 256]) torch.Size([502, 256])
transformer.wpe.weight -> pos_embed.W_pos : torch.Size([2048, 256]) torch.Size([2048, 256])
transformer.h.0.ln_1.weight -> blocks.0.ln1.w : torch.Size([256]) torch.Size([256])
transformer.h.0.ln_1.bias -> blocks.0.ln1.b : torch.Size([256]) torch.Size([256])
transformer.h.0.attn.attention.k_proj.weight -> blocks.0.attn.W_K : torch.Size([16, 256, 16]) torch.Size([16, 256, 16])
transformer.h.0.attn.attention.v_proj.weight -> blocks.0.attn.W_V : torch.Size([16, 256, 16]) torch.Size([16, 256, 16])
transformer.h.0.attn.attention.q_proj.weight -> blocks.0.attn.W_Q : torch.Size([16, 256, 16]) torch.Size([16, 256, 16])
transformer.h.0.attn.attention.out_proj.weight -> blocks.0.attn.W_O : torch.Size([16, 16, 256]) torch.Size([16, 16, 256])
transformer.h.0.attn.attention.out_proj.bias -> blocks.0.attn.b_O : torch.Size([256]) torch.Size([256])
transformer.h.0.ln_2.weight -> blocks.0.ln2.w : torch.Size([256]) torch.

<All keys matched successfully>

In [130]:
# model.generate("<1 <2 <3", prepend_bos=False)
from utils import generate_sample

tokenizer = model.tokenizer
inputs = tokenizer("<1 <2 <3", return_tensors="pt")
# del inputs["token_type_ids"]
result = param_holder.generate(**inputs, do_sample=True, max_new_tokens=20)
tokenizer.decode(result.squeeze())
# result = model.generate(inputs["input_ids"], do_sample=True)
# tokenizer.decode(result.squeeze())

Setting `pad_token_id` to `eos_token_id`:50256 for open-end generation.


'<1 <2 <3 <216 <78 <85 <100 <161 161> <132 <29 29> <42 42> <130 <42 <17 <91 <68 <125 125> <48 16>'