# Building run_with_saes

In [None]:
import sys 
sys.path.append("../..")
sys.path.append("..")

import os
from importlib import reload
from tqdm import tqdm

import joseph
from joseph.analysis import *
from joseph.visualisation import *
from joseph.utils import *
from joseph.data import *


reload(joseph.analysis)
reload(joseph.visualisation)
reload(joseph.utils)
reload(joseph.data)

from joseph.analysis import *
from joseph.visualisation import *
from joseph.utils import *
from joseph.data import *

# turn torch grad tracking off
torch.set_grad_enabled(False)

import webbrowser
from IPython.core.display import display, HTML

path_to_html = "../week_8_jan/gpt2_small_features_layer_5"
def render_feature_dashboard(feature_id):
    
    path = f"{path_to_html}/data_{feature_id:04}.html"
    
    print(f"Feature {feature_id}")
    if os.path.exists(path):
        # with open(path, "r") as f:
        #     html = f.read()
        #     display(HTML(html))
        webbrowser.open_new_tab("file://" + os.path.abspath(path))
    else:
        print("No HTML file found")
    

In [None]:
path_to_all_layer_saes = "../GPT2-small-SAEs/" 

log_sparsity_files = os.listdir(path_to_all_layer_saes)
# print(log_sparsity_files)
model_files = [f for f in log_sparsity_files if "log" not in f]
model_files = sorted(model_files, key=lambda x: int(x.split(".")[1]))
display(model_files)

log_sparsity_files = [f for f in log_sparsity_files if "log_feature_sparsity" in f]
log_sparsity_files = sorted(log_sparsity_files, key=lambda x: int(x.split(".")[1]))
log_sparsity_files

In [None]:
from sae_training.sparse_autoencoder import SparseAutoencoder

gpt2_small_sparse_autoencoders = {}
for path in model_files:
    layer = int(path.split(".")[1])
    print(f"Loading layer {layer}")
    sae = SparseAutoencoder.load_from_pretrained(f"{path_to_all_layer_saes}/{path}")
    sae.cfg.use_ghost_grads = False
    gpt2_small_sparse_autoencoders[sae.cfg.hook_point] = sae

In [None]:

model = HookedTransformer.from_pretrained(
    "gpt2-small",
    fold_ln=True,
)
model.set_use_split_qkv_input(True)
model.set_use_attn_result(True)



# Evaluations

In [None]:
from sae_training.utils import LMSparseAutoencoderSessionloader
model, sparse_autoencoder, activation_store = LMSparseAutoencoderSessionloader.load_session_from_pretrained(
    path = "../GPT2-small-SAEs/final_sparse_autoencoder_gpt2-small_blocks.7.hook_resid_pre_24576.pt"
)
sparse_autoencoder.cfg.use_ghost_grads = False

In [None]:
cfg = sparse_autoencoder.cfg

from sae_training.activations_store import ActivationStore


activation_store = ActivationStore(
    cfg,
    model,
)

In [None]:

@torch.no_grad()
def get_recons_loss(sparse_autoencoder, model, batch_tokens, hook_point):
    loss = model(batch_tokens, return_type="loss", loss_per_token=True)

    head_index = sparse_autoencoder.cfg.hook_point_head_index

    def mean_ablate_hook(mlp_post, hook):
        mlp_post[:] = mlp_post.mean([0, 1]).to(mlp_post.dtype)
        return mlp_post

    def zero_ablate_hook(mlp_post, hook):
        mlp_post[:] = 0.0
        return mlp_post

    def no_replacement_hook(activations, hook):
        return activations

    def standard_replacement_hook(activations, hook):
        activations = sparse_autoencoder.forward(activations)[0].to(activations.dtype)
        return activations

    def head_replacement_hook(activations, hook):
        new_actions = sparse_autoencoder.forward(activations[:,:,head_index])[0].to(activations.dtype)
        activations[:,:,head_index] = new_actions
        return activations

    replacement_hook = standard_replacement_hook if head_index is None else head_replacement_hook

    no_replacement_loss = model.run_with_hooks(
        batch_tokens,
        return_type="loss",
        fwd_hooks=[(hook_point, partial(no_replacement_hook))],
        loss_per_token=True,
    )
    
    recons_loss = model.run_with_hooks(
        batch_tokens,
        return_type="loss",
        fwd_hooks=[(hook_point, partial(replacement_hook))],
        loss_per_token=True,
    )

    zero_abl_loss = model.run_with_hooks(
        batch_tokens, return_type="loss", fwd_hooks=[(hook_point, zero_ablate_hook)],
        loss_per_token=True,
    )

    score = (zero_abl_loss - recons_loss) / (zero_abl_loss - loss)

    return score, loss, recons_loss, zero_abl_loss, no_replacement_loss

