In [1]:
import torch
import torch.optim as optim
import torch.nn as nn
import torch.nn.functional as F

import plotly.express as px
from tqdm import tqdm

from solu_moe import MoeAutoencoder, get_scheduler, heatmap

In [8]:
N_FEATURES = 70
SPARSITY = 0.99
IMPORTANCE_BASE = 0.99

LR = 3e-2
N_STEPS = 10000
BATCH_FEATURE_MULTIPLIER = 2

BATCH_SIZE = BATCH_FEATURE_MULTIPLIER*N_FEATURES + 1

model_config = {
    'd_model': 5,
    'n_experts': 6,
    'k_experts': 2,
    'd_expert': 5,
    'n_features': N_FEATURES,
    'activation': 'solu',
    'gate_jitter_eps': 1e-1,
    'solu_jitter_eps': 0,
    'w_in_bias_init': False,
    'use_norm': False,
}


model = MoeAutoencoder(**model_config)
optimizer = torch.optim.Adam(model.parameters(), lr=LR)
scheduler = get_scheduler(optimizer, N_STEPS)

feature_importance_weights = IMPORTANCE_BASE**torch.arange(model_config['n_features'])
feature_importance_weights = feature_importance_weights/feature_importance_weights.mean()

def get_batch_and_batch_weights():
    global BATCH_SIZE, N_FEATURES, SPARSITY, BATCH_FEATURE_MULTIPLIER
    '''Instead of sampling sums of sparse features with most batch elements being the zero vector,
    we let one element of the batch be the zero vector and make all other batch elements include at least one non-zero feature.
    Then we create weights for each element of the batch proportional to their probability.'''
    zero_feature = torch.zeros(1, N_FEATURES)
    base_features = torch.eye(N_FEATURES).repeat(BATCH_FEATURE_MULTIPLIER, 1) * torch.rand(BATCH_SIZE-1, N_FEATURES)

    # conditioned features
    conditioned_features = torch.rand(BATCH_SIZE-1, N_FEATURES)
    conditioned_features[(torch.rand(BATCH_SIZE-1, N_FEATURES) < SPARSITY)+(base_features != 0.0)] = 0.0
    conditioned_features = conditioned_features + base_features

    batch = torch.cat((zero_feature, conditioned_features), dim=0)

    batch_weights = torch.ones(BATCH_SIZE)
    batch_weights[0] = SPARSITY**N_FEATURES
    batch_weights[1:] = (1-(SPARSITY**N_FEATURES))/(BATCH_SIZE-1)

    return batch, batch_weights

def weighted_MSE(x, y, batch_weights=False):
    '''MSE with feature importance weights and batch weights'''
    global feature_importance_weights
    if batch_weights is not False:
        return (((x-y)**2)*feature_importance_weights[None,:]*batch_weights[:,None]).sum()
    else:
        return (((x-y)**2)*feature_importance_weights[None,:]).mean()


pbar = tqdm(range(N_STEPS))
for epoch in pbar:
    batch, batch_weights = get_batch_and_batch_weights()

    est_batch, load_balancing_loss = model(batch)
    load_balancing_loss = load_balancing_loss - model_config['k_experts']
    mse_loss = weighted_MSE(est_batch, batch, batch_weights)

    loss = mse_loss + 1e0*load_balancing_loss
    loss.backward()
    if epoch % 100 == 0:
        pbar.set_description(f"Loss: {loss.item():.4f} | MSE: {mse_loss.item():.4f} | Load Balancing: {load_balancing_loss.item():.4f}")

    optimizer.step()
    optimizer.zero_grad()
    scheduler.step()

Loss: 0.1811 | MSE: 0.1795 | Load Balancing: 0.0016: 100%|██████████| 10000/10000 [00:21<00:00, 466.68it/s]


In [9]:
# Evaluation

def get_features(batch_size, n_features, sparsity):
    features = torch.rand(batch_size, n_features)
    features[(torch.rand(batch_size, n_features) < sparsity)] = 0.0
    return features

model.eval()
features = get_features(3000*BATCH_SIZE, model_config['n_features'], SPARSITY)
est_features, _ = model(features)
mse_loss = MSE(est_features, features)
print(f"Final Loss: {mse_loss.item():.4f}")

Final Loss: 0.0022


In [10]:
import plotly.express as px
from interp_utils import heatmap

model.eval()

features = torch.eye(model_config['n_features'])

pred_features, _ = model(features)
feature_reconstruction_fig = heatmap(pred_features, title='Feature reconstruction', dim_names=('which feature is set to 1 (one-hot features)', 'model feature predictions'))

full_acts, k_expert_indices, k_expert_weights = model.get_full_acts_for_rendering(features, col_weight=-0.05, null_expert_activation=0.0)
full_acts_fig = heatmap(full_acts, dim_names=('features', 'expert neurons'), title='Activations')

value_weighted_full_acts, k_expert_indices, k_expert_weights = model.get_full_acts_for_rendering(features, value_weighted=True, col_weight=-0.25, null_expert_activation=0.0)
value_weighted_full_acts_fig = heatmap(value_weighted_full_acts, dim_names=('features', 'expert neurons'), title='Value-weighted activations')

feature_reconstruction_fig.show()
full_acts_fig.show()
value_weighted_full_acts_fig.show()