In [335]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import torch.distributions as dists

from interp_utils import reload_module

reload_module('toy_models')
reload_module('discrete_tree')
from discrete_tree import Tree

import json

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
device = 'cpu'

tree_dict = json.load(open('./tree.json', 'r'))

tree = Tree(tree_dict=tree_dict)

batch = tree.sample(30,)

In [356]:
from interp_utils import heatmap

heatmap(tree.sample((30,)))


In [357]:
reload_module('interp_models')
from interp_models import Autoencoder, SparseAutoencoder, SparseNNMF
from tqdm import tqdm
from interp_utils import get_scheduler

D_MODEL = 40
N_STEPS = 1000

storage_autoencoder = Autoencoder(n_features=tree.n_features, d_model=D_MODEL).to(device)

optimizer = optim.AdamW(storage_autoencoder.parameters(), lr=1e-3)
scheduler = get_scheduler(optimizer, N_STEPS)

pbar = tqdm(range(1000))
for i in pbar:
    optimizer.zero_grad()
    batch = tree.sample(100).to(device)
    reconstruction = storage_autoencoder(batch)
    
    loss = F.mse_loss(reconstruction, batch)
    loss.backward()
    optimizer.step()
    scheduler.step()

    pbar.set_description(f'loss: {loss.item():.3f}')

write_out = storage_autoencoder.encoder.weight.data.cpu().T
write_out_b = storage_autoencoder.encoder.bias.data.cpu().T

read_in = storage_autoencoder.decoder.weight.cpu().data
read_in_b = storage_autoencoder.decoder.bias.cpu().data


normed_write_out = F.normalize(write_out, dim=1)
normed_read_in = F.normalize(read_in, dim=0)

loss: 0.001: 100%|██████████| 1000/1000 [01:27<00:00, 11.47it/s]


In [359]:
from interp_utils import heatmap


heatmap(write_out @ read_in.T, dim_names=('write_out', 'read_in'))

In [392]:
reload_module('interp_models')
from interp_models import Autoencoder, SparseAutoencoder, SparseNNMF
from tqdm import tqdm
from interp_utils import get_scheduler

# D_MODEL = 
N_EPOCHS = 1
CODE_STEPS = 7000
ATOM_STEPS=4000
BATCH_SIZE=30000

SPARSE_COEF = 1

ORTHOG_K = 3
ORTHOG_COEF = 2

torch.manual_seed(0)


device = 'cuda:0'
# device = 'cpu'

storage_autoencoder = storage_autoencoder.to(device=device)
nnmf = SparseNNMF(d_model=storage_autoencoder.d_model, n_features=tree.n_features, orthog_k=ORTHOG_K, bias=True).to(device=device)


for epoch in range(N_EPOCHS):
    features_batch = tree.sample(BATCH_SIZE).to(device=device)
    with torch.no_grad():
        batch = storage_autoencoder.encoder(features_batch)
    # update codes
    nnmf.train(batch, frozen_atoms=(epoch > 0), sparse_coef=SPARSE_COEF, n_epochs=CODE_STEPS, orthog_coef=ORTHOG_COEF, lr=1e-2)

    if epoch > 0:
        # update atoms
        nnmf.train(batch, frozen_codes=False, sparse_coef=SPARSE_COEF, n_epochs=ATOM_STEPS, orthog_coef=ORTHOG_COEF, )



loss: 0.074, mse: 0.074, sparse: 0.001, orthog: 0.020: 100%|██████████| 7000/7000 [00:54<00:00, 128.96it/s]


In [393]:
import numpy as np
from scipy.optimize import linear_sum_assignment

atoms = nnmf.atoms.cpu().detach()
sims = atoms @ normed_write_out.T
max_value = sims.max()
cost_matrix = max_value - sims

row_ind, col_ind = linear_sum_assignment(cost_matrix.detach().numpy().T)
atom_perm = col_ind
# atom_perm

features = normed_write_out.detach()
heatmap(atoms[atom_perm] @ features.T, dim_names=('atoms', 'features')).show()
# heatmap(features @ features.T, dim_names=('features', 'features'))
heatmap(write_out @ read_in.T, dim_names=('write_out', 'read_in'))


In [215]:
tree

  B  1.0
  0 B  0.15
    1 BA 0.5
      2 BA 0.5
        3 B  0.5
        4 B  0.5
      5 BA 0.5
        6 B  0.5
        7 B  0.5
    8 BA 0.5
      9 BA 0.5
        10 B  0.5
        11 B  0.5
      12 BA 0.5
        13 B  0.5
        14 B  0.5
  15 B  0.15
    16 BA 0.5
      17 BA 0.5
        18 B  0.5
        19 B  0.5
      20 BA 0.5
        21 B  0.5
        22 B  0.5
    23 BA 0.5
      24 BA 0.5
        25 B  0.5
        26 B  0.5
      27 BA 0.5
        28 B  0.5
        29 B  0.5

