In [None]:
%load_ext autoreload
%autoreload 2

In [None]:
import jax_smi
jax_smi.initialise_tracking()

In [None]:
from saex.iterable_dataset import IterableDatasetConfig
from saex.models.micrlhf_model import MicrlhfModelConfig
from saex.model_haver import ModelHaver
from saex.sae import SAEConfig
from more_itertools import chunked


n_features = 3072
batch_size = 32
sae_config = SAEConfig(
    n_dimensions=n_features,
    sparsity_loss_type="l1",
    sparsity_coefficient=0,
    batch_size=batch_size,
    expansion_factor=32,
    use_encoder_bias=True,
    remove_decoder_bias=False,
    encoder_init_method="orthogonal",
    decoder_init_method="pseudoinverse",
    decoder_bias_init_method="zeros",
    reconstruction_loss_type="mse_batchnorm",
    project_updates_from_dec=True,
    death_loss_type="dm_ghost_grads",
    death_penalty_threshold=5e-7,
    death_penalty_coefficient=0.25,
    dead_after=1_000,
    buffer_size=2_000,
    restrict_dec_norm="exact",
    sparsity_tracking_epsilon=0.1,
    is_gated=True,
)
dataset_config = IterableDatasetConfig(
    dataset_name="nev/openhermes-2.5-phi-format-text",
)
model_config = MicrlhfModelConfig(
    tokenizer_path="microsoft/Phi-3-mini-4k-instruct",
    gguf_path="weights/phi-3-16.gguf",
    device_map="tpu:0",

    layer=11,
    max_seq_len=128,
)
haver = ModelHaver(model_config=model_config, sae_config=sae_config,
                    dataset_config=dataset_config,
                    sae_restore="weights/phi-l11-gated.safetensors")

In [None]:
from collections import defaultdict
from tqdm.auto import tqdm
import jax.numpy as jnp
import numpy as np
tokens_processed = 0
activ_cache = defaultdict(list)
for texts in chunked(bar := tqdm(haver.create_dataset()), batch_size):
    activations, model_misc = haver.model(texts)
    mask = model_misc.get("mask")
    pre_relu, hiddens = haver.sae.encode(activations)
    bar.set_postfix(l0=((hiddens != 0).sum(-1) * mask).mean() / mask.mean())
    hiddens = np.asarray(hiddens.astype(jnp.float16))
    for i, h in enumerate(hiddens):
        active_features = np.nonzero(h)[0]
        feature_activations = h[active_features]
        for f, a in zip(active_features, feature_activations):
            activ_cache[int(f)].append((tokens_processed + i, float(a)))
    tokens_processed += hiddens.shape[0]

In [29]:
def visualize(feature, thresh=10.0):
    cache = activ_cache[feature]
    if not cache:    
        return
    tokens, activs = zip(*cache)
    if max(activs) < thresh:
        return
    tokens_processed = 0
    sli = 24
    for texts in chunked(tqdm(haver.create_dataset()), batch_size):
        toks = haver.model.to_tokens(texts)
        all_tokens = [t for tok in toks for t in tok]
        proc = sum(map(len, toks))
        all_token_ids = [tokens_processed + i for i in range(proc)]
        for i, t in enumerate(all_token_ids):
            if t in tokens:
                activ = activs[tokens.index(t)]
                if activ < thresh:
                    continue
                print(activ, repr(haver.model.decode(all_tokens[max(0, i - sli):i])),
                      repr(haver.model.decode(all_tokens[i:i+4])))
        tokens_processed += proc
        if tokens_processed > max(tokens):
            break
for i in range(10_000, 10_0100):
    visualize(i)
    print()

0it [00:00, ?it/s]

