In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import torch.distributions as dists

from interp_utils import reload_module

reload_module('toy_models')
reload_module('discrete_tree')
from discrete_tree import Tree

import json

tree_dict = json.load(open('./simple_tree.json', 'r'))

tree = Tree(tree_dict=tree_dict)

batch = tree.sample(30,)

In [2]:
from interp_utils import heatmap

heatmap(tree.sample((30,)))


In [20]:
reload_module('interp_models')
from interp_models import Autoencoder, SparseAutoencoder, SparseNNMF
from tqdm import tqdm
from interp_utils import get_scheduler

D_MODEL = 20
N_STEPS = 1000


storage_autoencoder = Autoencoder(n_features=tree.n_features, d_model=D_MODEL)

optimizer = optim.AdamW(storage_autoencoder.parameters(), lr=1e-3)
scheduler = get_scheduler(optimizer, N_STEPS)

pbar = tqdm(range(1000))
for i in pbar:
    optimizer.zero_grad()
    batch = tree.sample(100)
    reconstruction = storage_autoencoder(batch)
    
    loss = F.mse_loss(reconstruction, batch)
    loss.backward()
    optimizer.step()
    scheduler.step()

    pbar.set_description(f'loss: {loss.item():.3f}')

write_out = storage_autoencoder.encoder.weight.data.T
write_out_b = storage_autoencoder.encoder.bias.data.T

read_in = storage_autoencoder.decoder.weight.data
read_in_b = storage_autoencoder.decoder.bias.data


normed_write_out = F.normalize(write_out, dim=1)
normed_read_in = F.normalize(read_in, dim=0)

loss: 0.000: 100%|██████████| 1000/1000 [00:05<00:00, 174.66it/s]


In [4]:
from interp_utils import heatmap

heatmap(write_out @ read_in.T)

In [24]:
reload_module('interp_models')
from interp_models import Autoencoder, SparseAutoencoder, SparseNNMF
from tqdm import tqdm
from interp_utils import get_scheduler

# D_MODEL = 
N_EPOCHS = 1
CODE_STEPS = 10000
ATOM_STEPS=2000
BATCH_SIZE=1000

SPARSE_COEF = 1e-1

ORTHOG_K = 0
ORTHOG_COEF = 1e-2


nnmf = SparseNNMF(d_model=storage_autoencoder.d_model, n_features=tree.n_features-2, orthog_k=ORTHOG_K, bias=True)


for epoch in range(N_EPOCHS):
    optimizer = optim.Adam(nnmf.parameters(), lr=1e-2)
    scheduler = get_scheduler(optimizer, CODE_STEPS)

    features_batch = tree.sample(BATCH_SIZE)
    with torch.no_grad():
        batch = storage_autoencoder.encoder(features_batch)

    # update codes
    nnmf.train(batch, frozen_atoms=epoch > 1, sparse_coef=SPARSE_COEF, n_epochs=CODE_STEPS, orthog_coef=ORTHOG_COEF,)

    # update atoms
    # nnmf.train(batch, frozen_codes=True, sparse_coef=SPARSE_COEF, n_epochs=ATOM_STEPS, orthog_coef=ORTHOG_COEF, )



loss: 0.021, mse: 0.021, sparse: 0.172: 100%|██████████| 10000/10000 [00:19<00:00, 516.35it/s]


In [25]:
heatmap(features_batch)

In [26]:
import numpy as np
from scipy.optimize import linear_sum_assignment
sims = nnmf.atoms @ normed_write_out.T
max_value = sims.max()
cost_matrix = max_value - sims

row_ind, col_ind = linear_sum_assignment(cost_matrix.detach().numpy().T)
atom_perm = col_ind
# atom_perm

heatmap(nnmf.atoms[atom_perm] @ normed_write_out.T, dim_names=('atoms', 'features'))



In [27]:
atoms, features = nnmf.atoms, normed_write_out
# def get_atom_perm(features, atoms):
#     # greedily match atoms to features
#     features = normed_write_out
#     atoms = nnmf.atoms

#     perm = []
#     dots = features @ atoms.T
#     for i in range(len(features)):
#         for best_matching_atom_idx in dots[i].sort(descending=True).indices:
#             if best_matching_atom_idx not in perm:
#                 perm.append(best_matching_atom_idx)
#                 break
#     perm = torch.tensor(perm)

