## Imports

In [None]:
import re
import torch as t
from functools import partial
from transformers import AutoTokenizer, AutoModelForCausalLM
from utils import HookedMistral

t.set_grad_enabled(False)
device = "cuda" if t.cuda.is_available() else "cpu"

## Load model

In [None]:
model_name = "mistralai/Mistral-7B-Instruct-v0.1"
tokenizer = AutoTokenizer.from_pretrained(model_name, padding=True, padding_side="left")
tokenizer.pad_token_id = 1
hf_model = AutoModelForCausalLM.from_pretrained(model_name, torch_dtype=t.float16, device_map="auto")

In [None]:
model = HookedMistral(hf_model, tokenizer)
model.print_names()

## Run some tests

In [None]:
names = [
    "model.layers.28",
    "model.layers.29.self_attn",
    "model.layers.29.mlp",
    "model.layers.29",
]
logits, cache = model.run_with_cache("Will the tensors match?", names)

print(logits.shape)
print(cache.keys())

t.allclose(
    cache['model.layers.29'],
    cache['model.layers.28'] + cache['model.layers.29.self_attn'] + cache['model.layers.29.mlp'],
    atol=1e-5,
)