In [1]:
import torch
import numpy as np
from torch import nn
from torch.nn import functional as F
from tqdm.auto import tqdm

# from models import SparseAutoEncoder, MLP, SparseCoding
from utils import generate_data, reconstruction_loss_with_l1
from metrics import mcc, greedy_mcc
from models import SparseAutoEncoder, MLP, SparseCoding

# Autoreload
%load_ext autoreload
%autoreload 2

  from .autonotebook import tqdm as notebook_tqdm


In [44]:
device = 'cuda' if torch.cuda.is_available() else 'cpu'

def calculate_dict_mcc(D_true, D_learned):
    print(D_true.shape, D_learned.shape)
    return greedy_mcc(D_true.cpu().numpy(), D_learned.cpu().numpy())

def train_model(model, X_train, S_train, D_true, num_steps=30_000):
    optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)
    for _ in range(num_steps):
        S_, X_ = model(X_train, norm_D=True)
        loss = reconstruction_loss_with_l1(X_train, X_, S_, l1_weight=0.1)
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
    
    # Calculate final metrics
    with torch.no_grad():
        S_, _ = model(X_train, norm_D=False)
        latent_mcc = greedy_mcc(S_train.cpu().numpy(), S_.cpu().numpy())
        if isinstance(model, (SparseAutoEncoder, MLP)):
            D_learned = model.decoder.weight.data
        elif isinstance(model, SparseCoding):
            D_learned = model.D.data
        dict_mcc = calculate_dict_mcc(D_true, D_learned)
    
    return latent_mcc, dict_mcc, D_learned

def train_and_get_weights(true_N, model_N, M, K, num_data, seed=20240926):
    # Generate data
    S, X, D = generate_data(true_N, M, K, num_data, seed=seed)
    D = D.T
    X_train = X.to(device)
    S_train = S.to(device)
    D_true = D.to(device)

    # Create a D_random with shape model_N x M
    D_random = torch.randn(M, model_N).to(device)

    # Initialize models
    models = {
        'SparseCoding': SparseCoding(X_train, D_random, learn_D=True).to(device),
        'SAE': SparseAutoEncoder(M, model_N, D_true, learn_D=True).to(device),
        'MLP': MLP(M, model_N, 256, D_true, learn_D=True).to(device),
    }

    # Train models and get metrics
    metrics = {}
    decoder_weights = {}
    for name, model in tqdm(models.items(), desc="Training models"):
        latent_mcc, dict_mcc, D_learned = train_model(model, X_train, S_train, D_true)
        metrics[name] = {'latent_mcc': latent_mcc, 'dict_mcc': dict_mcc}
        decoder_weights[name] = D_learned

    return decoder_weights, metrics

In [45]:
# Parameters
true_N, model_N, M, K = 16, 24, 8, 3
num_data = 1024

# Train models and get decoder weights and metrics
decoder_weights, metrics = train_and_get_weights(true_N, model_N, M, K, num_data)

# Print shapes of decoder weights and final metrics
for name, weights in decoder_weights.items():
    print(f"{name}:")
    print(f"  Decoder shape: {weights.shape}")
    print(f"  Final Latent MCC: {metrics[name]['latent_mcc']:.4f}")
    print(f"  Final Dictionary MCC: {metrics[name]['dict_mcc']:.4f}")
    print()

SAE decoder shape = torch.Size([8, 24])
SAE decoder shape = torch.Size([8, 24])
MLP decoder shape = torch.Size([8, 24])
MLP decoder shape = torch.Size([8, 24])


Training models:  33%|███▎      | 1/3 [00:17<00:35, 17.78s/it]

torch.Size([8, 16]) torch.Size([8, 24])


Training models:  67%|██████▋   | 2/3 [00:25<00:11, 11.94s/it]

torch.Size([8, 16]) torch.Size([8, 24])
torch.Size([8, 16]) torch.Size([8, 24])


Training models: 100%|██████████| 3/3 [00:57<00:00, 19.24s/it]

SparseCoding:
  Decoder shape: torch.Size([8, 24])
  Final Latent MCC: 0.7040
  Final Dictionary MCC: 0.8537

SAE:
  Decoder shape: torch.Size([8, 24])
  Final Latent MCC: 0.6079
  Final Dictionary MCC: 0.7789