10.78125 ', two minutes into the Grand<s><|user|> You are a helpful assistant, who always provide explanation. Think like you are answering' 'to a five year'
11.3125 '. Here is an example HTML<s><|user|> You are a helpful assistant, who always provide explanation. Think like you are answering' 'to a five year'
11.203125 '\n\n\tif (date<s><|user|> You are a helpful assistant, who always provide explanation. Think like you are answering' 'to a five year'
10.484375 'The power goes out, so<s><|user|> You are a helpful assistant, who always provide explanation. Think like you are answering' 'to a five year'
11.3828125 '\n```\n# Poor<s><|user|> You are a helpful assistant, who always provide explanation. Think like you are answering' 'to a five year'
10.984375 '<|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><s><|user|> You are a helpful assistant, who always provide explanation. Think like you are answering' 'to a five year'
10.5 'through a kind of arm-el<s><|u

0it [00:00, ?it/s]

11.3359375 'the shortest distance between the circles defined by $x^2-10x +y^2-4y' '-7=0'
























































































































































































































































































0it [00:00, ?it/s]

In [None]:
f

In [None]:
activ_cache

In [1]:
%load_ext autoreload
%autoreload 2
from scripts.train_phi_sae import main
main(layer=11, is_gated=True, sparsity_coefficient=1.4e-5, n_devices=4)



Loading model...


Special tokens have been added in the vocabulary, make sure the associated word embeddings are fine-tuned or trained.


Creating SAE...
Loading dataset...


Failed to detect the name of this notebook, you can set it manually with the WANDB_NOTEBOOK_NAME environment variable to enable code saving.


Creating buffer...
Training for 100000 iterations
Sparsity coefficient: 1.4e-05


[34m[1mwandb[0m: Currently logged in as: [33mneverix[0m. Use [1m`wandb login --relogin`[0m to force relogin


Learning rate: 0.0004 warmed up for 128 iterations and cycled every 100000 iterations


  0%|          | 0/100000 [00:00<?, ?it/s]

buf 0.0
49170.3671875
0.0
buf 24.585186004638672
49169.7109375
4917.037109375
buf 49.1700439453125
48581.9765625
9342.3046875
buf 73.4610366821289
47364.4140625
13266.271484375
buf 97.14324188232422
45614.1796875
16676.0859375
buf 119.95034790039062
43386.0859375
19569.89453125
buf 141.64337158203125
41042.234375
21951.51171875
buf 162.16448974609375
38310.20703125
23860.5859375
buf 181.31961059570312
35663.01171875
25305.546875
buf 199.15109252929688
33010.59375
26341.29296875
buf 215.65640258789062
30800.330078125
27008.22265625
buf 231.05654907226562
28479.984375
27387.4296875
buf 245.29656982421875
26396.130859375
27496.6875
buf 258.49462890625
24632.990234375
27386.62890625
buf 270.81109619140625
23329.203125
27111.265625
buf 282.4757080078125
21993.94140625
26733.0625
buf 293.47271728515625
21034.201171875
26259.146484375
buf 303.98980712890625
19922.357421875
25736.65234375
buf 313.9509582519531
19120.798828125
25155.22265625
buf 323.51141357421875
18148.45703125
24551.77734375


In [None]:
from scripts.train_gated_sae import main
main(cache_batch_size=256)

In [None]:
from scripts.train_phi_sae import main
main(layer=11, is_gated=True)

In [None]:
from scripts.train_phi_sae import main
main(layer=11, is_gated=True, sparsity_coefficient=4.2e-6)

In [None]:
from scripts.train_phi_sae import main
main(layer=30)

In [None]:
from scripts.train_gated_sae import main
main(is_xl=True, layer=20, cache_batch_size=256)
# main(is_xl=True, layer=30, cache_batch_size=256, restore="weights/gpt2-20-gated.safetensors")

In [None]:
from scripts.train_gpt2_sae import main
main(restore="weights/gpt2-20-base.safetensors", layer=32)
# python -m scripts.train_gpt2_sae --save_steps 0 --is_xl=False --layer=9
# main(save_steps=0, is_xl=False, layer=1)