batch_tokens = activation_store.get_batch_tokens()
score, loss, recons_loss, zero_abl_loss, no_replacement_loss = get_recons_loss(sparse_autoencoder, model, batch_tokens, activation_store.cfg.hook_point)
score.shape

In [None]:

def recons_loss_batched(sparse_autoencoder, model, activation_store, n_batches = 100):
    
    losses = []
    for _ in tqdm(range(n_batches)):
        batch_tokens = activation_store.get_batch_tokens()
        score, loss, recons_loss, zero_abl_loss, no_replacement_loss = get_recons_loss(sparse_autoencoder, model, batch_tokens, activation_store.cfg.hook_point)
        losses.append((score.mean().item(), loss.mean().item(), recons_loss.mean().item(), zero_abl_loss.mean().item(), no_replacement_loss.mean().item()))

    losses = pd.DataFrame(losses, columns=["score", "loss", "recons_loss", "zero_abl_loss", "no_replacement_loss"])
    
    return losses

ce_losses = recons_loss_batched(sparse_autoencoder, model, activation_store, n_batches  = 10)
ce_losses.recons_loss.mean()

In [None]:
ce_losses.loss.mean()

In [None]:
ce_losses

In [None]:
def get_variance_explained(model, sparse_autoencoder, batch_tokens):
    logits, cache = model.run_with_cache(batch_tokens, return_type="loss")
    x = cache[activation_store.cfg.hook_point]
    (
        sae_out,
        feature_acts,
        loss,
        mse_loss,
        l1_loss,
        mse_loss_ghost_resid,
    ) = sparse_autoencoder(x)
    
    x = x.float().cpu()[:,1:,:]
    x_centred = x - x.mean(-1, keepdim=True)
    sae_out = sae_out[:,1:,:].cpu()


    # MSE Loss
    mse_loss_sam = (sae_out - x).pow(2) / x_centred.pow(2).sum(dim=-1, keepdim=True).sqrt()
    mse_loss = (sae_out - x).pow(2) / (x.pow(2)).sum(dim=-1, keepdim=True).sqrt()

    # Variance Explained
    per_token_l2_dist = (sae_out - x).pow(2).sum(dim=-1).squeeze()
    total_variance = x.pow(2).sum(dim=-1, keepdim=True).squeeze()
    total_variance_with_centering = x_centred.pow(2).sum(dim=-1, keepdim=True).squeeze()

    variance_explained = 1 - (per_token_l2_dist / total_variance)
    variance_explained_sams = 1 - (per_token_l2_dist / total_variance_with_centering)
    

    return mse_loss_sam, mse_loss, variance_explained, variance_explained_sams

mse_loss_sam, mse_loss, variance_explained, variance_explained_sams = get_variance_explained(model, sparse_autoencoder, batch_tokens)

px.line(mse_loss.mean(-1).T.cpu())

In [None]:

def get_variance_explained_batched(
    model, sparse_autoencoder, activation_store, n_batches=100
):
    sams_losses = []
    our_losses = []
    total_variances_1 = []
    total_variances_2 = []

    for _ in tqdm(range(n_batches)):
        batch_tokens = activation_store.get_batch_tokens()
        (
            sams_loss,
            our_loss,
            total_variance_1,
            total_variance_2,
        ) = get_variance_explained(model, sparse_autoencoder, batch_tokens)
        sams_losses.append(sams_loss.mean().item())
        our_losses.append(our_loss.mean().item())
        total_variances_1.append(total_variance_1.mean().item())
        total_variances_2.append(total_variance_2.mean().item())

    losses = pd.DataFrame(
        {
            "sams": sams_losses,
            "our": our_losses,
            "variance_explained_ours": total_variances_1,
            "variance_explained_sams": total_variances_2,
        }
    )
    return losses


batch_tokens = activation_store.get_batch_tokens()
losses = get_variance_explained_batched(
    model, sparse_autoencoder, activation_store, n_batches=10
)
losses

In [None]:
losses.variance_explained_ours.mean()

In [None]:
losses.variance_explained_sams.mean()

## Sanity Check

