In [68]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import torch.distributions as dists
import numpy as np

from interp_utils import reload_module


In [329]:

reload_module('interp_models')
from interp_models import Autoencoder, SparseAutoencoder, SparseNNMF
from tqdm import tqdm
from interp_utils import get_scheduler

reload_module('toy_models')
from toy_models import SparseIndependent, MonsterToy

N_FEATURES = 300
D_MODEL = 100
N_EPOCHS = 1
CODE_STEPS = 3000
ATOM_STEPS=1000
BATCH_SIZE=10000
SPARSE_COEF = 1

ORTHOG_K=False
ORTHOG_COEF = 1e-1 if ORTHOG_K is not False else 0.0


# toy = SparseIndependent(n_features=N_FEATURES, d_model=D_MODEL, feature_sparsity=0.04)
toy = MonsterToy(d_model=100, n_features=N_FEATURES, feature_prob=0.04, n_monster_features=6)


hidden_state, ground_truth = toy(D_MODEL)

nnmf = SparseNNMF(n_features=N_FEATURES, d_model=D_MODEL, orthog_k=ORTHOG_K, bias=True)


for epoch in range(N_EPOCHS):
    optimizer = optim.Adam(nnmf.parameters(), lr=1e-2)
    scheduler = get_scheduler(optimizer, CODE_STEPS)

    batch, true_codes = toy(BATCH_SIZE)


    # update codes
    nnmf.train(batch, frozen_atoms=epoch > 1, sparse_coef=SPARSE_COEF, n_steps=CODE_STEPS, reinit_codes=True, orthog_coef=ORTHOG_COEF, mean_init=False)

    # update atoms
    nnmf.train(batch, frozen_codes=True, sparse_coef=SPARSE_COEF, n_steps=ATOM_STEPS, orthog_coef=ORTHOG_COEF)



loss: 0.020, mse: 0.020, sparse: 0.013: 100%|██████████| 3000/3000 [00:42<00:00, 69.92it/s]
loss: 0.007, mse: 0.007: 100%|██████████| 1000/1000 [00:09<00:00, 100.44it/s]


In [330]:
# # (1-(0.9)**np.arange(300))
# import matplotlib.pyplot as plt

# N_DIMS = 300

# plt.plot(np.arange(1, N_DIMS+1)/N_DIMS)

In [213]:
# normed_atoms = nnmf.normed_atoms
# codes = nnmf.codes.detach()

# orthog_k=5

# topk_vals, topk_idx = codes.topk(dim=-1, k=orthog_k)

# # topk_idx.shape


# active_atoms = torch.index_select(normed_atoms, dim=0, index=topk_idx.view(-1)).view(*topk_idx.shape, -1)
# mask = 1-torch.eye(orthog_k)
# see(active_atoms)

# active_atom_sims = torch.einsum('bkd,bld,kl->bkl', active_atoms, active_atoms, mask).abs().mean()*((orthog_k**2 - orthog_k)/orthog_k**2)

>> active_atoms: (10000, 5, 100)


In [222]:
active_atom_sims

tensor(0.0971, grad_fn=<MulBackward0>)

In [224]:
# from interp_utils import heatmap
# # (batch @ batch.T)
# batch_subset = batch[:100]

# heatmap(batch_subset @ batch_subset.T)


In [331]:
feature_idx = 2

codes = nnmf.codes.detach()
atoms = nnmf.atoms.detach()
toy_features = toy.features

atom_idx = (atoms @ toy_features.T)[:,feature_idx].argmax()
topk_codes = codes[:,atom_idx].topk(k=20).indices
true_codes[topk_codes][:,[0,1,feature_idx]]

tensor([[2.2212, 2.4398, 2.8381],
        [1.9418, 2.4799, 2.6299],
        [2.1087, 2.5530, 2.8127],
        [2.5247, 1.8618, 2.7629],
        [2.2984, 2.0024, 2.8821],
        [2.3457, 2.1689, 2.5403],
        [2.1748, 2.1496, 2.8548],
        [2.3685, 2.6578, 2.5833],
        [2.2339, 2.7037, 2.4627],
        [2.1698, 2.4599, 2.6256],
        [2.2727, 2.1225, 2.7570],
        [2.7227, 2.2803, 2.7523],
        [2.1106, 2.3806, 2.6335],
        [2.4311, 2.4945, 2.7650],
        [2.1324, 2.6821, 2.6609],
        [2.5612, 2.1501, 2.7052],
        [2.1957, 2.0654, 2.6470],
        [1.9471, 2.0592, 2.7063],
        [2.5499, 2.5105, 2.9398],
        [1.8888, 2.4015, 2.6807]])

In [543]:
# batch, features = toy(BATCH_SIZE)
# nnmf.train(batch, frozen_atoms=epoch > 1, sparse_coef=SPARSE_COEF, n_steps=CODE_STEPS, reinit_codes=True)

loss: 0.014, mse: 0.002, sparse: 0.012: 100%|██████████| 1000/1000 [00:24<00:00, 40.39it/s]


In [332]:
heatmap(nnmf.atoms @ nnmf.atoms.T)

In [333]:
from interp_utils import hist, heatmap

heatmap(nnmf.atoms @ toy.normed_features.T, dim_names=('learned atoms', 'true features'))

In [334]:
hist((nnmf.atoms @ toy.normed_features.T).abs().max(dim=0).values, info=range(N_FEATURES))

In [411]:
# from interp_utils import see, asee





tensor([[2.4700, 0.5899, 0.0000,  ..., 0.0000, 0.0000, 0.0000],
        [2.6842, 0.0000, 0.0000,  ..., 0.0000, 0.0000, 0.0000],
        [2.6518, 0.0000, 0.0000,  ..., 0.0000, 0.0000, 0.0000],
        ...,
        [2.5777, 0.0000, 0.0000,  ..., 0.0000, 0.0000, 0.0000],
        [2.2689, 0.0000, 0.0000,  ..., 0.0000, 0.0000, 0.0000],
        [2.5382, 0.0000, 0.0000,  ..., 0.0000, 0.0000, 0.0000]])

In [13]:
N_FEATURES = 10
FEATURE_PROB = 0.1
probs = torch.ones(N_FEATURES,)*FEATURE_PROB

dists.Bernoulli(probs).sample((10,))

tensor([[0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
        [1., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
        [0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
        [0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
        [0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
        [0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
        [0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
        [0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
        [0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
        [0., 0., 0., 0., 0., 0., 1., 0., 0., 0.]])