In [None]:
import torch
import numpy as np
from toy_model import Tree, TreeDataset
import json


DEVICE = "cuda" if torch.cuda.is_available() else "cpu"

tree_dict = json.load(open('./tree.json', 'r'))
tree = Tree(tree_dict=tree_dict)

D_MODEL = tree.n_features

matryoshka_config = {
    'n_latents': tree.n_features,
    'target_l0': 1.2338, # L0 of the true features
    'n_prefixes': 10,
    'd_model': D_MODEL,
    'n_steps': 40_000, # You could try fewer steps if this runs too slow for your taste; 15K e.g.
    'lr': 3e-2,
    'permute_latents': True,
    'sparsity_type': 'l1',
    'starting_sparsity_loss_scale': 0.2
}


vanilla_config = matryoshka_config | {'n_prefixes': 1, 'permute_latents': False}

In [None]:
from scipy.optimize import linear_sum_assignment

@torch.no_grad()
def get_latent_perm(sae, tree_ds, include_all_latents=False):
    '''
    Get permutation of sae latents that tries to assign each latent to its closest-matching ground-truth feature
    H.T. Julian D'Costa for the linear_sum_assignment idea here.
    '''
    global DEVICE
    sample = tree_ds.tree.sample(1000).to(DEVICE)
    x = sample @ tree_ds.true_feats.to(DEVICE)
    
    feat_norms = tree_ds.true_feats.norm(dim=-1).to(DEVICE)
    scaled_sample = sample * feat_norms[None, :]

    true_acts = scaled_sample
    true_acts = true_acts/true_acts.max(dim=0, keepdim=True).values.clamp(min=1e-10)

    with torch.no_grad():
        sae_acts = sae.get_acts(x)
        sae_acts = sae_acts/sae_acts.max(dim=0, keepdim=True).values.clamp(min=1e-10)
        sims = (sae_acts.T @ true_acts).cpu()

    max_value = sims.max()
    cost_matrix = max_value - sims

    row_ind, col_ind = row_ind, col_ind = linear_sum_assignment(cost_matrix.detach().cpu().numpy().T)
    leftover_features = np.array(list({i for i in range(sims.shape[0]) if sae_acts.max(dim=0).values[i] > 0} - set(col_ind))).astype(int)
    
    if include_all_latents:
        return torch.tensor(np.concatenate([col_ind, leftover_features]))
    else:
        return torch.tensor(col_ind)

In [None]:

from torch.utils.data import DataLoader
from sae import MatryoshkaSAE
from tqdm import tqdm

# Get true_feats
true_feats = (torch.randn(tree.n_features, D_MODEL)/np.sqrt(D_MODEL)).to(DEVICE)

# Generate random orthogonal directions by taking QR decomposition of random matrix
Q, R = torch.linalg.qr(torch.randn(D_MODEL, D_MODEL))
true_feats = Q[:tree.n_features].to(DEVICE)


# randomly scale the norms of features
random_scaling = 1+torch.randn(tree.n_features, device=DEVICE)* 0.05
true_feats *= random_scaling[:, None]


# Setup dataloader
dataset = TreeDataset(tree, true_feats.cpu(), batch_size=200, num_batches=vanilla_config['n_steps'])
dataloader = DataLoader(dataset, batch_size=None, num_workers=6, pin_memory=True)

# Create a reference batch to visualize vanilla and matryoshka with over the course of training
ref_acts = dataset.tree.sample(100).to(DEVICE)
ref_x = ref_acts @ dataset.true_feats.to(DEVICE)
ref_acts = ref_acts*random_scaling[None]


In [None]:
from IPython.display import clear_output
from heatmap import heatmap


vanilla_sae = MatryoshkaSAE(**vanilla_config).to(DEVICE)
matryoshka_sae = MatryoshkaSAE(**matryoshka_config).to(DEVICE)

