## Imports

In [16]:
import torch as t
from transformers import AutoTokenizer, AutoModelForCausalLM
from utils import HookedMistral

t.set_grad_enabled(False)
device = "cuda" if t.cuda.is_available() else "cpu"

## Load model

In [2]:
model_name = "mistralai/Mistral-7B-Instruct-v0.1"
tokenizer = AutoTokenizer.from_pretrained(model_name, padding=True, padding_side="left")
tokenizer.pad_token_id = 1
hf_model = AutoModelForCausalLM.from_pretrained(
    model_name,
    torch_dtype=t.float16,
    device_map="auto",
    cache_dir="/workspace/cache/",
)

tokenizer_config.json:   0%|          | 0.00/1.47k [00:00<?, ?B/s]

tokenizer.model:   0%|          | 0.00/493k [00:00<?, ?B/s]

tokenizer.json:   0%|          | 0.00/1.80M [00:00<?, ?B/s]

special_tokens_map.json:   0%|          | 0.00/72.0 [00:00<?, ?B/s]

Downloading shards:   0%|          | 0/2 [00:00<?, ?it/s]

Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s]

In [17]:
model = HookedMistral(hf_model, tokenizer)
model.print_names()

model
model.embed_tokens
model.layers
model.layers.{0..31}
model.layers.{0..31}.self_attn
model.layers.{0..31}.self_attn.q_proj
model.layers.{0..31}.self_attn.k_proj
model.layers.{0..31}.self_attn.v_proj
model.layers.{0..31}.self_attn.o_proj
model.layers.{0..31}.self_attn.rotary_emb
model.layers.{0..31}.mlp
model.layers.{0..31}.mlp.gate_proj
model.layers.{0..31}.mlp.up_proj
model.layers.{0..31}.mlp.down_proj
model.layers.{0..31}.mlp.act_fn
model.layers.{0..31}.input_layernorm
model.layers.{0..31}.post_attention_layernorm
model.norm
lm_head


## Run some tests

In [22]:
from utils import HookedMistral
model = HookedMistral(hf_model, tokenizer)
model.reset_hooks()

names = [
    "model.layers.28",
    "model.layers.29.self_attn",
    "model.layers.29.mlp",
    "model.layers.29",
]
logits, cache = model.run_with_cache("Will the tensors match?", names)

print(logits.shape)
print(cache.keys())

t.allclose(
    cache['model.layers.29'],
    cache['model.layers.28'] + cache['model.layers.29.self_attn'] + cache['model.layers.29.mlp'],
    atol=1e-5,
)

torch.Size([1, 7, 32000])
dict_keys(['model.layers.28', 'model.layers.29.self_attn', 'model.layers.29.mlp', 'model.layers.29'])


True