In [22]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import torch.distributions as dists
import numpy as np

from noa_tools import reload_module
from toy_model import Tree, TreeDataset

import json

# DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
DEVICE = 'cpu'



tree_dict = json.load(open('./alt_tree.json', 'r'))

tree = Tree(tree_dict=tree_dict)

D_MODEL = tree.n_features

matryoshka_config = {
    'n_latents': tree.n_features,
    'target_l0': 1.2338, # L0 of the true features
    'n_prefixes': 10,
    'd_model': D_MODEL,
    'n_steps': 10000,
    'lr': 3e-2,
    'permute_latents': True,
    'sparsity_type': 'l1'
}


vanilla_config = matryoshka_config | {'n_prefixes': 1, 'permute_latents': False}



In [17]:
vanilla_config

{'n_latents': 20,
 'target_l0': 1.2338,
 'n_prefixes': 1,
 'd_model': 20,
 'n_steps': 10000,
 'lr': 0.03,
 'permute_latents': False,
 'sparsity_type': 'l1'}

In [23]:
from scipy.optimize import linear_sum_assignment

@torch.no_grad()
def get_latent_perm(sae, tree_ds, include_all_latents=False):
    '''
    Get permutation of sae latents that tries to assign each latent to its closest-matching ground-truth feature
    '''
    global DEVICE
    sample = tree_ds.tree.sample(1000).to(DEVICE)
    x = sample @ tree_ds.true_feats.to(DEVICE)
    
    feat_norms = tree_ds.true_feats.norm(dim=-1).to(DEVICE)
    scaled_sample = sample * feat_norms[None, :]

    true_acts = scaled_sample
    true_acts = true_acts/true_acts.norm(dim=0, keepdim=True).clamp(min=1e-10)

    with torch.no_grad():
        sae_acts = sae.get_acts(x)
        sae_acts = sae_acts/sae_acts.norm(dim=0, keepdim=True).clamp(min=1e-10)
        sims = (sae_acts.T @ true_acts).cpu()

    max_value = sims.max()
    cost_matrix = max_value - sims

    row_ind, col_ind = row_ind, col_ind = linear_sum_assignment(cost_matrix.detach().cpu().numpy().T)
    leftover_features = np.array(list(set(np.arange(sims.shape[0])) - set(col_ind))).astype(int)
    
    if include_all_latents:
        return torch.tensor(np.concatenate([col_ind, leftover_features]))
    else:
        return torch.tensor(col_ind)

In [24]:

from torch.utils.data import DataLoader
from sae import MatryoshkaSAE
from tqdm import tqdm

# Get true_feats
true_feats = (torch.randn(tree.n_features, D_MODEL)/np.sqrt(D_MODEL)).to(DEVICE)

# Generate random orthogonal directions by taking QR decomposition of random matrix
Q, R = torch.linalg.qr(torch.randn(D_MODEL, D_MODEL))
true_feats = Q[:tree.n_features].to(DEVICE)


# random_scaling = 1+torch.randn(tree.n_features, device=DEVICE)* 0.05
# true_feats *= random_scaling[:, None]


dataset = TreeDataset(tree, true_feats.cpu(), batch_size=200, num_batches=vanilla_config['n_steps'])
dataloader = DataLoader(dataset, batch_size=None, num_workers=6, pin_memory=True)

ref_acts = dataset.tree.sample(100).to(DEVICE)
random_scale = torch.randn_like(ref_acts, device=DEVICE) * 0.05 + 1
ref_acts = ref_acts * random_scale
ref_x = ref_acts @ dataset.true_feats.to(DEVICE)

In [25]:
ref_acts.shape

torch.Size([100, 20])

In [26]:
from IPython.display import clear_output
from heatmap import heatmap


vanilla_sae = MatryoshkaSAE(**vanilla_config).to(DEVICE)
matryoshka_sae = MatryoshkaSAE(**matryoshka_config).to(DEVICE)

for step, batch in tqdm(enumerate(dataloader), total=vanilla_config['n_steps']):
    batch = batch.to(DEVICE)

    matryoshka_sae.step(batch)
    vanilla_sae.step(batch)

    if step % 300 == 0:
        clear_output(wait=True)

        

        heatmap(ref_acts.cpu(), title='true feats').show()

        matryoshka_perm = get_latent_perm(matryoshka_sae, dataset, include_all_latents=True)
        heatmap(matryoshka_sae.get_acts(ref_x)[:,matryoshka_perm].cpu(), title=f'Matryoshka Latents  |  Sparsity: {matryoshka_sae.sparsity_controller.sparsity_loss_scale:.2f}  |  Step {step}',).show()

        vanilla_perm = get_latent_perm(vanilla_sae, dataset, include_all_latents=True)
        heatmap(vanilla_sae.get_acts(ref_x)[:,vanilla_perm].cpu(), title=f'Vanilla Latents  |  Sparsity: {matryoshka_sae.sparsity_controller.sparsity_loss_scale:.2f}  |  Step {step}').show()

100%|██████████| 10000/10000 [07:03<00:00, 23.63it/s]