for step, batch in tqdm(enumerate(dataloader), total=vanilla_config['n_steps']):
    batch = batch.to(DEVICE)

    matryoshka_sae.step(batch)
    vanilla_sae.step(batch)

    if step % 150 == 0:
        clear_output(wait=True)
    
        heatmap(ref_acts.cpu(), title='Ground-Truth Features').show()

        matryoshka_perm = get_latent_perm(matryoshka_sae, dataset)
        heatmap(matryoshka_sae.get_acts(ref_x)[:,matryoshka_perm].cpu(), title=f'Matryoshka Latents  |  Sparsity Reg: {matryoshka_sae.sparsity_controller():.2f}  |  Step {step}',).show()

        vanilla_perm = get_latent_perm(vanilla_sae, dataset)
        heatmap(vanilla_sae.get_acts(ref_x)[:,vanilla_perm].cpu(), title=f'Vanilla Latents  |  Sparsity Reg: {vanilla_sae.sparsity_controller():.2f}  |  Step {step}').show()

In [None]:
matryoshka_perm = get_latent_perm(matryoshka_sae, dataset)
vanilla_perm = get_latent_perm(vanilla_sae, dataset)

# Compute cosine similarity matrix of ground truth features
gt_cosine = F.normalize(true_feats, dim=1) @ F.normalize(true_feats, dim=1).T
gt_sims = heatmap(gt_cosine.cpu(), title='Ground Truth Feature Cosine Similarity')
gt_sims.show()

# Compute cosine similarity between learned and ground truth features
matryoshka_feats = matryoshka_sae.W_dec.data
vanilla_feats = vanilla_sae.W_dec.data

matryoshka_cosine = F.normalize(matryoshka_feats, dim=1)[matryoshka_perm] @ F.normalize(true_feats, dim=1).T
vanilla_cosine = F.normalize(vanilla_feats, dim=1)[vanilla_perm] @ F.normalize(true_feats, dim=1).T

m_dec_sims = heatmap(matryoshka_cosine.cpu(), title='Matryoshka Decoder, Ground-Truth Cosine Similarity', dim_names=('Matryoshka', 'True Feature'))
m_dec_sims.show()

v_dec_sims = heatmap(vanilla_cosine.cpu(), title='Vanilla Decoder, Ground-Truth Cosine Similarity', dim_names=('Vanilla', 'Ground-Truth'))
v_dec_sims.show()

# Compare encoder and decoder weights
matryoshka_enc = matryoshka_sae.W_enc.data.T
vanilla_enc = vanilla_sae.W_enc.data.T

matryoshka_enc_true = F.normalize(matryoshka_enc, dim=1)[matryoshka_perm] @ F.normalize(true_feats, dim=1).T
vanilla_enc_true = F.normalize(vanilla_enc, dim=1)[vanilla_perm] @ F.normalize(true_feats, dim=1).T


m_enc_sims = heatmap(matryoshka_enc_true.cpu(), title='Matryoshka Encoder, Ground-Truth Cosine Similarity', dim_names=('Latent', 'Feature'))
m_enc_sims.show()
v_enc_sims = heatmap(vanilla_enc_true.cpu(), title='Vanilla Encoder, Ground-Truth Cosine Similarity', dim_names=('Latent', 'Feature'))
v_enc_sims.show()


In [None]:
# Save plotly figures locally
gt_sims.write_html("figures/gt_feature_cosine.html")
m_dec_sims.write_html("figures/matryoshka_decoder_gt_cosine.html") 
v_dec_sims.write_html("figures/vanilla_decoder_gt_cosine.html")
m_enc_sims.write_html("figures/matryoshka_encoder_gt_cosine.html")
v_enc_sims.write_html("figures/vanilla_encoder_gt_cosine.html")

# Save plotly figures as PNG files
gt_sims.write_image("figures/gt_feature_cosine.png", scale=4)
m_dec_sims.write_image("figures/matryoshka_decoder_gt_cosine.png", scale=4)
v_dec_sims.write_image("figures/vanilla_decoder_gt_cosine.png", scale=4)
m_enc_sims.write_image("figures/matryoshka_encoder_gt_cosine.png", scale=4)
v_enc_sims.write_image("figures/vanilla_encoder_gt_cosine.png", scale=4)