In [None]:
hook_point = activation_store.cfg.hook_point
prompt = "The quick brown fox jumps over the lazy"
answer = " dog"
prompt = " John and Mary went to the park. Then John said to"
answer = " Mary"
utils.test_prompt(prompt, answer, model)

    
logits, cache = model.run_with_cache(prompt)
activations = cache[hook_point]
def standard_replacement_hook(activations, hook: HookPoint):
    activations = sparse_autoencoder.forward(activations)[0].to(activations.dtype)
    return activations
    
print(model.generate(prompt, max_new_tokens=20, stop_at_eos=False, temperature=0))
with model.hooks(fwd_hooks=[(hook_point, standard_replacement_hook)]):

    utils.test_prompt(prompt = prompt, answer = answer, model = model)
    print(model.generate(prompt, max_new_tokens=20, stop_at_eos=False, temperature=0))


## Loading Sam's SAE's

In [None]:
sams_layer1_sae_path = "layer_1_resid_post_ae_245000.pt"

weights = torch.load(sams_layer1_sae_path, map_location="cpu")
display(weights.keys())

# rename keys
new_weights = {}
rename_map = {
    "bias":"b_dec",
    "encoder.bias":"b_enc",
    "decoder.weight":"W_dec",
    "encoder.weight":"W_enc"
}

for k, v in rename_map.items():
    new_weights[v] = weights[k]
    
# rotate the following weights
weights_to_rotate = ["W_enc", "W_dec"]
for w in weights_to_rotate:
    new_weights[w] = new_weights[w].T
    
display(new_weights.keys())


In [None]:
from sae_training.utils import LMSparseAutoencoderSessionloader
model, sparse_autoencoder, activation_store = LMSparseAutoencoderSessionloader.load_session_from_pretrained(
    path = "../GPT2-small-SAEs/final_sparse_autoencoder_gpt2-small_blocks.2.hook_resid_pre_24576.pt"
)
sparse_autoencoder.cfg.use_ghost_grads = False # reside pre 2 should be good

In [None]:
# now I want to make an SAE which I can load this SAE weight into. 
from sae_training.activations_store import ActivationsStore
from sae_training.sparse_autoencoder import SparseAutoencoder
from sae_training.config import LanguageModelSAERunnerConfig

cfg = LanguageModelSAERunnerConfig(
    model_name = "pythia-70m-deduped",
    hook_point = "blocks.1.hook_resid_post",
    dataset_path= "EleutherAI/the_pile_deduplicated",
    hook_point_layer = 1,
    feature_sampling_method=None,
    d_in = 512,
    lr = 0.0,
    l1_coefficient = 0.0,
    expansion_factor = 64,
    device="mps",
    store_batch_size = 32,
    n_batches_in_buffer=128,
    use_ghost_grads=False,
)

sams_sparse_autoencoder = SparseAutoencoder(cfg)
print(sams_sparse_autoencoder.state_dict().keys())
sams_sparse_autoencoder.load_state_dict(new_weights)
sams_sparse_autoencoder.to("mps")
pythia_70m_model = HookedTransformer.from_pretrained("pythia-70m-deduped", device="mps", fold_ln=True)
pythia_70m_activation_store = ActivationsStore(sams_sparse_autoencoder.cfg, pythia_70m_model)

In [None]:
batch_tokens = pythia_70m_activation_store.get_batch_tokens()
batch_tokens.shape
loss = pythia_70m_model(batch_tokens, return_type="loss")
loss

In [None]:
# pythia 70m deduped
activations = pythia_70m_activation_store.next_batch()
sae_out, feature_acts, loss, mse_loss, l1_loss, mse_loss_ghost_resid = sams_sparse_autoencoder(activations)
print("Norms")
print("In Norm", sae_out.norm(dim=0).mean().item())
print("Out norm", activations.norm(dim=0).mean().item())
print("-"*20)
print("Sparsity")
print("L1", feature_acts.sum(dim=1).mean().item())
print("L0", (feature_acts > 0).float().sum(dim=1).mean().item()) # Way too many features firing.
print("-"*20)
print("Loss")
print("MSE", mse_loss.item())

In [None]:
# pythia 70m deduped
activations = pythia_70m_activation_store.next_batch()
activation_plus_b_dec = activations + sams_sparse_autoencoder.b_dec
sae_out, feature_acts, loss, mse_loss, l1_loss, mse_loss_ghost_resid = sams_sparse_autoencoder(activation_plus_b_dec)
print("Norms")
print("In Norm", sae_out.norm(dim=0).mean().item())
print("Out norm", activations.norm(dim=0).mean().item())
print("-"*20)
print("Sparsity")
print("L1", feature_acts.sum(dim=1).mean().item())
print("L0", (feature_acts > 0).float().sum(dim=1).mean().item()) # Way too many features firing.
print("-"*20)
print("Loss")
print("MSE", mse_loss.item())

