# Balancing Hedging and Absorption in Matryoshka SAEs

Hedging and Absoption have opposite effects on a the encoder of an SAE. Absorption causes the latent tracking the parent feature to merge negative components of child features. Hedging does the opposite, and merges in positive components of child features into the parent latent.

As a result, it's actually beneficial to allow the outer layers of a Matryoshka SAE to exert some absorption pressure on the inner layers to reduce the amount of hedging and get closer to learning the true features.

In this notebook, we explore this idea in a toy model.

## Create toy model and hierarchical feature generator

In [None]:
from hedging_paper.toy_models.toy_model import ToyModel
from hedging_paper.toy_models.tree_feature_generator import TreeFeatureGenerator

# Create the toy model and hierarchical features

DEFAULT_D_IN = 50

DEFAULT_D_SAE = 4
DEFAULT_NUM_FEATS = 4

toy_model = ToyModel(num_feats=DEFAULT_NUM_FEATS, hidden_dim=DEFAULT_D_IN)

children1 = []
children1.append(TreeFeatureGenerator(0.15))
children1.append(TreeFeatureGenerator(0.15))
children1.append(TreeFeatureGenerator(0.15))
children1.append(TreeFeatureGenerator(0.55, is_read_out=False))
parent1 = TreeFeatureGenerator(0.25, children1, mutually_exclusive_children=False)

feat_generator = TreeFeatureGenerator(1.0, [parent1], is_read_out=False)

## Balancing the inner and outer Matryoshka layers

Can we find a loss multiplier that will work to balance these layers?

In [None]:
import torch
from hedging_paper.saes.matryoshka_sae import MatryoshkaSAE, MatryoshkaSAEConfig, MatryoshkaSAERunnerConfig
from hedging_paper.toy_models.initialization import init_sae_to_match_model
from hedging_paper.toy_models.train_toy_sae import train_toy_sae

cfg = MatryoshkaSAERunnerConfig(
    d_in=toy_model.embed.weight.shape[0],
    d_sae=DEFAULT_D_SAE,
    l1_coefficient=3e-2,
    normalize_sae_decoder=False,
    scale_sparsity_penalty_by_decoder_norm=True,
    init_encoder_as_decoder_transpose=True,
    apply_b_dec_to_input=True,
    b_dec_init_method="zeros",
    matryoshka_steps=[1, DEFAULT_D_SAE],
    skip_outer_loss=True,
    # matryoshka_inner_loss_multipliers=[10.0],
    use_delta_loss=True,
    # reconstruction_loss="L2",
    # matryoshka_reconstruction_loss="L2",
)
sae_full_mat = MatryoshkaSAE(MatryoshkaSAEConfig.from_sae_runner_config(cfg))
init_sae_to_match_model(
    sae_full_mat,
    toy_model,
    noise_level=0.0,
    # feature_ordering=torch.tensor([0, 4, 1, 2, 3, 5, 6]),
    feature_ordering=torch.tensor([0, 1, 2, 3]),
)


# NOTE: occasionaly this gets stuck in poor local minima. If this happens, try rerunning and it should converge properly.
train_toy_sae(sae_full_mat, toy_model, feat_generator.sample)

In [None]:
from hedging_paper.toy_models.plotting import plot_sae_feat_cos_sims_seaborn

plot_sae_feat_cos_sims_seaborn(
    sae_full_mat,
    toy_model,
    title=r"Detached Matryoshka SAE ($\beta = \infty$)",
    show_values=False,
    width=3,
    height=1.5,
    save_path="plots/toy_sae_mat_detached.pdf",
    one_based_indexing=True,
)

In [None]:
import torch
from hedging_paper.saes.matryoshka_sae import MatryoshkaSAE, MatryoshkaSAEConfig, MatryoshkaSAERunnerConfig
from hedging_paper.toy_models.initialization import init_sae_to_match_model
from hedging_paper.toy_models.train_toy_sae import train_toy_sae

cfg = MatryoshkaSAERunnerConfig(
    d_in=toy_model.embed.weight.shape[0],
    d_sae=DEFAULT_D_SAE,
    l1_coefficient=3e-2,
    normalize_sae_decoder=False,
    scale_sparsity_penalty_by_decoder_norm=True,
    init_encoder_as_decoder_transpose=True,
    apply_b_dec_to_input=True,
    b_dec_init_method="zeros",
    matryoshka_steps=[DEFAULT_D_SAE],
    skip_outer_loss=True,
    # matryoshka_inner_loss_multipliers=[10.0],
    # use_delta_loss=True,
    # reconstruction_loss="L2",
    # matryoshka_reconstruction_loss="L2",
)
sae_full_abs = MatryoshkaSAE(MatryoshkaSAEConfig.from_sae_runner_config(cfg))
init_sae_to_match_model(
    sae_full_abs,
    toy_model,
    noise_level=0.0,
    # feature_ordering=torch.tensor([0, 4, 1, 2, 3, 5, 6]),
    feature_ordering=torch.tensor([0, 1, 2, 3]),
)

train_toy_sae(sae_full_abs, toy_model, feat_generator.sample)

In [None]:
from hedging_paper.toy_models.plotting import plot_sae_feat_cos_sims_seaborn