MLP:
  Decoder shape: torch.Size([8, 24])
  Final Latent MCC: 0.8200
  Final Dictionary MCC: 0.9226






In [46]:
# Print column norms of the decoder weight matrix
d = decoder_weights['SAE']
print(f"SAE decoder column norms = {torch.linalg.norm(d, dim=0)}")
d = decoder_weights['MLP']
print(f"MLP decoder column norms = {torch.linalg.norm(d, dim=0)}")
d = decoder_weights['SparseCoding']
print(f"SparseCoding decoder column norms = {torch.linalg.norm(d, dim=0)}")

SAE decoder column norms = tensor([1.0012, 1.0032, 1.0027, 1.0000, 1.0028, 1.0027, 1.0012, 1.0030, 1.0030,
        1.0004, 1.0000, 1.0037, 1.0024, 1.0033, 1.0031, 1.0031, 1.0000, 1.0000,
        1.0000, 1.0000, 1.0027, 1.0006, 1.0018, 1.0024])
MLP decoder column norms = tensor([1.0022, 1.0022, 1.0019, 1.0021, 1.0000, 1.0022, 1.0018, 1.0020, 1.0024,
        1.0025, 1.0021, 1.0017, 1.0000, 1.0021, 1.0018, 1.0026, 1.0025, 1.0025,
        1.0020, 1.0000, 1.0021, 1.0021, 1.0000, 1.0023])
SparseCoding decoder column norms = tensor([1.0023, 1.0027, 1.0025, 1.0026, 1.0022, 1.0027, 1.0024, 1.0026, 1.0025,
        1.0030, 1.0029, 1.0024, 1.0021, 1.0024, 1.0026, 1.0027, 1.0025, 1.0022,
        1.0027, 1.0027, 1.0028, 1.0026, 1.0027, 1.0021])


In [47]:
d = decoder_weights['SAE']

DT_D = d.T @ d #d @ d.T
DT_D.shape

# Sort DT_D by the value on the diagonal
DT_D_sorted = DT_D#[np.argsort(DT_D.diagonal())]

In [48]:
import plotly.express as px
import plotly.subplots as sp

# Create a subplot with 1 row and 3 columns
fig = sp.make_subplots(rows=1, cols=3, subplot_titles=('SAE', 'MLP', 'SparseCoding'))

# Calculate the global color scale range
all_values = []
for model_name in ['SAE', 'MLP', 'SparseCoding']:
    d = decoder_weights[model_name]
    DT_D = d.T @ d
    all_values.extend(DT_D.flatten().tolist())

color_scale_min = min(all_values)
color_scale_max = max(all_values)

# Plot for each model
for i, model_name in enumerate(['SAE', 'MLP', 'SparseCoding'], start=1):
    d = decoder_weights[model_name]
    DT_D = d.T @ d
    
    fig.add_trace(
        px.imshow(DT_D, color_continuous_scale='Blues',
                  zmin=color_scale_min, zmax=color_scale_max).data[0],
        row=1, col=i
    )

# Update layout
fig.update_layout(height=400, width=1200, title_text="DT_D for Different Models")
fig.update_xaxes(title_text="Column Index")
fig.update_yaxes(title_text="Column Index")

# Show the plot
fig.show()

In [37]:
d.shape

torch.Size([8, 24])

In [38]:
# Show the norm of the columns of D
D_norms = np.linalg.norm(d, axis=0)
# Bar chart of the norms of the columns of D
fig = px.bar(x=np.arange(len(D_norms)), y=D_norms, labels={'x':'Column Index', 'y':'Norm'}, title='Norm of Columns of D')
fig.update_layout(width=800, height=500)
fig.show()

In [25]:
d.shape

torch.Size([8, 20])

In [42]:
import einops

# Calculate the superposition of each decoder column vector
superpositions = np.zeros(model_N)

for i in range(model_N):
    vec_i = d[:, i]
    for j in range(model_N):
        if i != j:
            vec_j = d[:, j]
            superposition_measure = (vec_i @ vec_j)**2 #einops.reduce(vec_i * vec_j, 'n ->', 'sum') ** 2
            superpositions[i] += superposition_measure

