In [1]:
import torch
import numpy as np
from torch import nn
from torch.nn import functional as F
from tqdm.auto import tqdm

from models import SparseAutoEncoder, MLP, SparseCoding
from utils import generate_data, reconstruction_loss_with_l1
from metrics import mcc, greedy_mcc

# Autoreload
%load_ext autoreload
%autoreload 2

device = 'cuda' if torch.cuda.is_available() else 'cpu'

def calculate_dict_mcc(D_true, D_learned):
    return greedy_mcc(D_true.T.cpu().numpy(), D_learned.T.cpu().numpy())

def train_model(model, X_train, S_train, D_true, num_steps=30000):
    optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)
    for _ in range(num_steps):
        S_, X_ = model(X_train)
        loss = reconstruction_loss_with_l1(X_train, X_, S_)
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
    
    # Calculate final metrics
    with torch.no_grad():
        S_, _ = model(X_train)
        latent_mcc = greedy_mcc(S_train.cpu().numpy(), S_.cpu().numpy())
        if isinstance(model, (SparseAutoEncoder, MLP)):
            D_learned = model.decoder.weight.data.T
        elif isinstance(model, SparseCoding):
            D_learned = model.D.data
        dict_mcc = calculate_dict_mcc(D_true, D_learned)
    
    return latent_mcc, dict_mcc

def get_decoder_weights(models):
    decoder_weights = {}
    for name, model in models.items():
        if isinstance(model, (SparseAutoEncoder, MLP)):
            decoder_weights[name] = model.decoder.weight.data.T
        elif isinstance(model, SparseCoding):
            decoder_weights[name] = model.D.data
    return decoder_weights

def train_and_get_weights(true_N, model_N, M, K, num_data, seed=20240926):
    # Generate data
    S, X, D = generate_data(true_N, M, K, num_data, seed=seed)
    X_train = X.to(device)
    S_train = S.to(device)
    D_true = D.to(device)

    # Create a D_random with shape model_N x M
    D_random = torch.randn(model_N, M).to(device)
    D_random = D_random / torch.linalg.norm(D_random, dim=1, keepdim=True)

    # Initialize models
    models = {
        'SparseCoding': SparseCoding(X_train, D_random, learn_D=True).to(device),
        'SAE': SparseAutoEncoder(M, model_N, D_true, learn_D=True).to(device),
        'MLP': MLP(M, model_N, 256, D_true, learn_D=True).to(device),
    }

    # Train models and get metrics
    metrics = {}
    for name, model in tqdm(models.items(), desc="Training models"):
        latent_mcc, dict_mcc = train_model(model, X_train, S_train, D_true)
        metrics[name] = {'latent_mcc': latent_mcc, 'dict_mcc': dict_mcc}

    # Get decoder weights
    decoder_weights = get_decoder_weights(models)

    return decoder_weights, metrics

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
# Parameters
true_N, model_N, M, K = 16, 20, 8, 3
num_data = 1024

# Train models and get decoder weights and metrics
decoder_weights, metrics = train_and_get_weights(true_N, model_N, M, K, num_data)

# Print shapes of decoder weights and final metrics
for name, weights in decoder_weights.items():
    print(f"{name}:")
    print(f"  Decoder shape: {weights.shape}")
    print(f"  Final Latent MCC: {metrics[name]['latent_mcc']:.4f}")
    print(f"  Final Dictionary MCC: {metrics[name]['dict_mcc']:.4f}")
    print()

Training models: 100%|██████████| 3/3 [00:38<00:00, 12.84s/it]

SparseCoding:
  Decoder shape: torch.Size([20, 8])
  Final Latent MCC: 0.7048
  Final Dictionary MCC: 0.8596

SAE:
  Decoder shape: torch.Size([20, 8])
  Final Latent MCC: 0.4932
  Final Dictionary MCC: 0.6594

MLP:
  Decoder shape: torch.Size([20, 8])
  Final Latent MCC: 0.7137
  Final Dictionary MCC: 0.8909






In [3]:
decoder_weights

{'SparseCoding': tensor([[ 0.2234,  0.3667,  0.0450, -0.2941, -0.0321,  0.2738,  0.0556,  0.2871],
         [-0.0970, -0.1849, -0.2685,  0.3301,  0.0774,  0.2291,  0.3098,  0.1509],
         [-0.2222,  0.0399,  0.2431,  0.2028,  0.1300,  0.0015,  0.2777, -0.0974],
         [-0.2492,  0.1082, -0.0298,  0.2071, -0.3385,  0.1348, -0.1177, -0.0788],
         [ 0.2489,  0.1149, -0.2748,  0.2558, -0.0460,  0.0373, -0.1843,  0.3545],
         [ 0.2334, -0.3533,  0.3062, -0.2582,  0.3227,  0.2631,  0.0466,  0.2855],
         [ 0.2668, -0.3905,  0.0076, -0.0749, -0.2025,  0.2684, -0.2837, -0.2519],
         [ 0.2138,  0.3291,  0.2982, -0.2760,  0.1080,  0.2879,  0.1506, -0.2685],
         [-0.0097, -0.0104, -0.0841, -0.0602, -0.1967,  0.1617, -0.1179, -0.2713],
         [-0.2296, -0.3943,  0.1645, -0.2702, -0.2002, -0.1667,  0.2784, -0.2860],
         [-0.2259,  0.0447,  0.0752,  0.0097, -0.0418, -0.2788,  0.0229,  0.3117],
         [-0.1768, -0.0964, -0.1651, -0.2593, -0.3454, -0.2957, -0.1993

In [7]:
d = decoder_weights['SAE']

DT_D = d @ d.T
DT_D.shape

# Sort DT_D by the value on the diagonal
DT_D_sorted = DT_D#[np.argsort(DT_D.diagonal())]

In [8]:
import plotly.express as px

fig = px.imshow(DT_D_sorted, color_continuous_scale='Blues')
fig.show()

In [9]:
# Show the norm of the columns of D
D_norms = np.linalg.norm(d, axis=0)
# Bar chart of the norms of the columns of D
fig = px.bar(x=np.arange(len(D_norms)), y=D_norms, labels={'x':'Column Index', 'y':'Norm'}, title='Norm of Columns of D')
fig.update_layout(width=800, height=500)
fig.show()

In [10]:
import einops

# Calculate the superposition of each decoder column vector
superpositions = np.zeros(model_N)

for i in range(model_N):
    vec_i = d[i, :]
    for j in range(model_N):
        if i != j:
            vec_j = d[j, :]
            superposition_measure = einops.reduce(vec_i * vec_j, 'n ->', 'sum') ** 2
            superpositions[i] += superposition_measure

# Bar chart of the superposition of each decoder column vector
fig = px.bar(x=np.arange(len(superpositions)), y=superpositions, labels={'x':'Column Index', 'y':'Superposition'}, title='Superposition of Decoder Column Vectors')
fig.update_layout(width=800, height=500)
fig.show()