In [2]:
from interp_utils import *

import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import torch.distributions as dists


In [3]:
from interp_utils import hist

alpha = 3
beta = 10

distr= dists.Beta(alpha, beta)


sample = distr.sample((10000,))

hist(sample)

In [4]:
reload_module('toy_models')
from toy_models import Tree, SparseIndependent


tree = Tree(n_growths=3, root_is_feature=True)
model = SparseIndependent(n_features=300, d_model=100)
hidden_state, ground_truth = model(10)



In [5]:
from interp_utils import heatmap
feats = model.normed_features

heatmap(feats @ feats.T, dim_names=('features', 'features'))

In [15]:
from interp_utils import see, asee
p=print

In [6]:
reload_module('interp_models')
from interp_models import Autoencoder, SparseAutoencoder, SparseNNMF
from tqdm import tqdm
from interp_utils import get_scheduler

D_MODEL = 6
N_STEPS = 1000


storage_autoencoder = Autoencoder(n_features=tree.n_features, d_model=D_MODEL)

optimizer = optim.AdamW(storage_autoencoder.parameters(), lr=1e-3)
scheduler = get_scheduler(optimizer, N_STEPS)

pbar = tqdm(range(1000))
for i in pbar:
    optimizer.zero_grad()
    batch = tree.sample(100)
    reconstruction = storage_autoencoder(batch)
    
    loss = F.mse_loss(reconstruction, batch)
    loss.backward()
    optimizer.step()
    scheduler.step()

    pbar.set_description(f'loss: {loss.item():.3f}')

write_out = storage_autoencoder.encoder.weight.data.T
write_out_b = storage_autoencoder.encoder.bias.data.T

read_in = storage_autoencoder.decoder.weight.data
read_in_b = storage_autoencoder.decoder.bias.data


IProgress not found. Please update jupyter and ipywidgets. See https://ipywidgets.readthedocs.io/en/stable/user_install.html

loss: 0.005: 100%|██████████| 1000/1000 [00:11<00:00, 85.69it/s]

The use of `x.T` on tensors of dimension other than 2 to reverse their shape is deprecated and it will throw an error in a future release. Consider `x.mT` to transpose batches of matrices or `x.permute(*torch.arange(x.ndim - 1, -1, -1))` to reverse the dimensions of a tensor. (Triggered internally at ../aten/src/ATen/native/TensorShape.cpp:3571.)



In [198]:
reload_module('interp_models')
from interp_models import Autoencoder, SparseAutoencoder, SparseNNMF
from tqdm import tqdm
from interp_utils import get_scheduler

D_MODEL = 6
N_EPOCHS = 100
CODE_STEPS = 3000
ATOM_STEPS=3000
BATCH_SIZE=1000
L1_LAMBDA = 1e-2


nnmf = SparseNNMF(n_features=4, d_model=D_MODEL, n_codes=BATCH_SIZE)


for epoch in range(N_EPOCHS):
    optimizer = optim.Adam(nnmf.parameters(), lr=1e-2)
    scheduler = get_scheduler(optimizer, CODE_STEPS)

    features_batch = tree.sample(BATCH_SIZE)
    with torch.no_grad():
        batch = storage_autoencoder.encoder(features_batch)

    # update codes
    nnmf.train(batch, frozen_atoms=epoch > 1, l1_lambda=L1_LAMBDA, n_steps=CODE_STEPS)

    # update atoms
    nnmf.train(batch, frozen_codes=True, l1_lambda=L1_LAMBDA, n_steps=ATOM_STEPS)



loss: 0.009, mse: 0.002, sparse: 0.655: 100%|██████████| 3000/3000 [00:03<00:00, 994.93it/s] 
loss: 0.002, mse: 0.002: 100%|██████████| 3000/3000 [00:02<00:00, 1167.20it/s]
loss: 0.008, mse: 0.003, sparse: 0.523: 100%|██████████| 3000/3000 [00:02<00:00, 1065.33it/s]
loss: 0.003, mse: 0.003: 100%|██████████| 3000/3000 [00:02<00:00, 1168.67it/s]
loss: 0.008, mse: 0.003, sparse: 0.510: 100%|██████████| 3000/3000 [00:02<00:00, 1050.64it/s]
loss: 0.003, mse: 0.003: 100%|██████████| 3000/3000 [00:02<00:00, 1186.39it/s]
loss: 0.008, mse: 0.003, sparse: 0.522: 100%|██████████| 3000/3000 [00:02<00:00, 1048.81it/s]
loss: 0.003, mse: 0.003: 100%|██████████| 3000/3000 [00:02<00:00, 1151.52it/s]
loss: 0.008, mse: 0.003, sparse: 0.525: 100%|██████████| 3000/3000 [00:02<00:00, 1087.96it/s]
loss: 0.003, mse: 0.003: 100%|██████████| 3000/3000 [00:02<00:00, 1176.28it/s]
loss: 0.008, mse: 0.003, sparse: 0.524: 100%|██████████| 3000/3000 [00:02<00:00, 1018.31it/s]
loss: 0.003, mse: 0.003: 100%|██████████|

KeyboardInterrupt: 

In [199]:
feature_batch = tree.sample(1000,)
with torch.no_grad():
    batch = storage_autoencoder.encoder(feature_batch)

nnmf.train(batch, frozen_atoms=True, n_steps=3000)

pred, codes = nnmf(frozen_codes=True)

loss: 0.054, mse: 0.005, sparse: 0.490: 100%|██████████| 3000/3000 [00:02<00:00, 1051.95it/s]


In [200]:
freqs = (feature_batch > 1e-2).float().mean(dim=0)
freqs

tensor([0.9930, 0.9910, 0.9890, 0.7830, 0.2120, 0.9040, 0.0830])

In [202]:
(codes > 1e-3).float().mean(dim=0)

tensor([0.7870, 0.9280, 0.9080, 0.8990])

In [196]:
tree

0 root
├── 1 B 1.0
├── 2 BA 1.0
│   ├── 3 B 0.7816
│   └── 4 0.2184
└── BA 1.0
    ├── 5 B 0.912
    └── 6 0.088

In [195]:
heatmap(write_out @ read_in.T, dim_names=('write_out', 'read_in'))

In [76]:
hist(write_out.norm(dim=-1))

In [54]:
write_out = 

tensor(-0.3275)

tensor(-4.3810)

In [47]:
# import plotly.graph_objects as go

# # define three 3D vectors
# vector1 = [1, 2, 3]
# vector2 = [4, 5, 6]
# vector3 = [7, 8, 9]

# fig = go.Figure()

# # Add vectors
# fig.add_trace(go.Scatter3d(x=[0, vector1[0]], y=[0, vector1[1]], z=[0, vector1[2]], mode='lines', name='vector1'))
# fig.add_trace(go.Scatter3d(x=[0, vector2[0]], y=[0, vector2[1]], z=[0, vector2[2]], mode='lines', name='vector2'))
# fig.add_trace(go.Scatter3d(x=[0, vector3[0]], y=[0, vector3[1]], z=[0, vector3[2]], mode='lines', name='vector3'))

# # Set axes title
# fig.update_layout(scene=dict(xaxis_title='X',
#                              yaxis_title='Y',
#                              zaxis_title='Z'))

# fig.show()