In [1]:
from transformer_lens import HookedTransformer
from sae_lens import SAE
import plotly.express as px
import torch

device = "mps"
model = HookedTransformer.from_pretrained_no_processing("meta-llama/Meta-Llama-3-8B", device = device)

Loading checkpoint shards:   0%|          | 0/4 [00:00<?, ?it/s]

Special tokens have been added in the vocabulary, make sure the associated word embeddings are fine-tuned or trained.


Loaded pretrained model meta-llama/Meta-Llama-3-8B into HookedTransformer


In [2]:
from sae_lens import SAE
# the cfg dict is returned alongside the SAE since it may contain useful information for analysing the SAE (eg: instantiating an activation store)
# Note that this is not the same as the SAEs config dict, rather it is whatever was in the HF repo, from which we can extract the SAE config dict
# We also return the feature sparsities which are stored in HF for convenience. 
sae, cfg_dict, sparsity = SAE.from_pretrained(
    release = "sae-llama-3-8b-eai", # see other options in sae_lens/pretrained_saes.yaml
    sae_id = "blocks.3.hook_resid_post", # won't always be a hook point
    device = "mps",
)

In [7]:
sae.cfg.context_size = 128

In [8]:
_, cache = model.run_with_cache("I like to eat apples and bananas", names_filter = [sae.cfg.hook_name])
feature_acts = sae.encode(cache[sae.cfg.hook_name])

In [9]:
from sae_lens import ActivationsStore

activation_store = ActivationsStore.from_sae(
    model=model,
    sae=sae,
    streaming=True,
    # fairly conservative parameters here so can use same for larger
    # models without running out of memory.
    store_batch_size_prompts=8,
    train_batch_size_tokens=4096,
    n_batches_in_buffer=4,
    device=device,
)




In [11]:
from sae_lens import run_evals

metrics = run_evals(sae, activation_store, model, n_eval_batches=8, eval_batch_size_prompts=8)
print(metrics)

{'metrics/l2_norm': 2.8263065814971924, 'metrics/l2_ratio': 0.8937495946884155, 'metrics/l2_norm_in': 5.387141227722168, 'metrics/CE_loss_score': -0.03879524770309217, 'metrics/ce_loss_without_sae': 2.813219517469406, 'metrics/ce_loss_with_sae': 12.110986828804016, 'metrics/ce_loss_with_ablation': 11.761781692504883}


In [12]:
import circuitsvis as cv 
prompt = "1 2 3 4 5 6 7 8 9 10 11 12"
logits, cache = model.run_with_cache(prompt)
display(cv.logits.token_log_probs(model.to_tokens(prompt), model(prompt)[0].log_softmax(dim=-1), model.to_string))
