In [556]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import torch.distributions as dists
import numpy as np

from interp_utils import reload_module
reload_module('toy_models')
from toy_models import MonsterToy

In [557]:
from interp_utils import reload_module

reload_module('sparse_features')
from interp_models import Autoencoder, SparseAutoencoder, SparseNNMF
from tqdm import tqdm
from interp_utils import get_scheduler

reload_module('toy_models')
from toy_models import SparseIndependent

N_FEATURES = 600
D_MODEL = 100
N_EPOCHS = 4
CODE_STEPS = 1000
ATOM_STEPS=300
BATCH_SIZE=10000
SPARSE_COEF = 1

# toy = SparseIndependent(n_features=N_FEATURES, d_model=D_MODEL, feature_sparsity=0.04)
toy = MonsterToy(d_model=100, n_features=300, feature_prob=0.04, n_monster_features=2)


hidden_state, ground_truth = toy(D_MODEL)

nnmf = SparseNNMF(n_features=N_FEATURES, d_model=D_MODEL)


for epoch in range(N_EPOCHS):
    optimizer = optim.Adam(nnmf.parameters(), lr=1e-2)
    scheduler = get_scheduler(optimizer, CODE_STEPS)

    batch, features = toy(BATCH_SIZE)

    # update codes
    nnmf.train(batch, frozen_atoms=epoch > 1, sparse_coef=SPARSE_COEF, n_steps=CODE_STEPS, reinit_codes=True)

    # update atoms
    nnmf.train(batch, frozen_codes=True, sparse_coef=SPARSE_COEF, n_steps=ATOM_STEPS)



loss: 0.014, mse: 0.003, sparse: 0.011: 100%|██████████| 1000/1000 [00:23<00:00, 42.75it/s]
loss: 0.003, mse: 0.003: 100%|██████████| 300/300 [00:04<00:00, 65.74it/s]
loss: 0.011, mse: 0.001, sparse: 0.009: 100%|██████████| 1000/1000 [00:26<00:00, 37.76it/s]
loss: 0.002, mse: 0.002: 100%|██████████| 300/300 [00:04<00:00, 66.22it/s]
loss: 0.011, mse: 0.002, sparse: 0.010: 100%|██████████| 1000/1000 [00:20<00:00, 48.02it/s]
loss: 0.002, mse: 0.002: 100%|██████████| 300/300 [00:05<00:00, 57.08it/s]
loss: 0.012, mse: 0.002, sparse: 0.010: 100%|██████████| 1000/1000 [00:21<00:00, 46.64it/s]
loss: 0.002, mse: 0.002: 100%|██████████| 300/300 [00:04<00:00, 66.74it/s]


In [543]:
# batch, features = toy(BATCH_SIZE)
# nnmf.train(batch, frozen_atoms=epoch > 1, sparse_coef=SPARSE_COEF, n_steps=CODE_STEPS, reinit_codes=True)

loss: 0.014, mse: 0.002, sparse: 0.012: 100%|██████████| 1000/1000 [00:24<00:00, 40.39it/s]


In [558]:
# (batch @ batch.T)
batch_subset = batch[:100]

heatmap(batch_subset @ batch_subset.T)

In [562]:
topk_codes = nnmf.codes[:,172].topk(k=30).indices
features[topk_codes][:,[0,1,2]]

tensor([[1.2102, 1.4060, 0.9987],
        [1.7716, 1.3988, 0.9163],
        [1.7141, 1.3502, 0.9768],
        [1.2433, 1.6661, 0.9631],
        [1.6535, 1.8584, 0.9559],
        [1.0638, 1.7815, 0.9544],
        [1.3805, 1.1292, 0.8338],
        [1.4810, 1.8474, 0.9757],
        [1.4764, 1.7212, 0.9664],
        [1.6965, 1.4039, 0.8986],
        [1.8118, 1.9962, 0.9260],
        [1.6731, 1.1524, 0.9400],
        [1.7570, 1.9937, 0.9597],
        [1.7809, 1.1539, 0.9566],
        [1.3433, 1.5339, 0.8570],
        [1.6447, 1.8573, 0.9917],
        [1.5770, 1.4160, 0.9622],
        [1.3407, 1.3756, 0.8373],
        [1.2254, 1.7823, 0.9514],
        [1.4192, 1.1243, 0.9046],
        [1.4260, 1.8068, 0.9446],
        [1.3758, 1.4663, 0.8543],
        [1.2823, 1.3893, 0.9102],
        [1.0350, 1.4892, 0.9742],
        [1.4036, 1.2398, 0.9101],
        [1.5026, 1.9229, 0.8352],
        [1.0147, 1.9744, 0.8811],
        [1.4987, 1.6344, 0.9861],
        [1.1361, 1.8577, 0.9735],
        [1.058

In [519]:
features[0] @ features[10]

tensor(48.4449)

In [None]:
features[0]

In [560]:
from interp_utils import hist, heatmap

heatmap(nnmf.atoms @ toy.normed_features.T, dim_names=('learned atoms', 'true features'))

In [411]:
# from interp_utils import see, asee





tensor([[2.4700, 0.5899, 0.0000,  ..., 0.0000, 0.0000, 0.0000],
        [2.6842, 0.0000, 0.0000,  ..., 0.0000, 0.0000, 0.0000],
        [2.6518, 0.0000, 0.0000,  ..., 0.0000, 0.0000, 0.0000],
        ...,
        [2.5777, 0.0000, 0.0000,  ..., 0.0000, 0.0000, 0.0000],
        [2.2689, 0.0000, 0.0000,  ..., 0.0000, 0.0000, 0.0000],
        [2.5382, 0.0000, 0.0000,  ..., 0.0000, 0.0000, 0.0000]])

In [13]:
N_FEATURES = 10
FEATURE_PROB = 0.1
probs = torch.ones(N_FEATURES,)*FEATURE_PROB

dists.Bernoulli(probs).sample((10,))

tensor([[0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
        [1., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
        [0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
        [0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
        [0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
        [0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
        [0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
        [0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
        [0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
        [0., 0., 0., 0., 0., 0., 1., 0., 0., 0.]])