plot_sae_feat_cos_sims_seaborn(
    sae_full_abs,
    toy_model,
    title=r"Standard SAE ($\beta = 0$)",
    show_values=False,
    width=3,
    height=1.5,
    save_path="plots/toy_sae_abs.pdf",
    one_based_indexing=True,
)

In [None]:
import torch
from hedging_paper.saes.matryoshka_sae import MatryoshkaSAE, MatryoshkaSAEConfig, MatryoshkaSAERunnerConfig
from hedging_paper.toy_models.initialization import init_sae_to_match_model
from hedging_paper.toy_models.train_toy_sae import train_toy_sae

cfg = MatryoshkaSAERunnerConfig(
    d_in=toy_model.embed.weight.shape[0],
    d_sae=DEFAULT_D_SAE,
    l1_coefficient=3e-2,
    normalize_sae_decoder=False,
    scale_sparsity_penalty_by_decoder_norm=True,
    init_encoder_as_decoder_transpose=True,
    apply_b_dec_to_input=True,
    b_dec_init_method="zeros",
    matryoshka_steps=[1, DEFAULT_D_SAE],
    skip_outer_loss=True,
    matryoshka_inner_loss_multipliers=[0.25, 1.0],
    # use_delta_loss=True,
    # reconstruction_loss="L2",
    # matryoshka_reconstruction_loss="L2",
)
sae_bal_mat = MatryoshkaSAE(MatryoshkaSAEConfig.from_sae_runner_config(cfg))
init_sae_to_match_model(
    sae_bal_mat,
    toy_model,
    noise_level=0.0,
    # feature_ordering=torch.tensor([0, 4, 1, 2, 3, 5, 6]),
    feature_ordering=torch.tensor([0, 1, 2, 3]),
)

train_toy_sae(sae_bal_mat, toy_model, feat_generator.sample)

In [None]:
from hedging_paper.toy_models.plotting import plot_sae_feat_cos_sims_seaborn
from hedging_paper.toy_models.plotting import plot_sae_feat_cos_sims

plot_sae_feat_cos_sims(
    sae_bal_mat,
    toy_model,
    "Balanced Matryoshka SAE",
    show_values=False,
)
plot_sae_feat_cos_sims_seaborn(
    sae_bal_mat,
    toy_model,
    title=r"Balanced Matryoshka SAE ($\beta = 0.25$)",
    show_values=False,
    width=3,
    height=1.5,
    save_path="plots/toy_mat_balanced.pdf",
    one_based_indexing=True,
)

# Is it always possible to balance hedging and absorption?

In the above toy models, all child features fire with probability 0.15. What if they fire with different probabilities?

In [8]:
from hedging_paper.toy_models.toy_model import ToyModel
from hedging_paper.toy_models.tree_feature_generator import TreeFeatureGenerator

# Create the toy model and hierarchical features


children1 = []
children1.append(TreeFeatureGenerator(0.02))
children1.append(TreeFeatureGenerator(0.15))
children1.append(TreeFeatureGenerator(0.5))
children1.append(TreeFeatureGenerator(0.55, is_read_out=False))
parent1 = TreeFeatureGenerator(0.25, children1, mutually_exclusive_children=False)

feat_generator_unbalanced = TreeFeatureGenerator(1.0, [parent1], is_read_out=False)

In [None]:
import torch
from hedging_paper.saes.matryoshka_sae import MatryoshkaSAE, MatryoshkaSAEConfig, MatryoshkaSAERunnerConfig
from hedging_paper.toy_models.initialization import init_sae_to_match_model
from hedging_paper.toy_models.train_toy_sae import train_toy_sae

cfg = MatryoshkaSAERunnerConfig(
    d_in=toy_model.embed.weight.shape[0],
    d_sae=DEFAULT_D_SAE,
    l1_coefficient=3e-2,
    normalize_sae_decoder=False,
    scale_sparsity_penalty_by_decoder_norm=True,
    init_encoder_as_decoder_transpose=True,
    apply_b_dec_to_input=True,
    b_dec_init_method="zeros",
    matryoshka_steps=[1, DEFAULT_D_SAE],
    skip_outer_loss=True,
    matryoshka_inner_loss_multipliers=[0.17, 1.0],
)
sae_unbal_mat = MatryoshkaSAE(MatryoshkaSAEConfig.from_sae_runner_config(cfg))
init_sae_to_match_model(
    sae_unbal_mat,
    toy_model,
    noise_level=0.0,
    feature_ordering=torch.tensor([0, 1, 2, 3]),
)


train_toy_sae(sae_unbal_mat, toy_model, feat_generator_unbalanced.sample)

In [None]:
from hedging_paper.toy_models.plotting import plot_sae_feat_cos_sims_seaborn
from hedging_paper.toy_models.plotting import plot_sae_feat_cos_sims

plot_sae_feat_cos_sims(
    sae_unbal_mat,
    toy_model,
    "Unbalanceable Matryoshka SAE",
    show_values=False,
)
plot_sae_feat_cos_sims_seaborn(
    sae_unbal_mat,
    toy_model,
    title=r"Unbalanceable Matryoshka SAE ($\beta = 0.17$)",
    show_values=False,
    width=3,
    height=1.5,
    save_path="plots/toy_mat_unbalanceable.pdf",
    one_based_indexing=True,
)