#     return perm

# def invert_perm(perm):
#     print(perm)
#     print(len(perm))
#     print(max(perm))
#     inv_perm = torch.zeros_like(perm)
#     for i, p in enumerate(perm):
#         inv_perm[p] = i
#     return inv_perm

# def get_atom_perm(features, atoms):
#     # greedily match atoms to features
#     features = normed_write_out
#     atoms = nnmf.atoms

#     perm = []
#     dots = features @ atoms.T
#     for i in range(len(atoms)):
#         print(dots.T[i].sort(descending=True).indices)
#         for best_matching_feature_idx in dots.T[i].sort(descending=True).indices:
#             if best_matching_feature_idx not in perm:
#                 perm.append(best_matching_feature_idx)
#                 break
#     perm = torch.tensor(perm)
    
#     return perm

# atom_perm = get_atom_perm(features[1:], atoms)

# atoms = atoms[atom_perm]



In [1176]:
atom_perm

array([ 1,  5, 10,  2,  9,  0,  6,  8,  3,  4,  7])

In [28]:
heatmap(atoms[atom_perm] @ normed_read_in.T)

In [29]:
codes = nnmf.codes().detach()
codes = codes / codes.max(dim=1, keepdim=True).values.clamp(min=1)
print(tree)
heatmap(codes[:100, atom_perm], title='NNMF activations', dim_names=('batch', 'atom')).show()
heatmap(features_batch[:100][:,], title='Ground truth feature activations').show()


 B  1.0
  0 BA 0.2
    1 B  0.5
    2 B  0.5
  3 BA 0.2
    4 B  0.5
    5 B  0.5


In [1145]:
# atom 8: 5
# atom 9: 4
# 10: 1
# 7: 6
# 6: 7????
# 5: 9??
# 4: 2
# 3: 10
# 2: 8
# 0: 3
# 1: 0???

perm = torch.tensor([3, 0, 8, 10, 2, 9, 7, 6, 5, 4, 1])

In [1045]:
# nnmf.atoms @ 
batch_mean = batch.mean(dim=0)
mean_dir = batch_mean/torch.norm(batch_mean)
(nnmf.normed_atoms @ mean_dir)


tensor([-0.3209,  0.0588,  0.2081,  0.0929], grad_fn=<MvBackward0>)

In [762]:
features_batch[:20][:,-3:]

tensor([[0., 0., 0.],
        [1., 1., 0.],
        [0., 1., 0.],
        [0., 0., 0.],
        [0., 0., 0.],
        [0., 0., 0.],
        [0., 0., 0.],
        [0., 0., 0.],
        [0., 0., 0.],
        [1., 0., 0.],
        [0., 1., 0.],
        [0., 0., 0.],
        [1., 0., 0.],
        [0., 0., 0.],
        [0., 0., 0.],
        [0., 0., 0.],
        [0., 0., 1.],
        [0., 0., 0.],
        [0., 0., 1.],
        [0., 0., 0.]])

In [705]:
nnmf.codes()

tensor([[1.2809e+00, 5.2178e-01, 8.3384e-04, 8.2154e-04],
        [1.1812e-04, 5.4305e-02, 1.8370e+00, 4.4291e-04],
        [1.2233e-03, 1.3536e+00, 6.4142e-04, 4.8705e-04],
        ...,
        [5.7441e-04, 1.3546e+00, 4.7168e-04, 4.7056e-04],
        [1.1720e-03, 1.3542e+00, 5.3459e-04, 5.8248e-04],
        [9.2144e-04, 1.3545e+00, 4.8252e-04, 5.8358e-04]],
       grad_fn=<AbsBackward0>)

In [707]:
heatmap(read_in @ write_out.T)

In [22]:
tree

root
├── 0 B 0.2439
├── 1 B 0.1836
└── 2 BA 0.0395
    ├── 3 B 0.9133
    └── 4 B 0.0867

In [36]:
I = 8
f_subset = features_batch[30*I:30*(I+1)]
c_subset = nnmf.codes()[30*I:30*(I+1)]

heatmap(f_subset)

In [25]:
tree

root
├── 0 B 0.2439
├── 1 B 0.1836
└── 2 BA 0.0395
    ├── 3 B 0.9133
    └── 4 B 0.0867

In [24]:
heatmap(nnmf.codes()[:30])