In [394]:


probabilities = features_batch.mean(dim=0)
vis_probs = probabilities / probabilities.max()
vis_probs = vis_probs[None].repeat(7, 1)
codes = nnmf.codes().detach()
codes = codes / codes.abs().max(dim=0, keepdim=True).values.clamp(min=0.6)
codes = codes[:, atom_perm]
codes = torch.cat([vis_probs, codes], dim=0)
f_batch = torch.cat([vis_probs, features_batch], dim=0)


In [395]:


print(tree)
heatmap(codes[:200], title='NNMF activations', dim_names=('batch', 'atom'), info_1={'prob': probabilities}).show()
heatmap(f_batch[:200], title='Ground truth feature activations', info_1={'prob': probabilities}).show()


  B  1.0
  0 BA 0.15
    1 BA 0.2
      2 B  0.25
      3 B  0.25
      4 B  0.25
      5 B  0.25
    6 BA 0.2
      7 B  0.25
      8 B  0.25
      9 B  0.25
      10 B  0.25
    11 BA 0.2
      12 B  0.25
      13 B  0.25
      14 B  0.25
      15 B  0.25
    16 BA 0.2
      17 B  0.25
      18 B  0.25
      19 B  0.25
      20 B  0.25
    21 BA 0.2
      22 B  0.25
      23 B  0.25
      24 B  0.25
      25 B  0.25
  26 BA 0.15
    27 BA 0.2
      28 B  0.25
      29 B  0.25
      30 B  0.25
      31 B  0.25
    32 BA 0.2
      33 B  0.25
      34 B  0.25
      35 B  0.25
      36 B  0.25
    37 BA 0.2
      38 B  0.25
      39 B  0.25
      40 B  0.25
      41 B  0.25
    42 BA 0.2
      43 B  0.25
      44 B  0.25
      45 B  0.25
      46 B  0.25
    47 BA 0.2
      48 B  0.25
      49 B  0.25
      50 B  0.25
      51 B  0.25
  52 BA 0.15
    53 BA 0.2
      54 B  0.25
      55 B  0.25
      56 B  0.25
      57 B  0.25
    58 BA 0.2
      59 B  0.25
      60 B  0.25
      61 B  

In [1145]:
# atom 8: 5
# atom 9: 4
# 10: 1
# 7: 6
# 6: 7????
# 5: 9??
# 4: 2
# 3: 10
# 2: 8
# 0: 3
# 1: 0???

perm = torch.tensor([3, 0, 8, 10, 2, 9, 7, 6, 5, 4, 1])

In [1045]:
# nnmf.atoms @ 
batch_mean = batch.mean(dim=0)
mean_dir = batch_mean/torch.norm(batch_mean)
(nnmf.normed_atoms @ mean_dir)


tensor([-0.3209,  0.0588,  0.2081,  0.0929], grad_fn=<MvBackward0>)

In [762]:
features_batch[:20][:,-3:]

tensor([[0., 0., 0.],
        [1., 1., 0.],
        [0., 1., 0.],
        [0., 0., 0.],
        [0., 0., 0.],
        [0., 0., 0.],
        [0., 0., 0.],
        [0., 0., 0.],
        [0., 0., 0.],
        [1., 0., 0.],
        [0., 1., 0.],
        [0., 0., 0.],
        [1., 0., 0.],
        [0., 0., 0.],
        [0., 0., 0.],
        [0., 0., 0.],
        [0., 0., 1.],
        [0., 0., 0.],
        [0., 0., 1.],
        [0., 0., 0.]])

In [705]:
nnmf.codes()

tensor([[1.2809e+00, 5.2178e-01, 8.3384e-04, 8.2154e-04],
        [1.1812e-04, 5.4305e-02, 1.8370e+00, 4.4291e-04],
        [1.2233e-03, 1.3536e+00, 6.4142e-04, 4.8705e-04],
        ...,
        [5.7441e-04, 1.3546e+00, 4.7168e-04, 4.7056e-04],
        [1.1720e-03, 1.3542e+00, 5.3459e-04, 5.8248e-04],
        [9.2144e-04, 1.3545e+00, 4.8252e-04, 5.8358e-04]],
       grad_fn=<AbsBackward0>)

In [707]:
heatmap(read_in @ write_out.T)

In [22]:
tree

root
├── 0 B 0.2439
├── 1 B 0.1836
└── 2 BA 0.0395
    ├── 3 B 0.9133
    └── 4 B 0.0867

In [36]:
I = 8
f_subset = features_batch[30*I:30*(I+1)]
c_subset = nnmf.codes()[30*I:30*(I+1)]

heatmap(f_subset)

In [25]:
tree

root
├── 0 B 0.2439
├── 1 B 0.1836
└── 2 BA 0.0395
    ├── 3 B 0.9133
    └── 4 B 0.0867

In [24]:
heatmap(nnmf.codes()[:30])