# Training a basic SAE with SAELens

This tutorial demonstrates training a simple, relatively small Sparse Autoencoder, specifically on the tiny-stories-1L-21M model.

As the SAELens library is under active development, please open an issue if this tutorial is stale [here](https://github.com/jbloomAus/SAELens).

## Setup

In [None]:
try:
    # import google.colab # type: ignore
    # from google.colab import output
    %pip install sae-lens transformer-lens circuitsvis
except:
    from IPython import get_ipython  # type: ignore

    ipython = get_ipython()
    assert ipython is not None
    ipython.run_line_magic("load_ext", "autoreload")
    ipython.run_line_magic("autoreload", "2")

In [None]:
import torch
import os

from sae_lens import LanguageModelSAERunnerConfig, SAETrainingRunner

if torch.cuda.is_available():
    device = "cuda"
elif torch.backends.mps.is_available():
    device = "mps"
else:
    device = "cpu"

print("Using device:", device)
os.environ["TOKENIZERS_PARALLELISM"] = "false"

# Training an SAE

Now we're ready to train out SAE. We'll make a runner config, instantiate the runner and the rest is taken care of for us!

During training, you use weights and biases to check key metrics which indicate how well we are able to optimize the variables we care about.

To get a better sense of which variables to look at, you can read my (Joseph's) post [here](https://www.lesswrong.com/posts/f9EgfLSurAiqRJySD/open-source-sparse-autoencoders-for-all-residual-stream) and especially look at my weights and biases report [here](https://links-cdn.wandb.ai/wandb-public-images/links/jbloom/uue9i416.html).

A few tips:
- Feel free to reorganize your wandb dashboard to put L0, CE_Loss_score, explained variance and other key metrics in one section at the top.
- Make a [run comparer](https://docs.wandb.ai/guides/app/features/panels/run-comparer) when tuning hyperparameters.
- You can download the resulting sparse autoencoder / sparsity estimate from wandb and upload them to huggingface if you want to share your SAE with other.
    - cfg.json (training config)
    - sae_weight.safetensors (model weights)
    - sparsity.safetensors (sparsity estimate)

## MLP Out

I've tuned the hyperparameters below for a decent SAE which achieves 86% CE Loss recovered and an L0 of ~85, and runs in about 2 hours on an M3 Max. You can get an SAE that looks better faster if you only consider L0 and CE loss but it will likely have more dense features and more dead features. Here's a link to my output with two runs with two different L1's: https://wandb.ai/jbloom/sae_lens_tutorial .

In [None]:
total_training_steps = 30_000  # probably we should do more
batch_size = 4096
total_training_tokens = total_training_steps * batch_size

lr_warm_up_steps = 0
lr_decay_steps = total_training_steps // 5  # 20% of training
l1_warm_up_steps = total_training_steps // 20  # 5% of training

cfg = LanguageModelSAERunnerConfig(
    # Data Generating Function (Model + Training Distibuion)
    model_name="tiny-stories-1M",  # our model (more options here: https://neelnanda-io.github.io/TransformerLens/generated/model_properties_table.html)
    hook_name="blocks.0.hook_mlp_out",  # A valid hook point (see more details here: https://neelnanda-io.github.io/TransformerLens/generated/demos/Main_Demo.html#Hook-Points)
    hook_layer=0,  # Only one layer in the model.
    d_in=64,  # the width of the mlp output.
    dataset_path="apollo-research/roneneldan-TinyStories-tokenizer-gpt2",  # this is a tokenized language dataset on Huggingface for the Tiny Stories corpus.
    is_dataset_tokenized=True,
    streaming=True,  # we could pre-download the token dataset if it was small.
    # SAE Parameters
    mse_loss_normalization=None,  # We won't normalize the mse loss,
    expansion_factor=4,  # the width of the SAE. Larger will result in better stats but slower training.
    b_dec_init_method="zeros",  # The geometric median can be used to initialize the decoder weights.
    apply_b_dec_to_input=False,  # We won't apply the decoder weights to the input.
    normalize_sae_decoder=False,
    scale_sparsity_penalty_by_decoder_norm=True,
    decoder_heuristic_init=True,
    init_encoder_as_decoder_transpose=True,
    normalize_activations="expected_average_only_in",
    # Training Parameters
    lr=5e-5,  # lower the better, we'll go fairly high to speed up the tutorial.
    adam_beta1=0.9,  # adam params (default, but once upon a time we experimented with these.)
    adam_beta2=0.999,
    lr_scheduler_name="constant",  # constant learning rate with warmup. Could be better schedules out there.
    lr_warm_up_steps=lr_warm_up_steps,  # this can help avoid too many dead features initially.
    lr_decay_steps=lr_decay_steps,  # this will help us avoid overfitting.
    l1_coefficient=5,  # will control how sparse the feature activations are
    l1_warm_up_steps=l1_warm_up_steps,  # this can help avoid too many dead features initially.
    lp_norm=1.0,  # the L1 penalty (and not a Lp for p < 1)
    train_batch_size_tokens=batch_size,
    context_size=512,  # will control the lenght of the prompts we feed to the model. Larger is better but slower. so for the tutorial we'll use a short one.
    # Activation Store Parameters
    n_batches_in_buffer=64,  # controls how many activations we store / shuffle.
    training_tokens=total_training_tokens,  # 100 million tokens is quite a few, but we want to see good stats. Get a coffee, come back.
    store_batch_size_prompts=16,
    # Resampling protocol
    use_ghost_grads=False,  # we don't use ghost grads anymore.
    feature_sampling_window=1000,  # this controls our reporting of feature sparsity stats
    dead_feature_window=1000,  # would effect resampling or ghost grads if we were using it.
    dead_feature_threshold=1e-4,  # would effect resampling or ghost grads if we were using it.
    # WANDB
    log_to_wandb=True,  # always use wandb unless you are just testing code.
    wandb_project="sae_lens_tutorial",
    wandb_log_frequency=30,
    eval_every_n_wandb_logs=20,
    # Misc
    device=device,
    seed=42,
    n_checkpoints=0,
    checkpoint_path="checkpoints",
    dtype="float32",
)
# look at the next cell to see some instruction for what to do while this is running.
sparse_autoencoder = SAETrainingRunner(cfg).run()

# TO DO: Understanding TinyStories-1L with our SAE

I haven't had time yet to complete this section, but I'd love to see a PR where someones uses an SAE they trained in this tutorial to understand this model better.

In [None]:
import os
os.getcwd()

In [None]:
directory_name=os.path.join(os.getcwd(),"my_sae")

In [None]:
from sae_lens import SAE
sae=SAE.load_from_pretrained(path=directory_name)
sae.use_error_term=False
sae

In [None]:
from datasets import load_dataset,Dataset,DatasetDict
import torch
ds_valid = load_dataset("roneneldan/TinyStories", split="validation")

raw_datasets = DatasetDict(
    {
        "valid": ds_valid
    }
)
raw_datasets

In [None]:
from transformers import AutoTokenizer

context_length = 128
tokenizer = AutoTokenizer.from_pretrained("roneneldan/TinyStories-1M")

outputs = tokenizer(
    raw_datasets["valid"][:2]["text"],
    truncation=True,
    max_length=context_length,
    return_overflowing_tokens=True,
    return_length=True,
)

print(f"Input IDs length: {len(outputs['input_ids'])}")
print(f"Input chunk lengths: {(outputs['length'])}")
print(f"Chunk mapping: {outputs['overflow_to_sample_mapping']}")

In [None]:
def tokenize(element):
    outputs = tokenizer(
        element["text"],
        truncation=True,
        max_length=context_length,
        return_overflowing_tokens=True,
        return_length=True,
    )
    input_batch = []
    for length, input_ids in zip(outputs["length"], outputs["input_ids"]):
        if length == context_length:
            input_batch.append(input_ids)
    return {"tokens": input_batch}


tokenized_datasets = raw_datasets.map(
    tokenize, batched=True, remove_columns=raw_datasets["valid"].column_names
)
tokenized_datasets

In [None]:
import torch
tokenized_datasets=tokenized_datasets.with_format("torch")
tokenized_datasets

In [None]:
#getting model
from transformer_lens import HookedTransformerConfig,HookedTransformer
my_model=HookedTransformer.from_pretrained("tiny-stories-1M")
my_model.cfg

In [None]:
#running code with no-directional ablation
from torch.utils.data import DataLoader
from tqdm import tqdm
def no_abalation(
    model: HookedTransformer,
    batch_size,
    dataset: Dataset,
    device: str,
    sae:SAE
) -> tuple[int,int]:
    """
    Trains an HookedTransformer model on an autoregressive language modeling task.
    Args:
        model: The model to train
        config: The training configuration
        dataset: The dataset to train on - this function assumes the dataset is set up for autoregressive language modeling.
    Returns:
        The trained model
    """
    net_model_loss=0.0
    net_reconstruction_loss=0.0
    dataloader = DataLoader(dataset,batch_size,shuffle=True)
    model.to(device)
    sae.to(device)
    with torch.no_grad():
      for step, batch in tqdm(enumerate(dataloader)):
          torch.cuda.empty_cache()
          tokens = batch["tokens"].to(device)
          loss,cache=model.run_with_cache(tokens,return_type="loss")
          sae_input_acts=cache["blocks.0.hook_mlp_out"]
          sae_reconstruction=sae.decode(sae.encode(sae_input_acts))
          reconstruction_loss=torch.nn.functional.mse_loss(sae_reconstruction,sae_input_acts)
          net_model_loss+=loss.item()
          net_reconstruction_loss+=reconstruction_loss.item()
          del cache
          del tokens
          del loss
          del reconstruction_loss
          del sae_input_acts
          del sae_reconstruction
          torch.cuda.empty_cache()
    return net_model_loss,net_reconstruction_loss


In [None]:
import gc

In [None]:
loss1=0.0
loss2=0.0
num_thousand_rows=int(len(tokenized_datasets["valid"])/1000)
for iteration in range(0,num_thousand_rows):
  test_dataset=tokenized_datasets["valid"].select(range(iteration*1000,(iteration+1)*1000))
  net_model_loss,net_reconstruction_loss=no_abalation(my_model,128,test_dataset,"cuda",sae)
  loss1+=net_model_loss
  loss2+=net_reconstruction_loss
  del test_dataset
  del net_model_loss
  del net_reconstruction_loss
  gc.collect()
  torch.cuda.empty_cache()

In [None]:
print(loss1,loss2)

In [None]:
%pip install einops

In [None]:
#helper function that removes SAE direction from hook_points
import einops
def directional_hook_function_helper(
    residual_layer,
    hook,
    sae_decoder_unit_vector,
):

  new_residual_layer=residual_layer-einops.einsum(einops.einsum(sae_decoder_unit_vector,sae_decoder_unit_vector,"a,b->a b"),residual_layer,"dmodn dmod,batch pos dmod -> batch pos dmodn")
  return new_residual_layer



In [None]:
#helper function that reconstructs input without dindex
def reconstruction_helper(sae:SAE,input,dindex:int):
  feats=sae.encode(input)
  feats[:,:,dindex]=0
  reconstruction=sae.decode(feats)
  return reconstruction

In [None]:
#utility code to create a batch of 64 random unit vectors
def create_rand_vecs(dmod,batchsize=64):
    my_vec=torch.rand((dmod,batchsize))
    my_vec=torch.nn.functional.normalize(my_vec,dim=0).to("cuda")
    return my_vec

In [None]:
#function that runs with ablation
#running code with no-directional ablation
from torch.utils.data import DataLoader
from tqdm import tqdm
def yes_abalation(
    model: HookedTransformer,
    batch_size,
    dataset: Dataset,
    device: str,
    hook
) -> tuple[int,int]:
    """
    Trains an HookedTransformer model on an autoregressive language modeling task.
    Args:
        model: The model to train
        config: The training configuration
        dataset: The dataset to train on - this function assumes the dataset is set up for autoregressive language modeling.
    Returns:
        The trained model
    """
    net_model_loss=0.0
    dataloader = DataLoader(dataset,batch_size,shuffle=True)
    model.to(device)
    with torch.no_grad():
      for step, batch in tqdm(enumerate(dataloader)):
          torch.cuda.empty_cache()
          tokens = batch["tokens"].to(device)
          loss=model.run_with_hooks(tokens,return_type="loss",fwd_hooks=[(lambda name: "resid" in name, hook)])
          net_model_loss+=loss.item()
          del tokens
          del loss
          torch.cuda.empty_cache()
    return net_model_loss

In [None]:
from functools import partial
num_ablations=64
feature_abalation_losses=[]
num_thousand_rows=int(len(tokenized_datasets["valid"])/1000)
for ablation_idx in range(num_ablations):
  abalation_vector=((sae.W_dec[ablation_idx])/torch.norm(sae.W_dec[ablation_idx])).to("cuda")
  hook_function=partial(directional_hook_function_helper,sae_decoder_unit_vector=abalation_vector)
  loss1=0
  for iteration in range(num_thousand_rows):
    test_dataset=tokenized_datasets["valid"].select(range(iteration*1000,(iteration+1)*1000))
    net_model_loss=yes_abalation(my_model,128,test_dataset,"cuda",hook_function)
    loss1+=net_model_loss
    gc.collect()
    torch.cuda.empty_cache()
  feature_abalation_losses.append(loss1)
  print(f"The Loss at iteration:{ablation_idx+1} is {loss1}")

In [None]:
feature_abalation_losses

In [None]:
def yes_abalation_sae(
    model: HookedTransformer,
    batch_size,
    dataset: Dataset,
    device: str,
    sae:SAE,
    ablation_idx
) -> tuple[int,int]:
    """
    Trains an HookedTransformer model on an autoregressive language modeling task.
    Args:
        model: The model to train
        config: The training configuration
        dataset: The dataset to train on - this function assumes the dataset is set up for autoregressive language modeling.
    Returns:
        The trained model
    """
    net_sae_loss=0.0
    dataloader = DataLoader(dataset,batch_size,shuffle=True)
    model.to(device)
    sae.to(device)
    with torch.no_grad():
      for step, batch in tqdm(enumerate(dataloader)):
          torch.cuda.empty_cache()
          tokens = batch["tokens"].to(device)
          loss,cache=model.run_with_cache(tokens,return_type="loss")
          sae_input_acts=cache["blocks.0.hook_mlp_out"]
          sae_reconstruction=reconstruction_helper(sae,sae_input_acts,ablation_idx)
          reconstruction_loss=torch.nn.functional.mse_loss(sae_reconstruction,sae_input_acts)
          net_sae_loss+=reconstruction_loss.item()
          del cache
          del tokens
          del loss
          torch.cuda.empty_cache()
    return net_sae_loss


In [None]:
reconstruction_losses=[]
num_ablations=1
for ablation_idx in range(num_ablations):
  loss=0
  for iteration in range(num_thousand_rows):
    test_dataset=tokenized_datasets["valid"].select(range(iteration*1000,(iteration+1)*1000))
    net_model_loss=yes_abalation_sae(my_model,128,test_dataset,"cuda",sae,ablation_idx)
    loss+=net_model_loss
    gc.collect()
    torch.cuda.empty_cache()
  reconstruction_losses.append(loss)
  print(f"Loss is:{loss}, ablation_idx is:{ablation_idx}")


In [None]:
reconstruction_losses

In [None]:
#ablating random vectors to see the difference in loss
from functools import partial
num_ablations=64
random_losses=[]
abalation_vectors=create_rand_vecs(128,64)
for ablation_idx in range(num_ablations):
  abalation_vector=abalation_vectors[ablation_idx]
  hook_function=partial(directional_hook_function_helper,sae_decoder_unit_vector=abalation_vector)
  loss1=0
  for iteration in range(num_thousand_rows):
    test_dataset=tokenized_datasets["valid"].select(range(iteration*1000,(iteration+1)*1000))
    net_model_loss=yes_abalation(my_model,128,test_dataset,"cuda",hook_function)
    loss1+=net_model_loss
    gc.collect()
    torch.cuda.empty_cache()
  random_losses.append(loss1)

In [None]:
random_losses

In [None]:
from sae_lens import HookedSAETransformer,ActivationsStore
import plotly.express as px
my_model_sae=HookedSAETransformer.from_pretrained("tiny-stories-1M")
tiny_stories_act_store = ActivationsStore.from_sae(
    model=my_model_sae,
    sae=sae,
    streaming=True,
    context_size=128,
    dataset=tokenized_datasets["valid"],
    store_batch_size_prompts=16,
    n_batches_in_buffer=32,
    device=str(device),
    total_tokens=27066
)


In [None]:
from sae_lens import HookedSAETransformer,ActivationsStore
import plotly.express as px
def show_activation_histogram(
    model: HookedSAETransformer,
    sae,
    act_store: ActivationsStore,
    latent_idx: int,
    total_batches: int = 200,
):
    """
    Displays the activation histogram for a particular latent, computed across `total_batches` batches from `act_store`.
    """
    sae_acts_post_hook_name = f"{sae.cfg.hook_name}.hook_sae_acts_post"
    all_positive_acts = []
    sae.to("cuda")
    for i in tqdm(range(total_batches), desc="Computing activations for histogram"):
        tokens = act_store.get_batch_tokens().to("cuda")
        _, cache = model.run_with_cache_with_saes(
            tokens,
            saes=[sae],
            stop_at_layer=sae.cfg.hook_layer + 1,
            names_filter=[sae_acts_post_hook_name],
        )
        acts = cache[sae_acts_post_hook_name][..., latent_idx]
        all_positive_acts.extend(acts[acts > 0].cpu().tolist())

    frac_active = len(all_positive_acts) / (total_batches * act_store.store_batch_size_prompts * act_store.context_size)

    px.histogram(
        all_positive_acts,
        nbins=50,
        title=f"ACTIVATIONS DENSITY {frac_active:.3%}",
        labels={"value": "Activation"},
        width=800,
        template="ggplot2",
        color_discrete_sequence=["darkorange"],
    ).update_layout(bargap=0.02, showlegend=False).show()



In [None]:
num_indices=64
for i in range(num_indices):
    show_activation_histogram(my_model_sae,sae,tiny_stories_act_store,latent_idx=i)

In [None]:
def show_activation_amount(
    model: HookedSAETransformer,
    sae,
    act_store: ActivationsStore,
    latent_idx: int,
    total_batches: int = 500,
):
    """
    Displays the activation histogram for a particular latent, computed across `total_batches` batches from `act_store`.
    """
    sae_acts_post_hook_name = f"{sae.cfg.hook_name}.hook_sae_acts_post"
    all_positive_acts = []
    sae.to("cuda")
    for i in tqdm(range(total_batches), desc="Computing activations for histogram"):
        tokens = act_store.get_batch_tokens().to("cuda")
        _, cache = model.run_with_cache_with_saes(
            tokens,
            saes=[sae],
            stop_at_layer=sae.cfg.hook_layer + 1,
            names_filter=[sae_acts_post_hook_name],
        )
        acts = cache[sae_acts_post_hook_name][..., latent_idx]
        all_positive_acts.extend(acts[acts > 0].cpu().tolist())

    frac_active = len(all_positive_acts) / (total_batches * act_store.store_batch_size_prompts * act_store.context_size)
    return frac_active

In [None]:
num_indices=64
activation_precentages=[]
for i in range(num_indices):
    activation=show_activation_amount(my_model_sae,sae,tiny_stories_act_store,latent_idx=i)
    activation_precentages.append(activation)
activation_precentages

In [None]:
activation_precentages

In [None]:
#idea2 trying to find the vector which when ablated maximises the loss over the dataset
from transformer_lens import train
from wandb import wandb
from torch import optim
from torch.utils.data import DataLoader
from tqdm import tqdm
from functools import partial
def max_ablation_train(
    model: HookedTransformer,
    config,
    dataset: Dataset,
    my_vector,
    HookFunctionHelper,
) -> HookedTransformer:
    """
    Trains an HookedTransformer model on an autoregressive language modeling task.
    Args:
        model: The model to train
        config: The training configuration
        dataset: The dataset to train on - this function assumes the dataset is set up for autoregressive language modeling.
    Returns:
        The trained model
    """
    torch.manual_seed(config.seed)
    #model.train()
    if config.wandb:
        if config.wandb_project_name is None:
            config.wandb_project_name = "NeelNandaApplication"
        wandb.init(project=config.wandb_project_name, config=vars(config))

    if config.optimizer_name in ["Adam", "AdamW"]:
        # Weight decay in Adam is implemented badly, so use AdamW instead (see PyTorch AdamW docs)
        if config.weight_decay is not None:
            optimizer = optim.AdamW(
                [my_vector],
                lr=config.lr,
                weight_decay=config.weight_decay,
            )
        else:
            optimizer = optim.Adam(
                [my_vector],
                lr=config.lr,
            )
    elif config.optimizer_name == "SGD":
        optimizer = optim.SGD(
            [my_vector],
            lr=config.lr,
            weight_decay=(config.weight_decay if config.weight_decay is not None else 0.0),
            momentum=config.momentum,
        )
    else:
        raise ValueError(f"Optimizer {config.optimizer_name} not supported")

    scheduler = None
    if config.warmup_steps > 0:
        scheduler = optim.lr_scheduler.LambdaLR(
            optimizer,
            lr_lambda=lambda step: min(1.0, step / config.warmup_steps),
        )

    dataloader = DataLoader(dataset, batch_size=config.batch_size, shuffle=True)

    model.to(config.device)

    for epoch in tqdm(range(1, config.num_epochs + 1)):
        samples = 0
        for step, batch in tqdm(enumerate(dataloader)):
            tokens = batch["tokens"].to(config.device)
            abalation_vector=((my_vector)/torch.norm(my_vector)).to("cuda")
            hook_function=partial(HookFunctionHelper,sae_decoder_unit_vector=abalation_vector)
            loss = model.run_with_hooks(
                tokens,
                fwd_hooks=[(lambda name: "resid" in name, hook_function)],
                return_type="loss",
            )
            (-loss).backward()
            if config.max_grad_norm is not None:
                torch.nn.utils.clip_grad_norm_(model.parameters(), config.max_grad_norm)
            optimizer.step()
            if config.warmup_steps > 0:
                assert scheduler is not None
                scheduler.step()
            optimizer.zero_grad()

            samples += tokens.shape[0]

            if config.wandb:
                wandb.log({"train_loss": loss.item(), "samples": samples, "epoch": epoch})

            if config.print_every is not None and step % config.print_every == 0:
                print(f"Epoch {epoch} Samples {samples} Step {step} Loss {loss.item()}")

            if (
                config.save_every is not None
                and step % config.save_every == 0
                and config.save_dir is not None
            ):
                torch.save(model.state_dict(), f"{config.save_dir}/model_{step}.pt")

            if config.max_steps is not None and step >= config.max_steps:
                break
            del tokens
            del loss
            gc.collect()
            torch.cuda.empty_cache()

    return model

In [None]:
my_config=train.HookedTransformerTrainConfig(
    num_epochs=10,
    batch_size=64,
    lr=1e-3,
    seed=42,
    #wandb
    wandb=True,
    wandb_project_name="NeelNandaApplication",
    #device
    device="cuda",
    optimizer_name="Adam",
    warmup_steps=0,
    print_every=50,
)

In [None]:
max_ablation_tensor=torch.nn.parameter.Parameter(torch.normal(mean=torch.zeros(64),std=torch.ones(64)))
max_ablation_train(my_model,my_config,tokenized_datasets["valid"],max_ablation_tensor,directional_hook_function_helper)

In [None]:
import matplotlib.pyplot as plt

def plot_xy_chart(x,y):
    # Ensure data contains tuples with at least two elements

    # Extract x and y values from the tuples

    plt.figure(figsize=(8, 6))
    plt.scatter(x, y, color='blue', marker='o')

    plt.xlabel('Ablation_Loss')
    plt.ylabel('Reconstruction_Loss')
    plt.title('Ablation Loss vs Reconstruction Loss')

    plt.grid(True)
    plt.show()

plot_xy_chart(feature_abalation_losses,reconstruction_losses)

In [None]:
import plotly.graph_objects as go
import numpy as np
import scipy.stats as stats
def gaussian_pdf(x, mean, std, scale):
    return (1 / (std * np.sqrt(2 * np.pi))) * np.exp(-0.5 * ((x - mean) / std) ** 2) * scale

def plot_distribution(data1, data2, bins=100, title="Frequency vs Loss"):
    hist1 = go.Histogram(x=data1, nbinsx=bins, opacity=0.6, name='Feature_Ablation_loss', marker=dict(color='blue'))
    hist2 = go.Histogram(x=data2, nbinsx=bins, opacity=0.6, name='Random_Ablation_loss', marker=dict(color='red'))

    x_range = np.linspace(min(min(data1), min(data2)), max(max(data1), max(data2)), 100)
    pdf1 = stats.norm.pdf(x_range, np.mean(data1), np.std(data1)) * len(data1) * (max(data1) - min(data1)) / bins
    pdf2 = stats.norm.pdf(x_range, np.mean(data2), np.std(data2)) * len(data2) * (max(data2) - min(data2)) / bins
    
    fit_curve1 = go.Scatter(x=x_range, y=pdf1, mode='lines', name='Fit Distribution 1', line=dict(color='blue'))
    fit_curve2 = go.Scatter(x=x_range, y=pdf2, mode='lines', name='Fit Distribution 2', line=dict(color='red'))
    
    fig = go.Figure(data=[hist1, hist2, fit_curve1, fit_curve2])
    fig.update_layout(title=title, xaxis_title='Value', yaxis_title='Frequency', barmode='overlay')
    fig.show()


# Example usage
plot_distribution(feature_abalation_losses, random_losses)

In [None]:
def plot_xy_chart(x,y):
    # Ensure data contains tuples with at least two elements

    # Extract x and y values from the tuples

    plt.figure(figsize=(8, 6))
    plt.scatter(x, y, color='blue', marker='o')

    plt.xlabel('Ablation_Loss')
    plt.ylabel('Percentage_Activation')
    plt.title('Ablation Loss vs Percentage Activation')

    plt.grid(True)
    plt.show()

plot_xy_chart(feature_abalation_losses,activation_precentages)

In [None]:
hook_function=partial(directional_hook_function_helper,sae_decoder_unit_vector=max_ablation_tensor)
test_dataset_new=tokenized_datasets["valid"]["tokens"][:100,:]
ans_tokens=my_model.run_with_hooks(test_dataset_new,return_type="logits",fwd_hooks=[(lambda name: "resid" in name, hook_function)])
my_model.tokenizer.decode(ans_tokens[10].argmax(dim=1))

In [None]:
max_ablation_tensor=max_ablation_tensor.to("cuda")
w_dec=sae.W_dec
dot_product=einops.einsum(max_ablation_tensor,w_dec,"a,b a -> b")
cosine_sim=dot_product/(torch.norm(w_dec,dim=1)*torch.norm(max_ablation_tensor))
cosine_sim

In [None]:
random_test_vec=torch.rand(64).to("cuda")
rand_dot_prdocut=einops.einsum(random_test_vec,w_dec,"a,b a -> b")
rand_cosine_sim=rand_dot_prdocut/(torch.norm(w_dec,dim=1)*torch.norm(random_test_vec))
rand_cosine_sim

In [None]:
cosine_sims1=cosine_sim.tolist()
cosine_sims2=rand_cosine_sim.tolist()


In [None]:
def plot_values_distribution(values, bins=20, title="Value Distribution"):
    fig = go.Figure()
    fig.add_trace(go.Histogram(x=values, nbinsx=bins, opacity=0.6, name='Values', marker=dict(color='green')))
    fig.update_layout(title=title, xaxis_title='Value', yaxis_title='Frequency', barmode='overlay')
    fig.show()

plot_values_distribution(cosine_sims1)
plot_values_distribution(cosine_sims2)