# Take the square root of each entry
superpositions = np.sqrt(superpositions)

# Bar chart of the superposition of each decoder column vector
fig = px.bar(x=np.arange(len(superpositions)), y=superpositions, labels={'x':'Column Index', 'y':'Superposition'}, title='Superposition of Decoder Column Vectors')
fig.update_layout(width=800, height=500)
fig.show()

## Activation shrinkage experiment

In [41]:
device = 'cuda' if torch.cuda.is_available() else 'cpu'

def calculate_dict_mcc(D_true, D_learned):
    print(D_true.shape, D_learned.shape)
    return greedy_mcc(D_true.cpu().numpy(), D_learned.cpu().numpy())

def train_model(model, X_train, S_train, D_true, num_steps=30_000):
    optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)
    for i in range(num_steps):
        S_, X_ = model(X_train, norm_D=True)
        loss = reconstruction_loss_with_l1(X_train, X_, S_, l1_weight=0.01)
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        if i+1 == num_steps:
            print(f"Loss at step {i+1}: {loss.item()}")
    
    # Calculate final metrics
    with torch.no_grad():
        S_, _ = model(X_train, norm_D=False)
        latent_mcc = greedy_mcc(S_train.cpu().numpy(), S_.cpu().numpy())
        if isinstance(model, (SparseAutoEncoder, MLP)):
            D_learned = model.decoder.weight.data
        elif isinstance(model, SparseCoding):
            D_learned = model.D.data
        dict_mcc = calculate_dict_mcc(D_true, D_learned)
        print(f"Latent MCC: {latent_mcc}")
        print(f"Dict MCC: {dict_mcc}")
    
    return latent_mcc, dict_mcc, model

def train_and_get_weights(true_N, model_N, M, K, num_data, seed=20240926):
    # Generate data
    S, X, D = generate_data(true_N, M, K, num_data, seed=seed)
    D = D.T
    X_train = X.to(device)
    S_train = S.to(device)
    D_true = D.to(device)

    # Create a D_random with shape model_N x M
    D_random = torch.randn(M, model_N).to(device)

    # Initialize models
    models = {
        'SparseCoding': SparseCoding(X_train, D_random, learn_D=True).to(device),
        'SAE': SparseAutoEncoder(M, model_N, D_true, learn_D=True).to(device),
        'MLP': MLP(M, model_N, 256, D_true, learn_D=True).to(device),
    }

    # Train models and get metrics
    return_dict = {}
    for name, model in tqdm(models.items(), desc="Training models"):
        latent_mcc, dict_mcc, model = train_model(model, X_train, S_train, D_true)
        return_dict[name] = {'latent_mcc': latent_mcc, 'dict_mcc': dict_mcc, 'model': model}

    return return_dict, S, X, D

In [42]:
# Parameters
true_N, model_N, M, K = 16, 16, 8, 3
num_data = 1024

# Train models and get decoder weights and metrics
return_dict, S, X, D = train_and_get_weights(true_N, model_N, M, K, num_data)

SAE decoder shape = torch.Size([8, 16])
SAE decoder shape = torch.Size([8, 16])
MLP decoder shape = torch.Size([8, 16])
MLP decoder shape = torch.Size([8, 16])


Training models:   0%|          | 0/3 [00:00<?, ?it/s]

Loss at step 30000: 0.0021204138174653053
torch.Size([8, 16]) torch.Size([8, 16])
Latent MCC: 0.6385076370390954
Dict MCC: 0.8257747379939246


Training models:  33%|███▎      | 1/3 [00:08<00:17,  8.57s/it]

Loss at step 30000: 0.002083843108266592
torch.Size([8, 16]) torch.Size([8, 16])
Latent MCC: 0.5988417538932449
Dict MCC: 0.7496130742365494


Training models:  67%|██████▋   | 2/3 [00:15<00:07,  7.45s/it]

Loss at step 30000: 0.001920610200613737
torch.Size([8, 16]) torch.Size([8, 16])


Training models: 100%|██████████| 3/3 [00:37<00:00, 12.52s/it]

Latent MCC: 0.643156280803017
Dict MCC: 0.8523968125164174





