In [None]:
%load_ext autoreload
%autoreload 2

In [None]:
import os
os.chdir(os.path.dirname(os.path.dirname(__vsc_ipynb_file__)))

In [None]:
import jax_smi
jax_smi.initialise_tracking()

In [None]:
%load_ext autoreload
%autoreload 2
# %load_ext line_profiler
from scripts.train_phi_sae import main
import saex.trainer_cache
import saex.models.micrlhf_model
# %lprun -f saex.trainer_cache.BufferTrainer.train main(layer=11, is_gated=True, sparsity_coefficient=1.4e-5, n_devices=4)
# %lprun -f saex.models.micrlhf_model.MicrlhfModel.__call__ main(layer=11, is_gated=True, sparsity_coefficient=1.4e-5, n_devices=4)
# %lprun -f saex.models.micrlhf_model.MicrlhfModel.encode_texts main(layer=11, is_gated=True, sparsity_coefficient=1.4e-5, n_devices=4)
main(layer=11, is_gated=True, sparsity_coefficient=4e-6, n_devices=4, use_recip=False)

In [None]:
from scripts.train_gpt2_sae import main
main(size=2, layer=16, push_to_hub=("nev/gpt2_medium_saes-saex-test", "l16-test"))

In [None]:
from saex.iterable_dataset import IterableDatasetConfig
from saex.models.micrlhf_model import MicrlhfModelConfig
from saex.model_haver import ModelHaver
from saex.sae import SAEConfig
from more_itertools import chunked


n_features = 3072
batch_size = 64
sae_config = SAEConfig(
    n_dimensions=n_features,
    sparsity_loss_type="l1",
    sparsity_coefficient=0,
    batch_size=batch_size,
    expansion_factor=32,
    use_encoder_bias=True,
    remove_decoder_bias=False,
    encoder_init_method="orthogonal",
    decoder_init_method="pseudoinverse",
    decoder_bias_init_method="zeros",
    reconstruction_loss_type="mse_batchnorm",
    project_updates_from_dec=True,
    death_loss_type="dm_ghost_grads",
    death_penalty_threshold=5e-7,
    death_penalty_coefficient=0.25,
    dead_after=1_000,
    buffer_size=2_000,
    restrict_dec_norm="exact",
    sparsity_tracking_epsilon=0.1,
    is_gated=True,
)
dataset_config = IterableDatasetConfig(
    dataset_name="nev/openhermes-2.5-phi-format-text",
)
model_config = MicrlhfModelConfig(
    tokenizer_path="microsoft/Phi-3-mini-4k-instruct",
    gguf_path="weights/phi-3-16.gguf",
    device_map="tpu:0",
    use_flash=False,
    layer=11,
    max_seq_len=64,
)
haver = ModelHaver(model_config=model_config, sae_config=sae_config,
                    dataset_config=dataset_config,
                    sae_restore="weights/phi-l11-gated.safetensors")

In [None]:
haver.sae.push_to_hub("nev/phi-3-4k-saex-test", "l11-test1-recip-l0-100")

In [None]:
from collections import defaultdict
from tqdm import tqdm
import jax.numpy as jnp
import numpy as np
import jax


tokens_processed = 0
activ_cache = defaultdict(list)
for texts in chunked(bar := tqdm(haver.create_dataset()), batch_size):
    activations, model_misc = haver.model(texts)
    mask = model_misc.get("mask")
    pre_relu, hiddens = haver.sae.encode(activations)
    bar.set_postfix(l0=((hiddens != 0).sum(-1) * mask).mean() / mask.mean())
    indices = jnp.arange(len(mask)) + tokens_processed
    for feat in (hiddens != 0).any(axis=0).nonzero()[0]:
        greats = hiddens[:, feat]
        activ_cache[int(feat)].extend(zip(indices[mask], greats[mask]))
    # hiddens = np.asarray(hiddens.astype(jnp.float16))
    # for i, h in enumerate(hiddens):
    #     if not mask[i]:
    #         continue
    #     active_features = np.nonzero(h)[0]
    #     feature_activations = h[active_features]
    #     for f, a in zip(active_features, feature_activations):
    #         activ_cache[int(f)].append((tokens_processed + i, float(a)))
    tokens_processed += hiddens.shape[0]

In [None]:
def visualize(feature, thresh=6.0):
    cache = activ_cache[feature]
    if not cache:    
        return
    tokens, activs = zip(*cache)
    if max(activs) < thresh:
        return
    freq = len(tokens) / tokens_processed
    print(freq)
    if freq > 0.03:
        return
    tokens_viewed = 0
    sli = 24
    for texts in chunked(tqdm(haver.create_dataset()), batch_size):
        toks = haver.model.to_tokens(texts)
        all_tokens = [t for tok in toks for t in tok]
        proc = sum(map(len, toks))
        all_token_ids = [tokens_viewed + i for i in range(proc)]
        for i, t in enumerate(all_token_ids):
            if t in tokens:
                activ = activs[tokens.index(t)]
                if activ < thresh:
                    continue
                print(activ, repr(haver.model.decode(all_tokens[max(0, i - sli + 1):i+1])),
                      repr(haver.model.decode(all_tokens[i+1:i+5])))
        tokens_viewed += proc
        if tokens_viewed > max(tokens):
            break
for i in range(10_000, 10_0100):
    visualize(i)

In [None]:
from scripts.train_gated_sae import main
main(cache_batch_size=256)

In [None]:
from scripts.train_phi_sae import main
main(layer=11, is_gated=True)

In [None]:
from scripts.train_phi_sae import main
main(layer=11, is_gated=True, sparsity_coefficient=4.2e-6)

In [None]:
from scripts.train_phi_sae import main
main(layer=30)

In [None]:
from scripts.train_gated_sae import main
main(is_xl=True, layer=20, cache_batch_size=256)
# main(is_xl=True, layer=30, cache_batch_size=256, restore="weights/gpt2-20-gated.safetensors")

In [None]:
from scripts.train_gpt2_sae import main
main(restore="weights/gpt2-20-base.safetensors", layer=32)
# python -m scripts.train_gpt2_sae --save_steps 0 --is_xl=False --layer=9
# main(save_steps=0, is_xl=False, layer=1)