In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import torch.distributions as dists

from noa_tools import reload_module

# reload_module('toy_models')
reload_module('discrete_tree')
from discrete_tree import Tree

import json

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
# device = 'cpu'

tree_dict = json.load(open('./tree.json', 'r'))

tree = Tree(tree_dict=tree_dict)

batch = tree.sample(30,)

In [None]:
eps = 0.05
batch = tree.sample(100,)
noise = 1 + eps * (torch.randn_like(batch))

noisy_batch = batch * noise

In [None]:
from noa_tools import heatmap

heatmap(noisy_batch)


In [389]:
reload_module('models')
from models import SegmentedSAE
from noa_tools import grid_from_configs
import numpy as np

default_config =  {
    'n_features' : [tree.n_features+20],
    'target_l0': [10],
    'n_blocks': [5],
    'min_features': [1],
    'd_model': [30],
    'beta1': [1],
    'beta2': [4],
    'project_name': ['tree_test'],
    'n_epochs': [1],
    'n_steps': [5000],
    'log_interval': [50],
    'lr': [1e-2],
    'block_weight_power': [1.35]
}

config = default_config

grid = grid_from_configs(config)

config = grid[0]

config['adam_beta1'] = (1-1/(2**config['beta1']))
config['adam_beta2'] = (1-1/(2**config['beta2']))

sae = SegmentedSAE(**config).to('cuda:0')
true_feats = torch.randn(tree.n_features, config['d_model']).to(sae.W_enc.device)/np.sqrt(config['d_model'])
# randomize true_feat_scales from 0.5 to 2
true_feat_scales = torch.rand(tree.n_features, device=sae.W_enc.device) * 1 + 1
true_feats *= true_feat_scales[:, None]


sae.block_weights


`torch.cuda.amp.GradScaler(args...)` is deprecated. Please use `torch.amp.GradScaler('cuda', args...)` instead.



Parameter containing:
tensor([0.1139, 0.2903, 0.5018, 0.7399, 1.0000], device='cuda:0')

In [390]:
from tqdm import tqdm

from torch.utils.data import Dataset, DataLoader

class TreeDataset(Dataset):
    def __init__(self, tree, true_feats, batch_size, num_batches):
        self.tree = tree
        self.true_feats = true_feats.cpu()
        self.batch_size = batch_size
        self.num_batches = num_batches

    def __len__(self):
        return self.num_batches

    def __getitem__(self, idx):
        true_acts = self.tree.sample(self.batch_size)
        x = true_acts @ self.true_feats
        return x

dataset = TreeDataset(tree, true_feats.cpu(), batch_size=200, num_batches=config['n_steps'])
dataloader = DataLoader(dataset, batch_size=None, num_workers=6, pin_memory=True)

for step, batch in tqdm(enumerate(dataloader), total=config['n_steps']):
    sae.step(batch.to(sae.W_enc.device))


100%|██████████| 5000/5000 [01:17<00:00, 64.83it/s]


In [391]:
true_feats_n = true_feats/true_feats.norm(dim=-1, keepdim=True)
sae_feats_n = sae.W_dec/sae.W_dec.norm(dim=-1, keepdim=True)

In [392]:
from scipy.optimize import linear_sum_assignment
import numpy as np

sims = (sae_feats_n @ true_feats_n.T).cpu()
max_value = sims.max()
cost_matrix = max_value - sims

row_ind, col_ind = row_ind, col_ind = linear_sum_assignment(cost_matrix.detach().cpu().numpy().T)
leftover_features = np.array(list(set(np.arange(sae_feats_n.shape[0])) - set(col_ind)))

matching_perm = torch.tensor(col_ind)
feat_perm = np.concatenate([col_ind, leftover_features])

feat_perm = torch.tensor(feat_perm)


feat_perm = matching_perm



In [393]:
sims = (true_feats_n @ sae_feats_n[feat_perm].T)

In [394]:
heatmap(sims, dim_names=('true', 'sae'))

In [395]:
true_feats.norm(dim=-1)

tensor([1.8080, 1.6467, 1.7225, 1.3327, 1.9697, 1.8803, 1.7138, 2.2196, 1.3900,
        1.5465, 0.8768, 0.8329, 1.3875, 1.9596, 1.3327, 1.3107, 1.2169, 1.8947,
        1.2024, 1.4808, 1.1072, 1.6775], device='cuda:0')

In [396]:
sae_acts

tensor([[ 0.0000,  0.0000,  0.0000,  ...,  0.0000,  0.0000,  0.0000],
        [ 0.0000,  0.0000,  0.0000,  ...,  0.0000,  0.0000,  0.0000],
        [ 0.0000,  0.0000,  0.0000,  ...,  0.0000,  0.0000,  0.0000],
        ...,
        [ 0.0000,  0.0000,  0.0000,  ...,  0.0000,  0.0000,  0.0000],
        [ 0.0000,  0.0000,  0.0000,  ..., 10.6564,  0.0000,  0.0000],
        [ 0.0000,  0.0000,  0.0000,  ..., 10.4083,  0.0000,  0.0000]],
       device='cuda:0')

In [398]:
sample = tree.sample(50).to(sae.W_enc.device)


# heatmap(sample)

with torch.no_grad():
    sae_acts = sae.get_acts(sample @ true_feats)*sae.W_dec.norm(dim=-1)[None]
    sae_acts = sae_acts[:,feat_perm]


heatmap(sae_acts, title='sae acts', dim_names=('batch', 'sae_feat')).show()


feat_norms = true_feats.norm(dim=-1)
scaled_sample = sample * feat_norms[None, :]
heatmap(scaled_sample, dim_names=('batch', 'true_feat'), title='true feats (scaled)')


In [168]:
heatmap(true_feats_n @ true_feats_n.T)

In [None]:
heatmap(true_feats_n @ sae_feats_n[feat_perm].T)