In [1]:
# So HF saves cache to RunPod's persistent volume
import os
os.environ["TRANSFORMERS_CACHE"] = "/workspace/cache/"

In [2]:
import torch
import transformer_lens.loading_from_pretrained as loading

from transformers import AutoModelForCausalLM, AutoTokenizer
from transformer_lens import HookedTransformer

torch.set_grad_enabled(False)

<torch.autograd.grad_mode.set_grad_enabled at 0x7fb043fba2c0>

In [3]:
# Example of loading from hf then tlens
dtype = torch.float32
weights_source = "NousResearch/Llama-2-7b-chat-hf"
tlens_arch = "Llama-2-7b-chat-hf"

In [4]:
# Load from hf
tokenizer = AutoTokenizer.from_pretrained(weights_source)
hf_model = AutoModelForCausalLM.from_pretrained(weights_source, torch_dtype=dtype)

Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s]



In [5]:
# Load from hf into tlens
cfg = loading.get_pretrained_model_config(tlens_arch, torch_type=dtype)
cf_model = HookedTransformer(cfg, tokenizer=tokenizer)
state_dict = loading.get_pretrained_state_dict(tlens_arch, cfg, hf_model)
cf_model.load_state_dict(state_dict, strict=False)

_IncompatibleKeys(missing_keys=['blocks.0.attn.mask', 'blocks.0.attn.IGNORE', 'blocks.0.attn.rotary_sin', 'blocks.0.attn.rotary_cos', 'blocks.1.attn.mask', 'blocks.1.attn.IGNORE', 'blocks.1.attn.rotary_sin', 'blocks.1.attn.rotary_cos', 'blocks.2.attn.mask', 'blocks.2.attn.IGNORE', 'blocks.2.attn.rotary_sin', 'blocks.2.attn.rotary_cos', 'blocks.3.attn.mask', 'blocks.3.attn.IGNORE', 'blocks.3.attn.rotary_sin', 'blocks.3.attn.rotary_cos', 'blocks.4.attn.mask', 'blocks.4.attn.IGNORE', 'blocks.4.attn.rotary_sin', 'blocks.4.attn.rotary_cos', 'blocks.5.attn.mask', 'blocks.5.attn.IGNORE', 'blocks.5.attn.rotary_sin', 'blocks.5.attn.rotary_cos', 'blocks.6.attn.mask', 'blocks.6.attn.IGNORE', 'blocks.6.attn.rotary_sin', 'blocks.6.attn.rotary_cos', 'blocks.7.attn.mask', 'blocks.7.attn.IGNORE', 'blocks.7.attn.rotary_sin', 'blocks.7.attn.rotary_cos', 'blocks.8.attn.mask', 'blocks.8.attn.IGNORE', 'blocks.8.attn.rotary_sin', 'blocks.8.attn.rotary_cos', 'blocks.9.attn.mask', 'blocks.9.attn.IGNORE', 'blo

In [6]:
tokens = tokenizer("Why did the chicken cross the road?", return_tensors="pt")["input_ids"]
hf_logits = hf_model(tokens).logits
cf_logits = cf_model(tokens)

torch.allclose(hf_logits.cpu(), cf_logits.cpu(), atol=1e-4)

True

In [7]:
cf_model

HookedTransformer(
  (embed): Embed()
  (hook_embed): HookPoint()
  (blocks): ModuleList(
    (0-31): 32 x TransformerBlock(
      (ln1): RMSNorm(
        (hook_scale): HookPoint()
        (hook_normalized): HookPoint()
      )
      (ln2): RMSNorm(
        (hook_scale): HookPoint()
        (hook_normalized): HookPoint()
      )
      (attn): Attention(
        (hook_k): HookPoint()
        (hook_q): HookPoint()
        (hook_v): HookPoint()
        (hook_z): HookPoint()
        (hook_attn_scores): HookPoint()
        (hook_pattern): HookPoint()
        (hook_result): HookPoint()
        (hook_rot_k): HookPoint()
        (hook_rot_q): HookPoint()
      )
      (mlp): GatedMLP(
        (hook_pre): HookPoint()
        (hook_pre_linear): HookPoint()
        (hook_post): HookPoint()
      )
      (hook_attn_in): HookPoint()
      (hook_q_input): HookPoint()
      (hook_k_input): HookPoint()
      (hook_v_input): HookPoint()
      (hook_mlp_in): HookPoint()
      (hook_attn_out): HookPoint(

In [8]:
cf_model.cfg

HookedTransformerConfig:
{'act_fn': 'silu',
 'attention_dir': 'causal',
 'attn_only': False,
 'attn_types': None,
 'checkpoint_index': None,
 'checkpoint_label_type': None,
 'checkpoint_value': None,
 'd_head': 128,
 'd_mlp': 11008,
 'd_model': 4096,
 'd_vocab': 32000,
 'd_vocab_out': 32000,
 'default_prepend_bos': True,
 'device': device(type='cuda'),
 'dtype': torch.float32,
 'eps': 1e-05,
 'final_rms': True,
 'from_checkpoint': False,
 'gated_mlp': True,
 'init_mode': 'gpt2',
 'init_weights': False,
 'initializer_range': 0.0125,
 'model_name': 'Llama-2-7b-chat-hf',
 'n_ctx': 4096,
 'n_devices': 1,
 'n_heads': 32,
 'n_layers': 32,
 'n_params': 5033164800,
 'normalization_type': 'RMS',
 'original_architecture': 'LlamaForCausalLM',
 'parallel_attn_mlp': False,
 'positional_embedding_type': 'rotary',
 'post_embedding_ln': False,
 'rotary_dim': 128,
 'scale_attn_by_inverse_layer_idx': False,
 'seed': None,
 'tokenizer_name': 'Llama-2-7b-chat-hf',
 'tokenizer_prepends_bos': True,
 'use_at