In [1]:
import os
import json
from dotenv import load_dotenv

import torch
from transformers import AutoModelForCausalLM, AutoTokenizer

from dictionary_learning import AutoEncoder
from dictionary_learning import utils
import config


  from .autonotebook import tqdm as notebook_tqdm


NOTE: Training on 500000000 tokens


In [2]:
load_dotenv()
hf_token = os.getenv("HF_TOKEN") 
device = 'cuda:7' if torch.cuda.is_available() else 'cpu'

model_id = 'google/gemma-2-2b'

sae_release_id = 'gemma-scope-2b-pt-mlp-canonical'
sae_id = 'layer_12/width_16k/canonical'

layer = int(sae_id.split("/")[0].split("_")[1])

ctx_len = 128
add_special_tokens = True

#### load dataset

In [3]:
data = json.load(open("alpaca.json", "r"))

data_l = [list(i.values())[0]['prompt'] for i in data]

#### load models

In [4]:
model = AutoModelForCausalLM.from_pretrained(model_id, token=hf_token, dtype=config.LLM_CONFIG[model_id].dtype).to(device)
print(f"loaded model {model_id} on {device}\n")
# if layer = 12:, model = model[ : layer + 1]
model = utils.truncate_model(model, layer)

tokenizer = AutoTokenizer.from_pretrained(model_id, token=hf_token)
# if layer = 12:, submodel = model[layer]
submodule, layer = utils.get_submodule(model, layer)

ae = AutoEncoder.from_pretrained(load_from_sae_lens=True, release=sae_release_id, sae_id=sae_id, device=device)
print(f"\nloaded SAE {sae_release_id} hooked to layer {layer} of {model_id} on {device}")


Loading checkpoint shards: 100%|██████████| 3/3 [00:03<00:00,  1.27s/it]


loaded model google/gemma-2-2b on cuda:7

Model parameters before truncation: 2,614,341,888
Model parameters after truncation: 1,602,084,096

loaded SAE gemma-scope-2b-pt-mlp-canonical hooked to layer 12 of google/gemma-2-2b on cuda:7


In [None]:
tokens = tokenizer(data_l, return_tensors='pt', max_length=ctx_len, padding=True, truncation=True, add_special_tokens=add_special_tokens).to(device)