In [120]:
import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
from matplotlib import pyplot as plt
from tqdm import tqdm
from collections import defaultdict

In [121]:
torch.manual_seed(1337)

<torch._C.Generator at 0x119848e70>

In [122]:
#generate data for training as hidden state activations via LRH - sum of sparse overcomplete feature directions

def generate_hidden_data(dim = 128, n_features = 512, 
                         n_samples = (2**10), sparsity = 10):
    #basically want features Y times random vector w where w is sparse, then sum resulting vectors for hidden state
    #overcomplete feature basis?
    features = np.random.randn(n_features, dim)
    features = features / np.linalg.norm(features, axis=1, keepdims=True)

    #init sparsity weights
    weights = np.zeros((n_samples, n_features))
    #generate sparsity weights
    for i in tqdm(range(n_samples)):
        active_feats = np.random.choice(n_features, size=sparsity, replace=False)
        weights[i, active_feats] = np.random.randn(sparsity)
    #make hidden data via sum of sparse features
    hidden_data = weights @ features

    return torch.tensor(hidden_data, dtype=torch.float32)

print(generate_hidden_data().shape)

100%|██████████| 1024/1024 [00:00<00:00, 72696.25it/s]

torch.Size([1024, 128])





In [123]:
class SAE(nn.Module):
    def __init__(self, input_dim, width_ratio=4, activation=nn.ReLU()):
        super().__init__()
        self.sae_hidden = input_dim * width_ratio
        self.W_in = nn.Parameter(
            nn.init.kaiming_uniform_(
                torch.empty(input_dim, self.sae_hidden), nonlinearity="relu"
            )
        )
        self.b_in = nn.Parameter(torch.zeros(self.sae_hidden))
        self.W_out = nn.Parameter(
            nn.init.kaiming_uniform_(
                torch.empty(self.sae_hidden, input_dim), nonlinearity="relu"
            )
        )
        self.b_out = nn.Parameter(torch.zeros(input_dim))
        self.nonlinearity = activation

    def _normalize_weights(self):
        with torch.no_grad():
            norms = self.W_out.norm(p=2, dim=0, keepdim=True)
            self.W_out.div_(norms)

    def forward(self, x):
        x = x - self.b_out
        acts = self.nonlinearity(x @ self.W_in + self.b_in)
        l1_regularization = acts.abs().sum()
        l0 = (acts > 0).sum(dim=1).float().mean()
        self._normalize_weights()

        return l0, l1_regularization, acts@self.W_out + self.b_out


In [124]:
def train(model, train_data, test_data, batch_size=128, n_epochs=1000, l1_lam=5e-5, weight_decay=1e-4):
    optimizer = optim.Adam(model.parameters(), weight_decay=weight_decay)
    mse_criterion = nn.MSELoss()

    n_batches = len(train_data) // batch_size
    n_test_batches = len(test_data) // batch_size

    for epoch in range(n_epochs):
        total_loss = 0
        total_test_loss = 0
        total_mse_loss = 0
        total_l1_loss = 0
        total_l0 = 0
        batch_perm = torch.randperm(len(train_data))
        test_batch_perm = torch.randperm(len(test_data))

        for i in range(n_batches):
            # Training
            idx = batch_perm[i*batch_size: (i+1)*batch_size]
            batch = train_data[idx]

            optimizer.zero_grad()
            l0, l1, recon_hiddens = model(batch)

            recon_loss = mse_criterion(recon_hiddens, batch)
            sparsity_loss = l1_lam * l1
            loss = recon_loss + sparsity_loss

            loss.backward()
            optimizer.step()
            
            total_loss += loss.item()
            total_l1_loss += sparsity_loss.item()
            total_mse_loss += recon_loss.item()
            total_l0 += l0

            # Testing
            if i < n_test_batches:
                test_idx = test_batch_perm[i*batch_size: (i+1)*batch_size]
                test_batch = test_data[test_idx]
                
                with torch.no_grad():
                    _, _, test_recon = model(test_batch)
                    test_loss = mse_criterion(test_recon, test_batch)
                    total_test_loss += test_loss.item()

        if epoch % 10 == 0:
            avg_loss = total_loss / n_batches
            avg_test_loss = total_test_loss / n_test_batches
            avg_l1_loss = total_l1_loss / n_batches
            avg_l0 = total_l0 / n_batches
            
            print(f'Epoch {epoch}, Loss: {avg_loss:.4f}, '
                  f'Test Loss: {avg_test_loss:.4f}, '
                  f'L1: {avg_l1_loss:.4f}, '
                  f'L0: {avg_l0:.4f}')

    return {
        'mse': total_mse_loss/n_batches,
        'L0': total_l0/n_batches
    }

In [125]:
def run_experiment():
    sparsity = 20
    hidden_dim = 128
    width_factor = 4

    data = generate_hidden_data(dim=hidden_dim, sparsity=sparsity)
    train_size = int(0.8 * len(data))
    train_data, test_data = data[:train_size], data[train_size:]
    
    relu_model = SAE(hidden_dim, width_factor, nn.ReLU())
    print("Training ReLU model...")
    result = train(relu_model, train_data, test_data)
            

def run_DOE():
    sparsities = [5, 10, 20, 30, 40, 50]
    results = defaultdict(list)
    hidden_dim = 128
    width_factor = 4

    for sparsity in sparsities:
        for trial in range(10):
            data = generate_hidden_data(dim=hidden_dim, sparsity=sparsity)
            train_size = int(0.8 * len(data))
            train_data, test_data = data[:train_size], data[train_size:]
            
            relu_model = SAE(hidden_dim, width_factor, nn.ReLU())
            print("Training ReLU model...")
            result = train(relu_model, train_data)
            results[sparsity].append(result)
            

In [126]:
run_experiment()

100%|██████████| 1024/1024 [00:00<00:00, 72970.44it/s]


Training ReLU model...
Epoch 0, Loss: 0.5255, Test Loss: 0.1853, L1: 0.3480, L0: 253.4700
Epoch 10, Loss: 0.1833, Test Loss: 0.1245, L1: 0.0559, L0: 83.3477
Epoch 20, Loss: 0.1593, Test Loss: 0.1372, L1: 0.0163, L0: 28.8008
Epoch 30, Loss: 0.1508, Test Loss: 0.1365, L1: 0.0212, L0: 32.2318
Epoch 40, Loss: 0.1416, Test Loss: 0.1304, L1: 0.0248, L0: 31.0703
Epoch 50, Loss: 0.1351, Test Loss: 0.1225, L1: 0.0283, L0: 30.0286
Epoch 60, Loss: 0.1299, Test Loss: 0.1183, L1: 0.0311, L0: 28.9271
Epoch 70, Loss: 0.1263, Test Loss: 0.1155, L1: 0.0328, L0: 27.6419
Epoch 80, Loss: 0.1233, Test Loss: 0.1129, L1: 0.0339, L0: 26.5443
Epoch 90, Loss: 0.1213, Test Loss: 0.1106, L1: 0.0348, L0: 25.7292
Epoch 100, Loss: 0.1201, Test Loss: 0.1098, L1: 0.0357, L0: 25.5664
Epoch 110, Loss: 0.1184, Test Loss: 0.1112, L1: 0.0360, L0: 24.8607
Epoch 120, Loss: 0.1173, Test Loss: 0.1106, L1: 0.0362, L0: 24.4492
Epoch 130, Loss: 0.1169, Test Loss: 0.1085, L1: 0.0369, L0: 24.0521
Epoch 140, Loss: 0.1161, Test Loss: