In [151]:
import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
from matplotlib import pyplot as plt
from tqdm import tqdm
from collections import defaultdict

In [152]:
torch.manual_seed(1337)

<torch._C.Generator at 0x119848e70>

In [153]:
#generate data for training as hidden state activations via LRH - sum of sparse overcomplete feature directions

def generate_hidden_data(dim = 128, n_features = 512, 
                         n_samples = (2**10), sparsity = 10):
    #basically want features Y times random vector w where w is sparse, then sum resulting vectors for hidden state
    #overcomplete feature basis?
    features = np.random.randn(n_features, dim)
    features = features / np.linalg.norm(features, axis=1, keepdims=True)

    #init sparsity weights
    weights = np.zeros((n_samples, n_features))
    #generate sparsity weights
    for i in tqdm(range(n_samples)):
        active_feats = np.random.choice(n_features, size=sparsity, replace=False)
        weights[i, active_feats] = np.random.randn(sparsity)
    #make hidden data via sum of sparse features
    hidden_data = weights @ features

    return torch.tensor(hidden_data, dtype=torch.float32)

print(generate_hidden_data().shape)

100%|██████████| 1024/1024 [00:00<00:00, 73930.07it/s]

torch.Size([1024, 128])





In [154]:
class SAE(nn.Module):
    def __init__(self, input_dim, width_ratio=4, activation=nn.ReLU()):
        super().__init__()
        self.sae_hidden = input_dim * width_ratio
        self.W_in = nn.Parameter(
            nn.init.kaiming_uniform_(
                torch.empty(input_dim, self.sae_hidden), nonlinearity="relu"
            )
        )
        self.b_in = nn.Parameter(torch.zeros(self.sae_hidden))
        self.W_out = nn.Parameter(
            nn.init.kaiming_uniform_(
                torch.empty(self.sae_hidden, input_dim), nonlinearity="relu"
            )
        )
        self.b_out = nn.Parameter(torch.zeros(input_dim))
        self.nonlinearity = activation

    def _normalize_weights(self):
        with torch.no_grad():
            norms = self.W_out.norm(p=2, dim=0, keepdim=True)
            self.W_out.div_(norms)

    def forward(self, x):
        x = x - self.b_out
        acts = self.nonlinearity(x @ self.W_in + self.b_in)
        l1_regularization = acts.abs().sum()
        l0 = (acts > 0).sum(dim=1).float().mean()
        self._normalize_weights()

        return l0, l1_regularization, acts@self.W_out + self.b_out


In [155]:
def train(model, train_data, test_data, batch_size=128, n_epochs=1000, l1_lam=5e-5, weight_decay=1e-4):
    optimizer = optim.Adam(model.parameters(), weight_decay=weight_decay)
    mse_criterion = nn.MSELoss()

    n_batches = len(train_data) // batch_size
    n_test_batches = len(test_data) // batch_size

    for epoch in range(n_epochs):
        total_loss = 0
        total_test_loss = 0
        total_mse_loss = 0
        total_l1_loss = 0
        total_l0 = 0
        batch_perm = torch.randperm(len(train_data))
        test_batch_perm = torch.randperm(len(test_data))

        for i in range(n_batches):
            # Training
            idx = batch_perm[i*batch_size: (i+1)*batch_size]
            batch = train_data[idx]

            optimizer.zero_grad()
            l0, l1, recon_hiddens = model(batch)

            recon_loss = mse_criterion(recon_hiddens, batch)
            sparsity_loss = l1_lam * l1
            loss = recon_loss + sparsity_loss

            loss.backward()
            optimizer.step()
            
            total_loss += loss.item()
            total_l1_loss += sparsity_loss.item()
            total_mse_loss += recon_loss.item()
            total_l0 += l0

            # Testing
            if i < n_test_batches:
                test_idx = test_batch_perm[i*batch_size: (i+1)*batch_size]
                test_batch = test_data[test_idx]
                
                with torch.no_grad():
                    _, _, test_recon = model(test_batch)
                    test_loss = mse_criterion(test_recon, test_batch)
                    total_test_loss += test_loss.item()

        #if epoch % 10 == 0:
        #    avg_loss = total_loss / n_batches
        #    avg_test_loss = total_test_loss / n_test_batches
        #    avg_l1_loss = total_l1_loss / n_batches
        #    avg_l0 = total_l0 / n_batches
            
            #print(f'Epoch {epoch}, Loss: {avg_loss:.4f}, '
            #      f'Test Loss: {avg_test_loss:.4f}, '
            #      f'L1: {avg_l1_loss:.4f}, '
            #      f'L0: {avg_l0:.4f}')

    return {
        'mse': total_mse_loss/n_batches,
        'L0': total_l0/n_batches
    }

In [156]:
def run_experiment():
    sparsity = 128
    hidden_dim = 128
    width_factor = 4

    data = generate_hidden_data(dim=hidden_dim, sparsity=sparsity)
    train_size = int(0.8 * len(data))
    train_data, test_data = data[:train_size], data[train_size:]
    
    relu_model = SAE(hidden_dim, width_factor, nn.ReLU())
    print("Training ReLU model...")
    result = train(relu_model, train_data, test_data)
            

def run_DOE():
    sparsities = [5, 10, 20, 30, 40, 50]
    seeds = [1337, 42, 69, 420, 666]
    results = defaultdict(list)
    hidden_dim = 128
    width_factor = 4

    for sparsity in tqdm(sparsities):
        for i, trial in tqdm(enumerate(range(seeds))):
            torch.manual_seed(trial)
            data = generate_hidden_data(dim=hidden_dim, sparsity=sparsity)
            
            # Shuffle data
            indices = torch.randperm(len(data))
            data = data[indices]
            
            # Split into train/test
            train_size = int(0.8 * len(data))
            train_data, test_data = data[:train_size], data[train_size:]
            
            relu_model = SAE(hidden_dim, width_factor, nn.ReLU())
            print(f"Training ReLU model with {sparsity} sparsity on iteration {i}")
            result = train(relu_model, train_data, test_data)
            results[sparsity].append(result)

    return results
            

In [157]:
results = run_DOE()

  0%|          | 0/6 [00:00<?, ?it/s]
100%|██████████| 1024/1024 [00:00<00:00, 64112.60it/s]


Training ReLU model with 5 sparsity on iteration 0



100%|██████████| 1024/1024 [00:00<00:00, 64243.98it/s]


Training ReLU model with 5 sparsity on iteration 1



100%|██████████| 1024/1024 [00:00<00:00, 61952.30it/s]


Training ReLU model with 5 sparsity on iteration 2



100%|██████████| 1024/1024 [00:00<00:00, 65427.18it/s]


Training ReLU model with 5 sparsity on iteration 3



100%|██████████| 1024/1024 [00:00<00:00, 72007.63it/s]


Training ReLU model with 5 sparsity on iteration 4


100%|██████████| 5/5 [00:27<00:00,  5.41s/it]
 17%|█▋        | 1/6 [00:27<02:15, 27.05s/it]
100%|██████████| 1024/1024 [00:00<00:00, 69745.01it/s]


Training ReLU model with 10 sparsity on iteration 0



100%|██████████| 1024/1024 [00:00<00:00, 63264.55it/s]


Training ReLU model with 10 sparsity on iteration 1



100%|██████████| 1024/1024 [00:00<00:00, 64489.00it/s]


Training ReLU model with 10 sparsity on iteration 2



100%|██████████| 1024/1024 [00:00<00:00, 65591.05it/s]


Training ReLU model with 10 sparsity on iteration 3



100%|██████████| 1024/1024 [00:00<00:00, 61560.71it/s]


Training ReLU model with 10 sparsity on iteration 4


100%|██████████| 5/5 [00:27<00:00,  5.55s/it]
 33%|███▎      | 2/6 [00:54<01:49, 27.45s/it]
100%|██████████| 1024/1024 [00:00<00:00, 62560.52it/s]


Training ReLU model with 20 sparsity on iteration 0



100%|██████████| 1024/1024 [00:00<00:00, 62618.89it/s]


Training ReLU model with 20 sparsity on iteration 1



100%|██████████| 1024/1024 [00:00<00:00, 61199.31it/s]


Training ReLU model with 20 sparsity on iteration 2



100%|██████████| 1024/1024 [00:00<00:00, 64265.13it/s]


