In [1]:
import torch
import torch.nn as nn

class SparseAutoencoder(nn.Module):
    def __init__(self, input_dim=768, hidden_dim=3072, sparsity_coef=1e-3):
        """
        Sparse Autoencoder for decomposing activations
        
        Args:
            input_dim: Size of input (768 for GPT-2)
            hidden_dim: Size of hidden layer (expansion factor * input_dim)
            sparsity_coef: Weight for L1 sparsity loss (beta)
        """
        super().__init__()
        
        self.input_dim = input_dim
        self.hidden_dim = hidden_dim
        self.sparsity_coef = sparsity_coef
        
        # Encoder: projects to higher-dimensional space
        self.encoder = nn.Linear(input_dim, hidden_dim, bias=True)
        
        # Decoder: projects back to original space
        self.decoder = nn.Linear(hidden_dim, input_dim, bias=True)
        
        # Initialize weights
        self._initialize_weights()
    
    def _initialize_weights(self):
        nn.init.xavier_uniform_(self.encoder.weight)
        nn.init.zeros_(self.encoder.bias)
        nn.init.xavier_uniform_(self.decoder.weight)
        nn.init.zeros_(self.decoder.bias)
    
    def encode(self, x):
        """
        Encode input to sparse features
        
        Args:
            x: Input activations [batch, 768]
        Returns:
            h: Hidden features [batch, hidden_dim]
        """
        h = self.encoder(x)
        h = torch.relu(h)  # ReLU for sparsity
        return h
    
    def decode(self, h):
        """
        Decode features back to activation space
        
        Args:
            h: Hidden features [batch, hidden_dim]
        Returns:
            x_reconstructed: [batch, 768]
        """
        return self.decoder(h)
    
    def forward(self, x):
        """
        Full forward pass
        
        Args:
            x: Input activations [batch, 768]
        Returns:
            x_reconstructed: Reconstructed activations [batch, 768]
            h: Hidden features [batch, hidden_dim]
        """
        h = self.encode(x)
        x_reconstructed = self.decode(h)
        return x_reconstructed, h
    
    def compute_loss(self, x, x_reconstructed, h):
        """
        Compute total loss
        
        Args:
            x: Original input
            x_reconstructed: Reconstructed input
            h: Hidden features
        Returns:
            loss: Total loss
            loss_dict: Dictionary with loss components
        """
        # Reconstruction loss (MSE)
        recon_loss = torch.mean((x - x_reconstructed) ** 2)
        
        # Sparsity loss (L1)
        # We are trying to ensure that as many of the hidden dimension neurons are close to zero as possible! This is where the "sparse" encoding comes from
        sparsity_loss = torch.mean(torch.abs(h))
        
        # Total loss
        total_loss = recon_loss + self.sparsity_coef * sparsity_loss
        
        # L0 (number of non-zero features)
        l0 = (h > 0).float().sum(dim=1).mean()
        
        return total_loss, {
            'total_loss': total_loss.item(),
            'recon_loss': recon_loss.item(),
            'sparsity_loss': sparsity_loss.item(),
            'l0': l0.item()
        }

In [2]:
sae = SparseAutoencoder(input_dim=768, hidden_dim=3072)

x = torch.randn(32, 768)  # Batch of 32 activations
x_recon, h = sae(x)

print(f"Input shape: {x.shape}")
print(f"Hidden shape: {h.shape}")
print(f"Output shape: {x_recon.shape}")

# Test loss
loss, loss_dict = sae.compute_loss(x, x_recon, h)
print(f"\nLoss components:")
for k, v in loss_dict.items():
    print(f"  {k}: {v:.6f}")


Input shape: torch.Size([32, 768])
Hidden shape: torch.Size([32, 3072])
Output shape: torch.Size([32, 768])

Loss components:
  total_loss: 1.316933
  recon_loss: 1.316681
  sparsity_loss: 0.252120
  l0: 1530.093750


In [None]:
data = 