In [44]:
# Get the activations on some test data
S_test, X_test, D_test = generate_data(true_N, M, K, num_data, seed=20240926+1)

activations_dict = {}
decoder_dict = {}

# Get the activations of the test data
for name, model_dict in return_dict.items():
    model = model_dict['model']
    if isinstance(model, SparseCoding):
        S_final = model.optimize_codes(X_test, num_iterations=30_000, lr=1e-3)
        decoder_dict[name] = model.D.data
    elif isinstance(model, SparseAutoEncoder) or isinstance(model, MLP):
        S_final, X_ = model(X_test, norm_D=False)
        decoder_dict[name] = model.decoder.weight.data

    # Set everything below epsilon threshold to zero
    S_final = S_final.clone()
    S_final[S_final < 5e-5] = 0

    print(f"{name} activations shape: {S_final.shape}")
    recon_loss = reconstruction_loss_with_l1(X_test, X_, S_final, l1_weight=0.01)
    print(f"{name} reconstruction loss: {recon_loss}")
    print(f"{name} activations L0 norm: {torch.norm(S_final, p=0)}")
    print(f"{name} activations L1 norm: {torch.norm(S_final, p=1)}")
    print(f"{name} activations mean: {S_final.mean()}")
    print(f"{name} activations std: {S_final.std()}")
    print(f"{name} activations min: {S_final.min()}")
    print(f"{name} activations max: {S_final.max()}\n\n")

    activations_dict[name] = S_final

SparseCoding activations shape: torch.Size([1024, 16])
SparseCoding reconstruction loss: 0.0642855316400528
SparseCoding activations L0 norm: 8521.0
SparseCoding activations L1 norm: 4765.7001953125
SparseCoding activations mean: 0.2908768057823181
SparseCoding activations std: 0.540534257888794
SparseCoding activations min: 0.0
SparseCoding activations max: 6.481787204742432


SAE activations shape: torch.Size([1024, 16])
SAE reconstruction loss: 0.0024461480788886547
SAE activations L0 norm: 8155.0
SAE activations L1 norm: 3929.6337890625
SAE activations mean: 0.2398461550474167
SAE activations std: 0.3780081570148468
SAE activations min: 0.0
SAE activations max: 3.5633182525634766


MLP activations shape: torch.Size([1024, 16])
MLP reconstruction loss: 0.06424184888601303
MLP activations L0 norm: 7867.0
MLP activations L1 norm: 4694.14990234375
MLP activations mean: 0.28650858998298645
MLP activations std: 0.4674675464630127
MLP activations min: 0.0
MLP activations max: 3.7930803298

In [14]:
import plotly.express as px

print(activations_dict['SparseCoding'].shape)

# Histogram of the activations
fig = px.histogram(activations_dict['SparseCoding'].flatten().cpu().numpy(), nbins=100, title='Histogram of SparseCoding Activations')
fig.update_layout(width=800, height=500)
fig.show()

torch.Size([1024, 16])


In [11]:
model_name = 'SparseCoding'

# Get the activations of the test data
S_ = activations_dict[model_name]
D_ = decoder_dict[model_name]

# Get the mask of where the activations are non-zero
mask = S_ > 0
# Convert mask to an integer tensor
mask = mask.to(torch.int)

X_hat = S_ @ D_.T
loss = reconstruction_loss_with_l1(X_test, X_hat, S_, l1_weight=0.01)
print(f"Loss: {loss}")

Loss: 0.36694350838661194


In [12]:
# Define an optimisation loop to optimise the non-zero codes to minimise the loss
S_opt = S_.clone().detach().requires_grad_(True)
optimizer = torch.optim.Adam([S_opt], lr=1e-3)
num_iterations = 1000

for i in range(num_iterations):
    optimizer.zero_grad()
    # Apply ReLU to the activations
    #S_opt = F.relu(S_opt)
    X_hat = S_opt @ D_.T
    loss = reconstruction_loss_with_l1(X_test, X_hat, S_opt, l1_weight=0.01)
    loss.backward()
    
    # Apply the mask to the gradients to only update non-zero codes
    S_opt.grad *= mask
    
    optimizer.step()

    # Project back to non-negative values
    with torch.no_grad():
        S_opt.clamp_(min=0)

    if i % 100 == 0:
        print(f"Loss at iteration {i}: {loss.item()}")
    