Training ReLU model with 20 sparsity on iteration 3



100%|██████████| 1024/1024 [00:00<00:00, 57146.60it/s]


Training ReLU model with 20 sparsity on iteration 4


100%|██████████| 5/5 [00:30<00:00,  6.15s/it]
 50%|█████     | 3/6 [01:25<01:26, 28.95s/it]
100%|██████████| 1024/1024 [00:00<00:00, 59653.98it/s]


Training ReLU model with 30 sparsity on iteration 0



100%|██████████| 1024/1024 [00:00<00:00, 60232.06it/s]


Training ReLU model with 30 sparsity on iteration 1



100%|██████████| 1024/1024 [00:00<00:00, 61850.60it/s]


Training ReLU model with 30 sparsity on iteration 2



100%|██████████| 1024/1024 [00:00<00:00, 62033.73it/s]


Training ReLU model with 30 sparsity on iteration 3



100%|██████████| 1024/1024 [00:00<00:00, 62120.76it/s]


Training ReLU model with 30 sparsity on iteration 4


100%|██████████| 5/5 [00:30<00:00,  6.16s/it]
 67%|██████▋   | 4/6 [01:56<00:59, 29.68s/it]
100%|██████████| 1024/1024 [00:00<00:00, 57101.02it/s]


Training ReLU model with 40 sparsity on iteration 0



100%|██████████| 1024/1024 [00:00<00:00, 60323.42it/s]


Training ReLU model with 40 sparsity on iteration 1



100%|██████████| 1024/1024 [00:00<00:00, 61314.63it/s]


Training ReLU model with 40 sparsity on iteration 2



100%|██████████| 1024/1024 [00:00<00:00, 60204.20it/s]


Training ReLU model with 40 sparsity on iteration 3



100%|██████████| 1024/1024 [00:00<00:00, 59297.36it/s]


Training ReLU model with 40 sparsity on iteration 4


100%|██████████| 5/5 [00:31<00:00,  6.26s/it]
 83%|████████▎ | 5/6 [02:27<00:30, 30.26s/it]
100%|██████████| 1024/1024 [00:00<00:00, 57505.45it/s]


Training ReLU model with 50 sparsity on iteration 0



100%|██████████| 1024/1024 [00:00<00:00, 56806.48it/s]


Training ReLU model with 50 sparsity on iteration 1



100%|██████████| 1024/1024 [00:00<00:00, 56051.78it/s]


Training ReLU model with 50 sparsity on iteration 2



100%|██████████| 1024/1024 [00:00<00:00, 58029.90it/s]


Training ReLU model with 50 sparsity on iteration 3



100%|██████████| 1024/1024 [00:00<00:00, 58191.87it/s]


Training ReLU model with 50 sparsity on iteration 4


100%|██████████| 5/5 [00:29<00:00,  5.95s/it]
100%|██████████| 6/6 [02:57<00:00, 29.56s/it]


In [160]:
import pandas as pd

results

defaultdict(list,
            {5: [{'mse': 0.026843073467413586, 'L0': tensor(5.7174)},
              {'mse': 0.026320386677980423, 'L0': tensor(5.4258)},
              {'mse': 0.026732576079666615, 'L0': tensor(5.5404)},
              {'mse': 0.02660268358886242, 'L0': tensor(5.7930)},
              {'mse': 0.02716939647992452, 'L0': tensor(5.7018)}],
             10: [{'mse': 0.044039856642484665, 'L0': tensor(10.5352)},
              {'mse': 0.044560532396038376, 'L0': tensor(10.7474)},
              {'mse': 0.04403712724645933, 'L0': tensor(10.8307)},
              {'mse': 0.0442103153715531, 'L0': tensor(10.7344)},
              {'mse': 0.0443751011043787, 'L0': tensor(10.7839)}],
             20: [{'mse': 0.07146312793095906, 'L0': tensor(18.5846)},
              {'mse': 0.07112449655930202, 'L0': tensor(18.4102)},
              {'mse': 0.06989554439981778, 'L0': tensor(18.3620)},
              {'mse': 0.0698989989856879, 'L0': tensor(18.1250)},
              {'mse': 0.0707064146