In [None]:
# this is is a GPT2 small residual stream SAE (layer 2 resid pre) 
activations = activation_store.next_batch()
sae_out, feature_acts, loss, mse_loss, l1_loss, mse_loss_ghost_resid = sparse_autoencoder(activations)
print("Norms")
print("In Norm", sae_out.norm(dim=0).mean().item())
print("Out norm", activations.norm(dim=0).mean().item())
print("-"*20)
print("Sparsity")
print("L1", feature_acts.sum(dim=1).mean().item())
print("L0", (feature_acts > 0).float().sum(dim=1).mean().item()) # Way too many features firing.
print("-"*20)
print("Loss")
print("MSE", mse_loss.item())

In [None]:

@torch.no_grad()
def get_recons_loss(sparse_autoencoder, model, batch_tokens, hook_point, add_b_dec = False):
    loss = model(batch_tokens, return_type="loss", loss_per_token=True)

    head_index = sparse_autoencoder.cfg.hook_point_head_index

    def mean_ablate_hook(mlp_post, hook):
        mlp_post[:] = mlp_post.mean([0, 1]).to(mlp_post.dtype)
        return mlp_post

    def zero_ablate_hook(mlp_post, hook):
        mlp_post[:] = 0.0
        return mlp_post

    def no_replacement_hook(activations, hook):
        return activations

    def standard_replacement_hook(activations, hook):
        if add_b_dec:
            activations = activations + sparse_autoencoder.b_dec
        activations = sparse_autoencoder.forward(activations)[0].to(activations.dtype)
        return activations

    def head_replacement_hook(activations, hook):
        new_actions = sparse_autoencoder.forward(activations[:,:,head_index])[0].to(activations.dtype)
        activations[:,:,head_index] = new_actions
        return activations

    replacement_hook = standard_replacement_hook if head_index is None else head_replacement_hook

    no_replacement_loss = model.run_with_hooks(
        batch_tokens,
        return_type="loss",
        fwd_hooks=[(hook_point, partial(no_replacement_hook))],
        loss_per_token=True,
    )
    
    recons_loss = model.run_with_hooks(
        batch_tokens,
        return_type="loss",
        fwd_hooks=[(hook_point, partial(replacement_hook))],
        loss_per_token=True,
    )

    zero_abl_loss = model.run_with_hooks(
        batch_tokens, return_type="loss", fwd_hooks=[(hook_point, zero_ablate_hook)],
        loss_per_token=True,
    )

    score = (zero_abl_loss - recons_loss) / (zero_abl_loss - loss)

    return score, loss, recons_loss, zero_abl_loss, no_replacement_loss

def recons_loss_batched(sparse_autoencoder, model, activation_store, n_batches = 100, add_b_dec = False):
    
    losses = []
    for _ in tqdm(range(n_batches)):
        batch_tokens = activation_store.get_batch_tokens()
        score, loss, recons_loss, zero_abl_loss, no_replacement_loss = get_recons_loss(
            sparse_autoencoder, model, batch_tokens, activation_store.cfg.hook_point,
            add_b_dec = add_b_dec)
        losses.append((score.mean().item(), loss.mean().item(), recons_loss.mean().item(), zero_abl_loss.mean().item(), no_replacement_loss.mean().item()))

    losses = pd.DataFrame(losses, columns=["score", "loss", "recons_loss", "zero_abl_loss", "no_replacement_loss"])
    
    return losses


ce_losses = recons_loss_batched(sams_sparse_autoencoder, pythia_70m_model, pythia_70m_activation_store, n_batches  = 100)
display(ce_losses)
print(ce_losses.score.mean())
print(ce_losses.recons_loss.mean())
ce_losses = recons_loss_batched(sams_sparse_autoencoder, pythia_70m_model, pythia_70m_activation_store, n_batches  = 100, add_b_dec=True)
display(ce_losses)
print(ce_losses.score.mean())
print(ce_losses.recons_loss.mean())


In [None]:
ce_losses = recons_loss_batched(sparse_autoencoder, model, activation_store, n_batches  = 100, add_b_dec=False)
display(ce_losses)
print(ce_losses.score.mean())
print(ce_losses.recons_loss.mean())

## Ok let's just compare weights

In [None]:
sams

In [None]:
px.histogram(sparse_autoencoder.b_enc.detach().cpu().numpy())