In [79]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import torch.distributions as dists

from interp_utils import reload_module

reload_module('toy_models')
reload_module('discrete_tree')
from discrete_tree import Tree

import json

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
# device = 'cpu'

tree_dict = json.load(open('./tree.json', 'r'))

tree = Tree(tree_dict=tree_dict)

batch = tree.sample(30,)

tree

  B  1.0
  0 BA 0.15
    1 B  0.2
    2 B  0.2
    3 B  0.2
      B  0.4
  4 BA 0.15
    5 B  0.2
    6 B  0.2
    7 B  0.2
      B  0.4
  8 BA 0.15
    9 B  0.2
    10 B  0.2
    11 B  0.2
      B  0.4
  12 B  0.05
  13 B  0.05
  14 B  0.05
  15 B  0.05
  16 B  0.05
  17 B  0.05

In [80]:
from interp_utils import heatmap

heatmap(tree.sample((30,)))


In [81]:
reload_module('interp_models')
from interp_models import Autoencoder, SparseAutoencoder, SparseNNMF
from tqdm import tqdm
from interp_utils import get_scheduler

D_MODEL = 40
N_STEPS = 300
BATCH_SIZE = 100

storage_autoencoder = Autoencoder(n_features=tree.n_features, d_model=D_MODEL).to(device)

optimizer = optim.AdamW(storage_autoencoder.parameters(), lr=1e-2)
scheduler = get_scheduler(optimizer, N_STEPS)

pbar = tqdm(range(N_STEPS))
for i in pbar:
    optimizer.zero_grad()
    batch = tree.sample(100).to(device)
    reconstruction = storage_autoencoder(batch)
    
    loss = F.mse_loss(reconstruction, batch)
    loss.backward()
    optimizer.step()
    scheduler.step()

    if i % 10 == 0:
        pbar.set_description(f'loss: {loss.item():.3f}')

write_out = storage_autoencoder.encoder.weight.data.cpu().T
write_out_b = storage_autoencoder.encoder.bias.data.cpu().T

read_in = storage_autoencoder.decoder.weight.cpu().data
read_in_b = storage_autoencoder.decoder.bias.cpu().data


normed_write_out = F.normalize(write_out, dim=1)
normed_read_in = F.normalize(read_in, dim=0)

loss: 0.000: 100%|██████████| 300/300 [00:07<00:00, 41.57it/s]


In [89]:
tree

  B  1.0
  0 BA 0.15
    1 B  0.2
    2 B  0.2
    3 B  0.2
      B  0.4
  4 BA 0.15
    5 B  0.2
    6 B  0.2
    7 B  0.2
      B  0.4
  8 BA 0.15
    9 B  0.2
    10 B  0.2
    11 B  0.2
      B  0.4
  12 B  0.05
  13 B  0.05
  14 B  0.05
  15 B  0.05
  16 B  0.05
  17 B  0.05

In [90]:
from interp_utils import heatmap


heatmap(write_out @ read_in.T, dim_names=('write_out', 'read_in'))

In [97]:
reload_module('interp_models')
from interp_models import Autoencoder, SparseAutoencoder, SparseNNMF, SparseNNMFWCodeBias
from tqdm import tqdm
from interp_utils import get_scheduler

# D_MODEL = 
N_EPOCHS = 1
CODE_STEPS = 30000
ATOM_STEPS=1000
BATCH_SIZE=50000

SPARSE_COEF = 3e-1

ORTHOG_K = 2
ORTHOG_COEF = 1e-2
# ORTHOG_K = 3
# ORTHOG_COEF = 1
ORTHOG_K = 0
ORTHOG_COEF = 0

torch.manual_seed(0)


device = 'cuda:0'
# device = 'cpu'

storage_autoencoder = storage_autoencoder.to(device=device)
features_batch = tree.sample(BATCH_SIZE).to(device=device)
with torch.no_grad():
    batch = storage_autoencoder.encoder(features_batch)


nnmf = SparseNNMF(d_model=storage_autoencoder.d_model, n_features=tree.n_features, orthog_k=ORTHOG_K, bias=True).to(device=device)

nnmf.train(batch, frozen_atoms=False, sparse_coef=SPARSE_COEF, n_epochs=CODE_STEPS, orthog_coef=ORTHOG_COEF, lr=1e-2, mean_init=True)
# for epoch in range(N_EPOCHS):
    
#     # update codes
#     nnmf.train(batch, frozen_atoms=(epoch > 0), sparse_coef=SPARSE_COEF, n_epochs=CODE_STEPS, orthog_coef=ORTHOG_COEF, lr=1e-2, mean_init=True)