S_opt = S_opt.detach()
print(f"Final loss: {loss.item()}")


Loss at iteration 0: 0.36694350838661194
Loss at iteration 100: 0.2931056022644043
Loss at iteration 200: 0.2445969134569168
Loss at iteration 300: 0.21021294593811035
Loss at iteration 400: 0.1847432255744934
Loss at iteration 500: 0.1653081327676773
Loss at iteration 600: 0.1501605063676834
Loss at iteration 700: 0.138151153922081
Loss at iteration 800: 0.12848083674907684
Loss at iteration 900: 0.12059114128351212
Final loss: 0.11414305120706558


In [13]:
S_opt

tensor([[0.0000, 0.0000, 0.0000,  ..., 0.2700, 0.0679, 0.0000],
        [0.0000, 0.4185, 0.3180,  ..., 0.0000, 0.0000, 0.3530],
        [0.0000, 0.0933, 0.7430,  ..., 0.0000, 0.0000, 0.0000],
        ...,
        [0.0000, 0.5826, 0.0850,  ..., 0.0000, 0.0000, 0.0000],
        [0.0000, 0.0000, 0.4312,  ..., 0.1912, 0.3534, 0.0000],
        [0.0000, 0.0000, 0.2751,  ..., 0.0000, 0.3502, 0.0000]])

In [14]:
S_

tensor([[4.5400e-05, 4.5400e-05, 4.5400e-05,  ..., 4.5400e-05, 4.5400e-05,
         4.5400e-05],
        [4.5400e-05, 4.5400e-05, 4.5400e-05,  ..., 4.5400e-05, 4.5400e-05,
         4.5400e-05],
        [4.5400e-05, 4.5400e-05, 4.5400e-05,  ..., 4.5400e-05, 4.5400e-05,
         4.5400e-05],
        ...,
        [4.5400e-05, 4.5400e-05, 4.5400e-05,  ..., 4.5400e-05, 4.5400e-05,
         4.5400e-05],
        [4.5400e-05, 4.5400e-05, 4.5400e-05,  ..., 4.5400e-05, 4.5400e-05,
         4.5400e-05],
        [4.5400e-05, 4.5400e-05, 4.5400e-05,  ..., 4.5400e-05, 4.5400e-05,
         4.5400e-05]])

In [15]:
import plotly.graph_objects as go

# Use the mask to get corresponding values in S_ and S_opt
S_masked = S_.flatten()[mask.flatten()]
S_opt_masked = S_opt.flatten()[mask.flatten()]

print(S_masked.shape)
print(S_opt_masked.shape)

# Calculate the difference between corresponding activations where S_ is non-zero
activation_diff = []
for i in range(S_.numel()):
    if S_.flatten()[i] != 0:
        diff = S_opt.flatten()[i] - S_.flatten()[i]
        activation_diff.append(diff)

activation_diff = torch.tensor(activation_diff)
print(activation_diff)

print(f"Activation difference mean: {activation_diff.mean()}")
print(f"Activation difference shape: {activation_diff.shape}")

# Plotly histogram of the difference in activations
fig = go.Figure()
fig.add_trace(go.Histogram(x=activation_diff.tolist(), name='Activation Difference', opacity=0.7))

# Add a vertical line for the mean
mean_value = activation_diff.mean()
fig.add_vline(x=mean_value, line_dash="dash", line_color="red", annotation_text=f"Mean: {mean_value:.4f}", annotation_position="top right")

fig.update_layout(
    title='Histogram of Activation Differences (Optimized - Original)',
    xaxis_title='Activation Difference',
    yaxis_title='Count',
    width=800,
    height=500
)

fig.show()

torch.Size([16384])
torch.Size([16384])
tensor([-4.5400e-05, -4.5400e-05, -4.5400e-05,  ..., -4.5400e-05,
         3.5015e-01, -4.5400e-05])
Activation difference mean: 0.1945384442806244
Activation difference shape: torch.Size([16384])


In [75]:
activation_diff

tensor([-3.5070e-05,  6.1280e-02, -3.7076e-05,  ..., -4.5573e-05,
         3.3415e-01, -4.9380e-05])