In [68]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import torch.distributions as dists
import numpy as np

from interp_utils import reload_module


In [274]:

reload_module('interp_models')
from interp_models import Autoencoder, SparseAutoencoder, SparseNNMF
from tqdm import tqdm
from interp_utils import get_scheduler

reload_module('toy_models')
from toy_models import SparseIndependent, MonsterToy

N_FEATURES = 300
D_MODEL = 100
N_EPOCHS = 1
CODE_STEPS = 1000
ATOM_STEPS=300
BATCH_SIZE=10000
SPARSE_COEF = 1

ORTHOG_COEF = 0.0
ORTHOG_K=False

# toy = SparseIndependent(n_features=N_FEATURES, d_model=D_MODEL, feature_sparsity=0.04)
toy = MonsterToy(d_model=100, n_features=N_FEATURES, feature_prob=0.04, n_monster_features=2)


hidden_state, ground_truth = toy(D_MODEL)

nnmf = SparseNNMF(n_features=N_FEATURES, d_model=D_MODEL, orthog_k=ORTHOG_K)


for epoch in range(N_EPOCHS):
    optimizer = optim.Adam(nnmf.parameters(), lr=1e-2)
    scheduler = get_scheduler(optimizer, CODE_STEPS)

    batch, true_codes = toy(BATCH_SIZE)


    # update codes
    nnmf.train(batch, frozen_atoms=epoch > 1, sparse_coef=SPARSE_COEF, n_steps=CODE_STEPS, reinit_codes=True, orthog_coef=ORTHOG_COEF, mean_init=True)

    # update atoms
    nnmf.train(batch, frozen_codes=True, sparse_coef=SPARSE_COEF, n_steps=ATOM_STEPS, orthog_coef=ORTHOG_COEF)



loss: 0.021, mse: 0.021, sparse: 0.017: 100%|██████████| 1000/1000 [00:14<00:00, 67.40it/s]
loss: 0.006, mse: 0.006: 100%|██████████| 300/300 [00:02<00:00, 105.56it/s]


In [213]:
# normed_atoms = nnmf.normed_atoms
# codes = nnmf.codes.detach()

# orthog_k=5

# topk_vals, topk_idx = codes.topk(dim=-1, k=orthog_k)

# # topk_idx.shape


# active_atoms = torch.index_select(normed_atoms, dim=0, index=topk_idx.view(-1)).view(*topk_idx.shape, -1)
# mask = 1-torch.eye(orthog_k)
# see(active_atoms)

# active_atom_sims = torch.einsum('bkd,bld,kl->bkl', active_atoms, active_atoms, mask).abs().mean()*((orthog_k**2 - orthog_k)/orthog_k**2)

>> active_atoms: (10000, 5, 100)


In [222]:
active_atom_sims

tensor(0.0971, grad_fn=<MulBackward0>)

In [224]:
# from interp_utils import heatmap
# # (batch @ batch.T)
# batch_subset = batch[:100]

# heatmap(batch_subset @ batch_subset.T)


In [275]:
feature_idx = 5

codes = nnmf.codes.detach()
atoms = nnmf.atoms.detach()
toy_features = toy.features

atom_idx = (atoms @ toy_features.T)[:,feature_idx].argmax()
topk_codes = codes[:,atom_idx].topk(k=20).indices
true_codes[topk_codes][:,[0,1,feature_idx]]

tensor([[1.6138, 1.8229, 0.9078],
        [1.9218, 1.7812, 0.9367],
        [1.6823, 1.6527, 0.9919],
        [1.7219, 1.7707, 0.9479],
        [1.7598, 1.9963, 0.9973],
        [1.8471, 1.6432, 0.9504],
        [1.8497, 1.9421, 0.9621],
        [1.9194, 1.8449, 0.9535],
        [1.8169, 1.9619, 0.9817],
        [1.8431, 1.8067, 0.9909],
        [2.0978, 1.9129, 0.9723],
        [1.9847, 1.8382, 0.8772],
        [2.0560, 2.0116, 0.8691],
        [1.9025, 1.9260, 0.9565],
        [1.6863, 1.8947, 0.9071],
        [1.8302, 1.8069, 0.9739],
        [1.9795, 2.0230, 0.9700],
        [1.5922, 1.8007, 0.9831],
        [1.8085, 1.8027, 0.9878],
        [2.0341, 1.8066, 0.8857]])

In [543]:
# batch, features = toy(BATCH_SIZE)
# nnmf.train(batch, frozen_atoms=epoch > 1, sparse_coef=SPARSE_COEF, n_steps=CODE_STEPS, reinit_codes=True)

loss: 0.014, mse: 0.002, sparse: 0.012: 100%|██████████| 1000/1000 [00:24<00:00, 40.39it/s]


In [276]:
heatmap(nnmf.atoms @ nnmf.atoms.T)

In [277]:
from interp_utils import hist, heatmap

heatmap(nnmf.atoms @ toy.normed_features.T, dim_names=('learned atoms', 'true features'))

In [411]:
# from interp_utils import see, asee





tensor([[2.4700, 0.5899, 0.0000,  ..., 0.0000, 0.0000, 0.0000],
        [2.6842, 0.0000, 0.0000,  ..., 0.0000, 0.0000, 0.0000],
        [2.6518, 0.0000, 0.0000,  ..., 0.0000, 0.0000, 0.0000],
        ...,
        [2.5777, 0.0000, 0.0000,  ..., 0.0000, 0.0000, 0.0000],
        [2.2689, 0.0000, 0.0000,  ..., 0.0000, 0.0000, 0.0000],
        [2.5382, 0.0000, 0.0000,  ..., 0.0000, 0.0000, 0.0000]])

In [13]:
N_FEATURES = 10
FEATURE_PROB = 0.1
probs = torch.ones(N_FEATURES,)*FEATURE_PROB

dists.Bernoulli(probs).sample((10,))

tensor([[0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
        [1., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
        [0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
        [0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
        [0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
        [0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
        [0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
        [0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
        [0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
        [0., 0., 0., 0., 0., 0., 1., 0., 0., 0.]])