#     # if epoch > 0:
#     #     # update atoms
#     #     nnmf.train(batch, frozen_codes=False, sparse_coef=SPARSE_COEF, n_epochs=ATOM_STEPS, orthog_coef=ORTHOG_COEF, )



loss: 0.019, mse: 0.019, sparse: 0.044, orthog: 0.058: 100%|██████████| 30000/30000 [01:28<00:00, 337.51it/s]


In [98]:
import numpy as np
from scipy.optimize import linear_sum_assignment

atoms = nnmf.atoms.cpu().detach()
sims = atoms @ normed_write_out.T
max_value = sims.max()
cost_matrix = max_value - sims

row_ind, col_ind = linear_sum_assignment(cost_matrix.detach().numpy().T)
atom_perm = col_ind
# atom_perm

features = normed_write_out.detach()
heatmap(atoms[atom_perm] @ features.T, dim_names=('atoms', 'features')).show()
# heatmap(features @ features.T, dim_names=('features', 'features'))
heatmap(write_out @ read_in.T, dim_names=('write_out', 'read_in'))


In [99]:


probabilities = features_batch.mean(dim=0)
vis_probs = probabilities / probabilities.max()
vis_probs = vis_probs[None].repeat(7, 1)
codes = nnmf.codes().detach()
codes = codes / codes.abs().max(dim=0, keepdim=True).values#.clamp(min=0.6)
codes = codes[:, atom_perm]
codes = torch.cat([vis_probs, codes], dim=0)
f_batch = torch.cat([vis_probs, features_batch], dim=0)


In [100]:


# print(tree)
# heatmap(sims[:200], title='NNMF activations', dim_names=('batch', 'atom'), info_1={'prob': probabilities}).show()

heatmap(codes[:200], title='NNMF activations', dim_names=('batch', 'atom'), info_1={'prob': probabilities}).show()
heatmap(f_batch[:200], title='Ground truth feature activations', info_1={'prob': probabilities}).show()


In [None]:
# atom 8: 5
# atom 9: 4
# 10: 1
# 7: 6
# 6: 7????
# 5: 9??
# 4: 2
# 3: 10
# 2: 8
# 0: 3
# 1: 0???

perm = torch.tensor([3, 0, 8, 10, 2, 9, 7, 6, 5, 4, 1])

In [None]:
# nnmf.atoms @ 
batch_mean = batch.mean(dim=0)
mean_dir = batch_mean/torch.norm(batch_mean)
(nnmf.normed_atoms @ mean_dir)


tensor([-0.3209,  0.0588,  0.2081,  0.0929], grad_fn=<MvBackward0>)

In [None]:
features_batch[:20][:,-3:]

tensor([[0., 0., 0.],
        [1., 1., 0.],
        [0., 1., 0.],
        [0., 0., 0.],
        [0., 0., 0.],
        [0., 0., 0.],
        [0., 0., 0.],
        [0., 0., 0.],
        [0., 0., 0.],
        [1., 0., 0.],
        [0., 1., 0.],
        [0., 0., 0.],
        [1., 0., 0.],
        [0., 0., 0.],
        [0., 0., 0.],
        [0., 0., 0.],
        [0., 0., 1.],
        [0., 0., 0.],
        [0., 0., 1.],
        [0., 0., 0.]])

In [None]:
nnmf.codes()

tensor([[1.2809e+00, 5.2178e-01, 8.3384e-04, 8.2154e-04],
        [1.1812e-04, 5.4305e-02, 1.8370e+00, 4.4291e-04],
        [1.2233e-03, 1.3536e+00, 6.4142e-04, 4.8705e-04],
        ...,
        [5.7441e-04, 1.3546e+00, 4.7168e-04, 4.7056e-04],
        [1.1720e-03, 1.3542e+00, 5.3459e-04, 5.8248e-04],
        [9.2144e-04, 1.3545e+00, 4.8252e-04, 5.8358e-04]],
       grad_fn=<AbsBackward0>)

In [None]:
heatmap(read_in @ write_out.T)

In [None]:
tree

root
├── 0 B 0.2439
├── 1 B 0.1836
└── 2 BA 0.0395
    ├── 3 B 0.9133
    └── 4 B 0.0867

In [None]:
I = 8
f_subset = features_batch[30*I:30*(I+1)]
c_subset = nnmf.codes()[30*I:30*(I+1)]

heatmap(f_subset)

In [None]:
tree

root
├── 0 B 0.2439
├── 1 B 0.1836
└── 2 BA 0.0395
    ├── 3 B 0.9133
    └── 4 B 0.0867

In [None]:
heatmap(nnmf.codes()[:30])