# Stochastic Parameter Decomposition

## Imports

In [None]:
# import all the torch stuff, as well as the ViT stuff and such
# import plotly.express as px
# import torch
# from jaxtyping import Int, Float
# from typing import List, Optional, Tuple
# from tqdm import tqdm
# from transformer_lens.hook_points import HookPoint
# from transformer_lens import utils, HookedTransformer, HookedTransformerConfig, FactoredMatrix, ActivationCache
# import circuitsvis as cv

In [81]:
import torch
from torch import Tensor
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
import numpy as np
import einops
import typing
from tqdm import tqdm

## Toy Model MLP

In [1]:
# create a SPDHookedTransformer Class
# that extends the hookedTransformer
# which inherits all the previous class stuff
# but also contains the SPD training algo?

# how is everything actually structured? 
"""
Maybe I should start with implementing it on a simple MLP setup. 

Orig network:
- 1 layer MLP 5-2-5 (defined as usual with torch.Sequential presumably)

SPD network:
let's give it 10 subcomponents per layer
then it's just like 10 matmuls? 
maybe i should define it like ant did in the toy models of superposition paper
stick them all into a trenchcoat
this might be easier once I have defined the toy modle

"""

In [82]:
"""config = {
    "num_layers": 1,
    "pre_embed_size": 100,
    "in_size": 1000,
    "hidden_size": 50,
    "subcomponents_per_layer": 10, 
    "beta_1": 1, 
    "beta_2": 1, 
    "beta_3": 1, 
    "causal_imp_min": 1, 
    "num_mask_samples": 20,
    "importance_mlp_size": 10,
}"""

In [130]:
class ToyResidMLP(nn.Module):
    def __init__(self, config, device="cpu"):
        super().__init__()
        # Initialize Weights for the
        self.num_layers, self.pre_embed_size, self.in_size, self.hidden_size = config["num_layers"], config["pre_embed_size"], config["in_size"], config["hidden_size"]
        self.device = device
        self.W_embed = nn.Parameter(torch.empty((self.pre_embed_size, self.in_size)))
        self.W_unembed = nn.Parameter(torch.empty((self.in_size, self.pre_embed_size)))
        self.W_in = nn.ParameterList([torch.empty((self.in_size, self.hidden_size), device=device) for i in range(self.num_layers)])
        self.W_out = nn.ParameterList([torch.empty((self.hidden_size, self.in_size), device=device) for i in range(self.num_layers)])
        self.b = nn.ParameterList([torch.zeros((self.hidden_size,), device=device) for i in range(self.num_layers)])

        for param in [self.W_embed, self.W_unembed] + list(self.W_in) + list(self.W_out): 
            nn.init.xavier_normal_(param)
        
    def forward(self, x): 
        
        assert x.shape[1] == self.pre_embed_size, f"Input shape {x.shape[0]} does not match model's accepted size {self.pre_embed_size}"
        # embed 
        x_resid = torch.einsum("np,pi->ni", x.clone(), self.W_embed)
        N, D = x_resid.shape

        
        for l in range(self.num_layers):
            hidden = F.relu(torch.einsum("nd,dh -> nh", x_resid, self.W_in[l]) + self.b[l])
            layer_out = torch.einsum("nh,hd -> nd", hidden, self.W_out[l])
            x_resid = x_resid + layer_out
        # am I supposed to have a embed and out?
        x_out = torch.einsum("ni,ip->np", x_resid, self.W_unembed) 
        return x_out


# Generated by LLM bc I am lazy and this is not important to me learning stuff atm
class SparseAutoencoderDataset(Dataset):
    """
    Dataset for learning to reconstruct sparse inputs.
    Each item is (input, target) where target is the input itself (or ReLU of input).
    """
    def __init__(self, in_dim=100, n_samples=10000, sparsity=0.9, device="cpu"):
        super().__init__()
        self.in_dim = in_dim
        self.n_samples = n_samples
        self.device = device

        # Pre-generate all samples
        self.inputs = []
        self.targets = []
        
        for _ in range(n_samples):
            # Sparse input: each entry is -1, 0, or 1, with sparsity
            x = np.random.choice([0, -1, 1], size=(in_dim,), p=[1-sparsity, sparsity/2, sparsity/2])
            x = torch.tensor(x, dtype=torch.float32, device=device)
            target = F.relu(x)
            
            self.inputs.append(x)
            self.targets.append(target)
    
    def __len__(self):
        return self.n_samples
    
    def __getitem__(self, idx):
        return self.inputs[idx], self.targets[idx]

def train_toy_resid_mlp(
    model,
    dataloader,
    lr=1e-3,
    num_epochs=10,
    device="cpu",
    print_every=1
):
    model.train()
    model.to(device)
    optimizer = torch.optim.AdamW(model.parameters(), lr=lr)
    loss_fn = nn.MSELoss()
    for epoch in range(num_epochs):
        total_loss = 0.0
        for x, y in tqdm(dataloader, desc=f"Epoch {epoch+1}/{num_epochs}"):
            x = x.to(device)
            y = y.to(device)
            optimizer.zero_grad()
            y_pred = model(x)
            loss = loss_fn(y_pred, y)
            loss.backward()
            optimizer.step()
            total_loss += loss.item() * x.size(0)
        avg_loss = total_loss / len(dataloader.dataset)
        if (epoch+1) % print_every == 0:
            print(f"Epoch {epoch+1}: avg MSE loss = {avg_loss:.6f}")

In [272]:
## LLM-Generated Usage Example

torch.autograd.set_detect_anomaly(True)

if __name__ == "__main__":
    # Config
    # device = "mps" if torch.backends.mps.is_available() else "cpu"
    device = "cpu"
    # Dataset and DataLoader
    # Dataset and DataLoader
    dataset = SparseAutoencoderDataset(
        in_dim=100,
        n_samples=10000,
        sparsity=0.9,
        device=device,
    )
    print(device)
    
    dataloader = DataLoader(dataset, batch_size=128, shuffle=True)
    # Model
    toy_model = ToyResidMLP(config, device=device)
    # Train
    train_toy_resid_mlp(toy_model, dataloader, lr=8e-2, num_epochs=20, device=device)

cpu


Epoch 1/20: 100%|████████████████████████████████████████████████████████████████████████████████████████████████████| 79/79 [00:00<00:00, 114.36it/s]


Epoch 1: avg MSE loss = 1457925.062243


Epoch 2/20: 100%|████████████████████████████████████████████████████████████████████████████████████████████████████| 79/79 [00:00<00:00, 115.52it/s]


Epoch 2: avg MSE loss = 755.598002


Epoch 3/20: 100%|████████████████████████████████████████████████████████████████████████████████████████████████████| 79/79 [00:00<00:00, 116.34it/s]


Epoch 3: avg MSE loss = 95.474196


Epoch 4/20: 100%|████████████████████████████████████████████████████████████████████████████████████████████████████| 79/79 [00:00<00:00, 115.75it/s]


Epoch 4: avg MSE loss = 45.635807


Epoch 5/20: 100%|████████████████████████████████████████████████████████████████████████████████████████████████████| 79/79 [00:00<00:00, 107.14it/s]


Epoch 5: avg MSE loss = 25.904420


Epoch 6/20: 100%|████████████████████████████████████████████████████████████████████████████████████████████████████| 79/79 [00:00<00:00, 116.48it/s]


Epoch 6: avg MSE loss = 16.163718


Epoch 7/20: 100%|████████████████████████████████████████████████████████████████████████████████████████████████████| 79/79 [00:00<00:00, 116.98it/s]


Epoch 7: avg MSE loss = 10.680986


Epoch 8/20: 100%|████████████████████████████████████████████████████████████████████████████████████████████████████| 79/79 [00:00<00:00, 115.97it/s]


Epoch 8: avg MSE loss = 7.360531


Epoch 9/20: 100%|████████████████████████████████████████████████████████████████████████████████████████████████████| 79/79 [00:00<00:00, 113.03it/s]


Epoch 9: avg MSE loss = 5.218640


Epoch 10/20: 100%|███████████████████████████████████████████████████████████████████████████████████████████████████| 79/79 [00:00<00:00, 115.02it/s]


Epoch 10: avg MSE loss = 3.806938


Epoch 11/20: 100%|███████████████████████████████████████████████████████████████████████████████████████████████████| 79/79 [00:00<00:00, 116.20it/s]


Epoch 11: avg MSE loss = 2.834863


Epoch 12/20: 100%|███████████████████████████████████████████████████████████████████████████████████████████████████| 79/79 [00:00<00:00, 115.94it/s]


Epoch 12: avg MSE loss = 2.161952


Epoch 13/20: 100%|███████████████████████████████████████████████████████████████████████████████████████████████████| 79/79 [00:00<00:00, 116.39it/s]


Epoch 13: avg MSE loss = 1.674701


Epoch 14/20: 100%|███████████████████████████████████████████████████████████████████████████████████████████████████| 79/79 [00:00<00:00, 116.51it/s]


Epoch 14: avg MSE loss = 1.327738


Epoch 15/20: 100%|███████████████████████████████████████████████████████████████████████████████████████████████████| 79/79 [00:00<00:00, 109.30it/s]


Epoch 15: avg MSE loss = 1.071527


Epoch 16/20: 100%|███████████████████████████████████████████████████████████████████████████████████████████████████| 79/79 [00:00<00:00, 113.86it/s]


Epoch 16: avg MSE loss = 0.883784


Epoch 17/20: 100%|███████████████████████████████████████████████████████████████████████████████████████████████████| 79/79 [00:00<00:00, 115.97it/s]


Epoch 17: avg MSE loss = 0.743674


Epoch 18/20: 100%|███████████████████████████████████████████████████████████████████████████████████████████████████| 79/79 [00:00<00:00, 116.12it/s]


Epoch 18: avg MSE loss = 0.639231


Epoch 19/20: 100%|███████████████████████████████████████████████████████████████████████████████████████████████████| 79/79 [00:00<00:00, 116.19it/s]


Epoch 19: avg MSE loss = 0.561527


Epoch 20/20: 100%|███████████████████████████████████████████████████████████████████████████████████████████████████| 79/79 [00:00<00:00, 116.34it/s]

Epoch 20: avg MSE loss = 0.503866





## SPD Model and Train Function

In [8]:
# generated by LLM

class HardSigmoid(nn.Module):
    """
    Implements the hard sigmoid activation function as described in the paper:
        σ_H(x) = 0 if x <= 0
               = x if 0 < x < 1
               = 1 if x >= 1
    This is equivalent to: torch.clamp(x, min=0.0, max=1.0)
    """
    def __init__(self):
        super().__init__()

    def forward(self, x):
        # Clamp values between 0 and 1
        return torch.clamp(x, min=0.0, max=1.0)

tensor([0.0000, 0.0000, 0.0000, 0.3000, 0.7000, 1.0000, 1.0000])


In [330]:
class SPDModelMLP(nn.Module): 
    # fix the betas stuff below that's basically a type hint
    def __init__(self, target_model, config, device="cpu"): 
        super().__init__()
        
        print("Devices: ", device, target_model.device)
        self.device = device
        object.__setattr__(self, "target_model", target_model) # sets pointer to target_model without registering its parameters as subsidiary

        self.target_model = target_model
        assert self.device == target_model.device, "Models not on same device"

        self.num_layers, self.pre_embed_size, self.in_size, self.hidden_size, self.imp_hidden_size = config["num_layers"], config["pre_embed_size"], config["in_size"], config["hidden_size"], config["importance_mlp_size"]
        self.C = config["subcomponents_per_layer"]
        self.hypers = dict(list(config.items())[4:]) # sets the "hypers" to contain all the hyperparameters for the model
        
        # Subcomponent vectors, each of shape C by in_size; to be used
        # with outer product to create our low-rank subcomponent matrices
        self.V_embed = nn.Parameter(torch.empty((self.C, self.pre_embed_size,), device=device))
        self.U_embed = nn.Parameter(torch.empty((self.C, self.in_size,), device=device))
        self.V_unembed = nn.Parameter(torch.empty((self.C, self.in_size,), device=device))
        self.U_unembed = nn.Parameter(torch.empty((self.C, self.pre_embed_size,), device=device))

        self.V_in = nn.ParameterList([torch.empty((self.C, self.in_size,), device=device) for i in range(self.num_layers)])
        self.U_in = nn.ParameterList([torch.empty((self.C, self.hidden_size,), device=device) for i in range(self.num_layers)]) 
        self.V_out = nn.ParameterList([torch.empty((self.C, self.hidden_size,), device=device) for i in range(self.num_layers)])
        self.U_out = nn.ParameterList([torch.empty((self.C, self.in_size,), device=device) for i in range(self.num_layers)])        
        
        # idk what you do with the biases lol
        self.b = nn.ParameterList([torch.zeros((self.hidden_size,), device=device) for i in range(self.num_layers)])
        
        # this is so horrible I'm sorry
        # gate_in_in is the gate_in weights for the in subcomponent of each layer
        # gate_in_out is gate_in weights for the out component
        # they each get an extra one on the end which is for the embed matrix
        # Probably should have just registered these as submodels, but I didn't know that those existed before.
        self.imp_W_gate_in_in = nn.ParameterList([torch.empty(self.C, 1, self.imp_hidden_size) for i in range(self.num_layers+1)])
        self.imp_W_gate_out_in = nn.ParameterList([torch.empty(self.C, self.imp_hidden_size, 1) for i in range(self.num_layers+1)]) 
        self.imp_b_in_in = nn.ParameterList([torch.empty(self.C, self.imp_hidden_size) for i in range(self.num_layers+1)])
        self.imp_b_out_in = nn.ParameterList([torch.empty(self.C, 1) for i in range(self.num_layers+1)])
        
        self.imp_W_gate_in_out = nn.ParameterList([torch.empty(self.C, 1, self.imp_hidden_size) for i in range(self.num_layers+1)])
        self.imp_W_gate_out_out = nn.ParameterList([torch.empty(self.C, self.imp_hidden_size, 1) for i in range(self.num_layers+1)]) 
        self.imp_b_in_out = nn.ParameterList([torch.empty(self.C, self.imp_hidden_size) for i in range(self.num_layers+1)])
        self.imp_b_out_out = nn.ParameterList([torch.empty(self.C, 1) for i in range(self.num_layers+1)])
        
        # self.pred_weights = torch.empty((num_layers, self.C), device=device)

        for param in self.parameters():
            # using xavier anyway -- note that the variance etc is
            # changed because we take the outer product in the 
            # forward pass
            if param.dim() >= 2: # xavier does not work for 1d tensors
                nn.init.xavier_normal_(param)/np.sqrt(self.C) 
            else:
                nn.init.zeros_(param)  # or another appropriate initializer for biases

        
    def forward(self, x, return_activs_and_weights=False, masks=None): # Regular run. Unclear whether I should have masking when I do a regular forward pass.
        v_activations = []
        weight_matrices = []
        x_in = x.clone()
        layerwise_resids = []

        if masks is not None: 
            
            v_activ_embed = torch.einsum("np,cp->nc", x_in, self.V_embed) * masks[-1]["embed"]
            v_activ_embed_nomask = torch.einsum("np,cp->nc", x_in, self.V_embed)

            x_embedded = torch.einsum("nc,ci->ni", v_activ_embed, self.U_embed)
            x_embedded_nomask = torch.einsum("nc,ci->ni", v_activ_embed_nomask, self.U_embed)

            x_resid = x_embedded.clone()
                        
            for l in range(self.num_layers):
                ### Run the forward pass just for this layer
                v_activ_layer_in = torch.einsum("ni,ci->nc", x_resid, self.V_in[l]) * masks[l]["in"] # V activations
                layer_hidden = F.relu(torch.einsum("nc,ch->nh", v_activ_layer_in, self.U_in[l]) + self.b[l]) # U and ReLU
                v_activ_layer_out = torch.einsum("nh,ch->nc", layer_hidden, self.V_out[l]) * masks[l]["out"] # V activ
                layer_out = torch.einsum("nc,ci->ni", v_activ_layer_out, self.U_out[l]) # U
                x_resid = x_resid + layer_out # add to residual stream

                # Run the layerwise forward pass for the whole model, with masking only this layer
                x_resid_layerwise_in = x_embedded_nomask.clone() # pre-embedded
                x_resid_layerwise_out = x_embedded_nomask.clone() # pre-embedded again
                
                for l_2 in range(self.num_layers): # oh god this is so awful. literally sobbing rn. i can't believe I'm writing it like this
                    if l_2 == l: 
                        # run with mask on in
                        v_activ_layer_in_l = torch.einsum("ni,ci->nc", x_resid_layerwise_in, self.V_in[l]) * masks[l]["in"]
                        layer_hidden_l = F.relu(torch.einsum("nc,ch->nh", v_activ_layer_in_l, self.U_in[l])+self.b[l])
                        v_activ_layer_out_l = torch.einsum("nh,ch->nc", layer_hidden_l, self.V_out[l])
                        layer_out_l_masked_in = torch.einsum("nc,ci->ni", v_activ_layer_out_l, self.U_out[l])
                        x_resid_layerwise_in = x_resid_layerwise_in + layer_out_l_masked_in

                        # run with mask on out
                        v_activ_layer_in_l = torch.einsum("ni,ci->nc", x_resid_layerwise_out, self.V_in[l])
                        layer_hidden_l = F.relu(torch.einsum("nc,ch->nh", v_activ_layer_in_l, self.U_in[l])+self.b[l])
                        v_activ_layer_out_l = torch.einsum("nh,ch->nc", layer_hidden_l, self.V_out[l]) * masks[l]["out"]
                        layer_out_l_masked_out = torch.einsum("nc,ci->ni", v_activ_layer_out_l, self.U_out[l])
                        x_resid_layerwise_out = x_resid_layerwise_out + layer_out_l_masked_out

                        
                    else: # THE SAME EXACT THING EXCEPT WITHOUT THE MASKS >_< 
                        v_activ_layer_in_l = torch.einsum("ni,ci->nc", x_resid_layerwise_in, self.V_in[l])
                        layer_hidden_l = F.relu(torch.einsum("nc,ch->nh", v_activ_layer_in_l, self.U_in[l]) + self.b[l])
                        v_activ_layer_out_l = torch.einsum("nh,ch->nc", layer_hidden_l, self.V_out[l])
                        layer_out_l = torch.einsum("nc,ci->ni", v_activ_layer_out_l, self.U_out[l])
                        x_resid_layerwise_in = x_resid_layerwise_in + layer_out_l
                        
                        v_activ_layer_in_l = torch.einsum("ni,ci->nc", x_resid_layerwise_out, self.V_in[l])
                        layer_hidden_l = F.relu(torch.einsum("nc,ch->nh", v_activ_layer_in_l, self.U_in[l]) + self.b[l])
                        v_activ_layer_out_l = torch.einsum("nh,ch->nc", layer_hidden_l, self.V_out[l])
                        layer_out_l = torch.einsum("nc,ci->ni", v_activ_layer_out_l, self.U_out[l])
                        x_resid_layerwise_out = x_resid_layerwise_out + layer_out_l
                        
                    x_out_layerwise_in = torch.einsum("nc,cp->np", torch.einsum("ni,ci->nc", x_resid_layerwise_in, self.V_unembed), self.U_unembed)
                    x_out_layerwise_out = torch.einsum("nc,cp->np", torch.einsum("ni,ci->nc", x_resid_layerwise_out, self.V_unembed), self.U_unembed)
                    layerwise_resids.append({"in":x_out_layerwise_in,"out":x_out_layerwise_out})
                    
                # run layerwise for embed and unembed as well, because very intelligently I gave them their own separate weight matrices 
                # and they don't get included when iterating through layers
                x_embed_resid = x_embedded.clone()
                for l in range(self.num_layers):
                    v_activ_layer_in_l = torch.einsum("ni,ci->nc", x_embed_resid, self.V_in[l])
                    layer_hidden_l = F.relu(torch.einsum("nc,ch->nh", v_activ_layer_in_l, self.U_in[l]) + self.b[l])
                    v_activ_layer_out_l = torch.einsum("nh,ch->nc", layer_hidden_l, self.V_out[l])
                    layer_out_embed_mask = torch.einsum("nc,ci->ni", v_activ_layer_out_l, self.U_out[l])
                    x_embed_resid = x_embed_resid + layer_out_embed_mask
                x_out_embed_mask = torch.einsum("nc,cp->np", torch.einsum("ni,ci->nc", x_embed_resid, self.V_unembed), self.U_unembed)

                x_unembed_resid = x_embedded_nomask.clone()
                for l in range(self.num_layers):
                    v_activ_layer_in_l = torch.einsum("ni,ci->nc", x_unembed_resid, self.V_in[l])
                    layer_hidden_l = F.relu(torch.einsum("nc,ch->nh", v_activ_layer_in_l, self.U_in[l]) + self.b[l])
                    v_activ_layer_out_l = torch.einsum("nh,ch->nc", layer_hidden_l, self.V_out[l])
                    layer_out_unembed_mask = torch.einsum("nc,ci->ni", v_activ_layer_out_l, self.U_out[l])
                    x_unembed_resid = x_unembed_resid + layer_out_unembed_mask
                x_out_unembed_mask = torch.einsum("nc,cp->np", torch.einsum("ni,ci->nc", x_unembed_resid, self.V_unembed)* masks[-1]["unembed"], self.U_unembed)

                layerwise_resids.append({"embed": x_out_embed_mask, "unembed": x_out_unembed_mask})

            v_activ_unembed = torch.einsum("ni,ci->nc", x_resid, self.V_unembed) * masks[-1]["unembed"]
            x_out = torch.einsum("nc,cp->np", v_activ_unembed, self.U_unembed)

                                
        else: 
            W_embed = torch.einsum("cp, ci -> cpi", self.V_embed, self.U_embed).sum(dim=0)
            x_resid = torch.einsum("np, pi -> ni", x_in, W_embed)
            v_activ_embed = torch.einsum("np,cp->nc", x_in, self.V_embed).unsqueeze(-1)
            
            for l in range(self.num_layers):
                # Use outer product to create weights for the layer, then sum all the subcomponents
                W_in = torch.einsum("ci,ch-> cih", self.V_in[l], self.U_in[l]).sum(dim=0) # shape i h
                W_out = torch.einsum("ch,ci-> chi", self.V_out[l], self.U_out[l]).sum(dim=0) # shape h i
                if return_activs_and_weights == True:
                    weight_matrices.append({"in": W_in, "out": W_out})
                # COMPUTE 
                layer_in = F.relu(torch.einsum("ni,ih -> nh", x_resid, W_in) + self.b[l])
                layer_out = torch.einsum("nh,hi->ni", F.relu(layer_in), W_out)
            
                if return_activs_and_weights == True: # LMAO I HAD IT INDENTED WRONG
                    # calculate activations
                    v_activ_layer_in = torch.einsum("ni,ci->nc", x_resid, self.V_in[l]).unsqueeze(-1)
                    v_activ_layer_out = torch.einsum("nh,ch->nc", layer_in, self.V_out[l]).unsqueeze(-1)
                    # both are now shape n,c,1
                    v_activations.append({"in": v_activ_layer_in, "out": v_activ_layer_out})
                    
                    # code to make sure I'm not doing some goofy shit with the activations at the beginning
                    check_layer_in_uv = F.relu(torch.einsum("nc,ch->nh", v_activ_layer_in.squeeze(), self.U_in[l]) + self.b[l])
                    assert torch.allclose(layer_in, check_layer_in_uv, atol=10), f"Subcomponent activations not calculated correctly, max difference is {torch.max(layer_in-check_layer_in_uv)}; {layer_in}{check_layer_in_uv}"
                    # would in theory have one for the second too but I'm being lazy
                
            x_resid = x_resid + layer_out
            W_unembed = torch.einsum("ci, cp-> cip", self.V_unembed, self.U_unembed).sum(dim=0)
            x_out = torch.einsum("ni, ip -> np", x_resid, W_unembed)
            
            if return_activs_and_weights == True:
                weight_matrices.append({"embed": W_embed, "unembed": W_unembed})
                v_activ_unembed = torch.einsum("ni,ci->nc", x_resid, self.V_unembed).unsqueeze(-1)
                v_activations.append({"embed": v_activ_embed, "unembed": v_activ_unembed})
            # putting the embed matrices at the end of the masks. since we iterate through l in range(num_layers) this will not get added! :D

        if return_activs_and_weights:
            return x_out, v_activations, weight_matrices 
        elif masks is not None: 
            return x_out, layerwise_resids
        else: 
            return x_out



if __name__ == "__main__":
    # Config
    #device = "cuda" if torch.cuda.is_available() else "mps" if torch.backends.mps.is_available() else "cpu"
    device = "cpu"
    config = {
        "num_layers": 1,
        "pre_embed_size": 100,
        "in_size": 1000,
        "hidden_size": 50,
        "subcomponents_per_layer": 10, 
        "beta_1": 1.0, 
        "beta_2": 1.0, 
        "beta_3": 1.0, 
        "causal_imp_min": 1.0, 
        "num_mask_samples": 20,
        "importance_mlp_size": 5,
    }
    
    dataset = SparseAutoencoderDataset(
        in_dim=100,
        n_samples=10000,
        sparsity=0.9,
        device=device,
    )
    
    dataloader = DataLoader(dataset, batch_size=128, shuffle=True)
    # Model
    spd_model = SPDModelMLP(toy_model, config, device=device)
    # Train
    train_SPD(spd_model, dataloader, lr=8e-2, num_epochs=20)

Devices:  cpu cpu
Training on device cpu


Epoch 1/20:   1%|█▏                                                                                          | 1/79 [00:01<01:44,  1.34s/it, loss=594]

Masked out min:  -0.2869523763656616 , max:  0.2887800335884094
Faithfulness: 264.47454833984375, Stoch Rec: 64.21952819824219, Stoch Rec Layerwise: 258.2439270019531, Importance Min: 6.957733154296875
Total gradient norm: 59.34287965195927


Epoch 1/20:   8%|██████▋                                                                                 | 6/79 [00:07<01:33,  1.28s/it, loss=1.94e+3]

Masked out min:  -9.320036888122559 , max:  9.596722602844238
Faithfulness: 267.8381652832031, Stoch Rec: 138.2049560546875, Stoch Rec Layerwise: 1532.0335693359375, Importance Min: 0.6614287495613098
Total gradient norm: 3754.5567799030964


Epoch 1/20:  14%|████████████                                                                           | 11/79 [00:14<01:33,  1.38s/it, loss=2.54e+3]

Masked out min:  -10.717123031616211 , max:  9.643646240234375
Faithfulness: 275.00274658203125, Stoch Rec: 141.59588623046875, Stoch Rec Layerwise: 2125.414794921875, Importance Min: 0.8351671099662781
Total gradient norm: 8113.127795851296


Epoch 1/20:  20%|█████████████████▌                                                                     | 16/79 [00:20<01:21,  1.30s/it, loss=7.02e+3]

Masked out min:  -27.065927505493164 , max:  25.073280334472656
Faithfulness: 283.6265869140625, Stoch Rec: 334.8540344238281, Stoch Rec Layerwise: 6402.341796875, Importance Min: 0.3327096700668335
Total gradient norm: 17843.103738707574


Epoch 1/20:  27%|███████████████████████▏                                                               | 21/79 [00:27<01:14,  1.28s/it, loss=2.64e+3]

Masked out min:  -10.594161987304688 , max:  12.317036628723145
Faithfulness: 306.9132080078125, Stoch Rec: 156.77793884277344, Stoch Rec Layerwise: 2171.67138671875, Importance Min: 0.0
Total gradient norm: 4652.010105639884


Epoch 1/20:  33%|████████████████████████████▉                                                           | 26/79 [00:33<01:07,  1.28s/it, loss=2.3e+3]

Masked out min:  -6.0844645500183105 , max:  6.3159966468811035
Faithfulness: 269.41326904296875, Stoch Rec: 141.77806091308594, Stoch Rec Layerwise: 1892.585693359375, Importance Min: 0.0
Total gradient norm: 5194.6129538806


Epoch 1/20:  39%|██████████████████████████████████▏                                                    | 31/79 [00:40<01:01,  1.28s/it, loss=1.99e+3]

Masked out min:  -12.78572940826416 , max:  14.028889656066895
Faithfulness: 217.8966064453125, Stoch Rec: 129.6506805419922, Stoch Rec Layerwise: 1638.853271484375, Importance Min: 0.0
Total gradient norm: 4830.930143878619


Epoch 1/20:  46%|███████████████████████████████████████▋                                               | 36/79 [00:46<00:55,  1.28s/it, loss=1.09e+3]

Masked out min:  -3.549102783203125 , max:  3.073551893234253
Faithfulness: 155.27503967285156, Stoch Rec: 111.92475891113281, Stoch Rec Layerwise: 818.1057739257812, Importance Min: 0.0
Total gradient norm: 1810.1297425784048


Epoch 1/20:  52%|███████████████████████████████████████████████▏                                           | 41/79 [00:53<00:50,  1.32s/it, loss=679]

Masked out min:  -3.1995632648468018 , max:  2.9430887699127197
Faithfulness: 97.11468505859375, Stoch Rec: 89.93598937988281, Stoch Rec Layerwise: 491.500244140625, Importance Min: 0.0
Total gradient norm: 513.6915152961153


Epoch 1/20:  58%|████████████████████████████████████████████████████▉                                      | 46/79 [00:59<00:42,  1.28s/it, loss=529]

Masked out min:  -2.343099594116211 , max:  1.4476081132888794
Faithfulness: 61.500431060791016, Stoch Rec: 80.50706481933594, Stoch Rec Layerwise: 386.6531982421875, Importance Min: 0.0
Total gradient norm: 323.455451669234


Epoch 1/20:  65%|██████████████████████████████████████████████████████████▋                                | 51/79 [01:06<00:35,  1.28s/it, loss=464]

Masked out min:  -0.9663612842559814 , max:  0.8439489006996155
Faithfulness: 36.90113830566406, Stoch Rec: 76.47035217285156, Stoch Rec Layerwise: 350.33258056640625, Importance Min: 0.0
Total gradient norm: 205.1323850084465


Epoch 1/20:  71%|████████████████████████████████████████████████████████████████▌                          | 56/79 [01:12<00:29,  1.29s/it, loss=418]

Masked out min:  -1.484490990638733 , max:  1.0805836915969849
Faithfulness: 22.232925415039062, Stoch Rec: 74.99824523925781, Stoch Rec Layerwise: 320.586181640625, Importance Min: 0.0
Total gradient norm: 118.33990716038137


Epoch 1/20:  77%|██████████████████████████████████████████████████████████████████████▎                    | 61/79 [01:19<00:23,  1.28s/it, loss=393]

Masked out min:  -0.535370409488678 , max:  0.5659753680229187
Faithfulness: 13.538299560546875, Stoch Rec: 74.23875427246094, Stoch Rec Layerwise: 305.703369140625, Importance Min: 0.0
Total gradient norm: 72.03299966809622


Epoch 1/20:  84%|████████████████████████████████████████████████████████████████████████████               | 66/79 [01:25<00:17,  1.34s/it, loss=369]

Masked out min:  -0.4921940565109253 , max:  0.4793298542499542
Faithfulness: 8.119293212890625, Stoch Rec: 71.939453125, Stoch Rec Layerwise: 289.4379577636719, Importance Min: 0.0
Total gradient norm: 36.81383790202253


Epoch 1/20:  90%|█████████████████████████████████████████████████████████████████████████████████▊         | 71/79 [01:32<00:10,  1.29s/it, loss=354]

Masked out min:  -0.37550997734069824 , max:  0.4103220999240875
Faithfulness: 4.841667175292969, Stoch Rec: 69.82166290283203, Stoch Rec Layerwise: 279.3833923339844, Importance Min: 0.0
Total gradient norm: 32.034912511589866


Epoch 1/20:  96%|███████████████████████████████████████████████████████████████████████████████████████▌   | 76/79 [01:38<00:03,  1.28s/it, loss=354]

Masked out min:  -0.34752029180526733 , max:  0.27377161383628845
Faithfulness: 2.9220175743103027, Stoch Rec: 70.76470184326172, Stoch Rec Layerwise: 280.5880126953125, Importance Min: 0.0
Total gradient norm: 29.2870870967505


Epoch 1/20: 100%|███████████████████████████████████████████████████████████████████████████████████████████| 79/79 [01:42<00:00,  1.29s/it, loss=127]


Epoch 1: avg loss = 1508.908346


AttributeError: 'float' object has no attribute 'item'

In [331]:
def train_SPD(spd_model, dataloader, lr=1e-4, num_epochs=10): # could also implement this by passing in the original model?

    model.train()
    print(f"Training on device {spd_model.device}")
    optimizer = torch.optim.AdamW(spd_model.parameters(), lr = lr)
    
    
    for epoch in range(num_epochs):
        total_loss = 0.0
        total_l_stoch_rec, total_l_stoch_rec_l, total_l_imp, total_l_faith = 0.0, 0.0, 0.0, 0.0

        with tqdm(dataloader, desc=f"Epoch {epoch+1}/{num_epochs}") as t:
            for batch_idx, (x,y) in enumerate(t):                
                x = x.to(device)
                y = y.to(device)
                
                N = x.shape[0]     # x is shape N by in_size
                C = spd_model.C
                target_model = spd_model.target_model
                hard_sigmoid = HardSigmoid()
                optimizer.zero_grad()
                
                # COMPUTE TARGET OUTPUT
                with torch.no_grad(): 
                    target_out = spd_model.target_model(x)
            
                # FAITHFULNESS LOSS
                
                spd_output, spd_activations, spd_weights = spd_model(x, return_activs_and_weights = True)
                squared_error = 0
                for l in range(spd_model.num_layers):
                    in_diff = target_model.W_in[l] - spd_weights[l]["in"]
                    out_diff = target_model.W_out[l] - spd_weights[l]["out"]
                    
                    # torch.linalg.matrix_norm defaults to the frobenius norm
                    # this takes the frobenius norm of the diff and then squares it
                    squared_error_layer = torch.linalg.matrix_norm(in_diff) ** 2 + torch.linalg.matrix_norm(out_diff) ** 2 
                    squared_error = squared_error + squared_error_layer
                    
                mean_squared_error = squared_error/spd_model.num_layers
                l_faithfulness = mean_squared_error
                
            
                ## IMPORTANCE-MINIMALITY LOSS
                
                pred_importances = []
                l_importance_minimality = 0.0
    
                components_imp_pred_embed_hidden = F.gelu(torch.einsum("nco,cos->ncs", spd_activations[-1]["embed"], spd_model.imp_W_gate_in_in[-1]) + spd_model.imp_b_in_in[-1])
                components_imp_pred_embed = hard_sigmoid(torch.einsum("ncs,cso->nco", components_imp_pred_embed_hidden, spd_model.imp_W_gate_out_in[-1]) + spd_model.imp_b_out_in[-1])
                
                components_imp_pred_unembed_hidden = F.gelu(torch.einsum("nco,cos->ncs", spd_activations[-1]["unembed"], spd_model.imp_W_gate_in_out[-1]) + spd_model.imp_b_in_out[-1])
                components_imp_pred_unembed = hard_sigmoid(torch.einsum("ncs,cso->nco", components_imp_pred_unembed_hidden, spd_model.imp_W_gate_out_out[-1]) + spd_model.imp_b_out_out[-1])
                l_importance_minimality = l_importance_minimality + (components_imp_pred_embed.sum() ** spd_model.hypers["causal_imp_min"]) + (components_imp_pred_unembed.sum() ** spd_model.hypers["causal_imp_min"])
            
                for l in range(spd_model.num_layers):
                    # both activations are n by c containing dot product so we already have hard_sigmoid
                    # spd_activations[l][inout] is shape n,c,1 (nco)
                    # imp_W_in is c by 1 by imp_size (cos)
                    # want to map to ncs then back to nco
                    # imp_b_in is shape (C, s) so should broadcast to ncs nicely
                    
                    # TODO: DEFINE hard_sigmoid AS A TORCH MODULE SO IT CAN CALCULATE THE DERIVATIVE 
                    # in theory should write this as a bunch of models stored in the main model, but that's not how i did it and i've already 
                    # written this, shrug
                    # oh my god this is so bad 
                    # imp_W_gate_in_in is the gate in for the W_in, etc
                    components_imp_pred_hidden_in = F.gelu(torch.einsum("nco,cos->ncs", spd_activations[l]["in"], spd_model.imp_W_gate_in_in[l]) + spd_model.imp_b_in_in[l])
                    components_pred_layer_in = hard_sigmoid(torch.einsum("ncs,cso->nco", components_imp_pred_hidden_in, spd_model.imp_W_gate_out_in[l]) + spd_model.imp_b_out_in[l])
            
                    #same thing for the out matrix in layer l
                    components_imp_pred_hidden_out = F.gelu(torch.einsum("nco,cos->ncs", spd_activations[l]["out"], spd_model.imp_W_gate_in_out[l]) + spd_model.imp_b_in_out[l])
                    components_pred_layer_out = hard_sigmoid(torch.einsum("ncs,cso->nco", components_imp_pred_hidden_out, spd_model.imp_W_gate_out_out[l]) + spd_model.imp_b_out_out[l])
                    
                    pred_importances.append({"in": components_pred_layer_in, "out": components_pred_layer_out})
                    l_importance_minimality = l_importance_minimality + (components_pred_layer_in.sum() ** spd_model.hypers["causal_imp_min"]) + (components_pred_layer_out.sum() ** spd_model.hypers["causal_imp_min"])
                
                pred_importances.append({"embed": components_imp_pred_embed, "unembed": components_imp_pred_unembed})
    
                l_importance_minimality /= N # divide by N, the batch size, to avg across the batch (notated as B in the paper)
            
            
                ## STOCHASTIC RECONSTRUCTION LOSS 
            
                l_stochastic_recon = 0.0
                l_stochastic_recon_layerwise = 0.0
                R = torch.rand(spd_model.hypers["num_mask_samples"], N, spd_model.num_layers+1, 2, C)
            
                for s in range(spd_model.hypers["num_mask_samples"]):
                    layer_masks = []
                    # Running this with a for loop. This is slow; I'd ideally just run 
                    # something like M = G[None, :, :] + (1 - G[None, :, :]) * R
                    # But I think there's a clarity tradeoff here so I'm just going to do this for now
                    layer_mask_embed = pred_importances[-1]["embed"].squeeze() + (torch.ones_like(pred_importances[-1]["embed"].squeeze()) - pred_importances[-1]["embed"].squeeze())*R[s,:,l,0,:]
                    layer_mask_unembed = pred_importances[-1]["unembed"].squeeze() + (torch.ones_like(pred_importances[-1]["unembed"].squeeze()) - pred_importances[-1]["unembed"].squeeze()) * R[s,:,l,1,:]
                    
                    for l in range(spd_model.num_layers):
                        # ugh, this is pretty bad. Ideally I'd go back and refactor so that I don't 
                        # have separate in and out weights, but this is what I have. going to keep it
                        # like this for now, but will refactor once I have something that works.
                        # components pred in is shape N, C, 1, squeeze to N, C. there are L of them
                        # R is N, L, C, sample just l to get N, C
                        # I want masks for each component on each layer on each datapoint, so
                        # masks should be one C vector for each layer, per-datapoint. so NLC
    
                        # whoops i'm now putting squeezes on everything. seems like I should just unsqueeze the thing when I store it, but I don't feel like refactoring it like that :( 
                        layer_mask_in = pred_importances[l]["in"].squeeze() + (torch.ones_like(pred_importances[l]["in"].squeeze()) - pred_importances[l]["in"].squeeze()) * R[s,:,l,0,:]
                        layer_mask_out = pred_importances[l]["out"].squeeze() + (torch.ones_like(pred_importances[l]["out"].squeeze()) - pred_importances[l]["out"].squeeze()) * R[s,:,l,1,:]
                        # these are both shape N, C 
                        layer_masks.append({"in": layer_mask_in, "out": layer_mask_out})
                        
                    layer_masks.append({"embed": layer_mask_embed, "unembed": layer_mask_unembed})
                    masked_out, layerwise_masked_outs = spd_model(x, masks=layer_masks)
                    l_stochastic_recon = l_stochastic_recon + torch.linalg.matrix_norm(target_out - masked_out) # uses matrix norm of difference
                    for i in range(len(layerwise_masked_outs)-1):
                        l_stochastic_recon_layerwise = l_stochastic_recon_layerwise + torch.linalg.matrix_norm(target_out - layerwise_masked_outs[i]["in"]) + torch.linalg.matrix_norm(target_out - layerwise_masked_outs[i]["out"])
                    l_stochastic_recon_layerwise = l_stochastic_recon_layerwise + torch.linalg.matrix_norm(target_out - layerwise_masked_outs[-1]["embed"]) + torch.linalg.matrix_norm(target_out - layerwise_masked_outs[-1]["unembed"])

                l_stochastic_recon /= spd_model.hypers["num_mask_samples"]
                l_stochastic_recon_layerwise /= (spd_model.hypers["num_mask_samples"] * spd_model.num_layers)
                
                loss = l_faithfulness + spd_model.hypers["beta_1"] * l_stochastic_recon + spd_model.hypers["beta_2"] * l_stochastic_recon_layerwise + spd_model.hypers["beta_3"] * l_importance_minimality
                
                loss.backward()
                optimizer.step()
    
                total_loss += loss.item() * x.size(0)
                
                if batch_idx % 5 == 0: 
                    print("Masked out min: ", masked_out.min().item(), ", max: ", masked_out.max().item())
                    print(f"Faithfulness: {l_faithfulness}, Stoch Rec: {l_stochastic_recon}, Stoch Rec Layerwise: {l_stochastic_recon_layerwise}, Importance Min: {l_importance_minimality}")
                    total_norm = 0.0
                    for p in spd_model.parameters():
                        if p.grad is not None:
                            param_norm = p.grad.data.norm(2)  # L2 norm of the gradient
                            total_norm += param_norm.item() ** 2
                    
                    total_norm = total_norm ** 0.5
                    print(f"Total gradient norm: {total_norm}")
                
                t.set_postfix(loss=loss.item())
            
        avg_loss = total_loss / len(dataloader.dataset)
        total_l_stoch_rec, total_l_stoch_rec_l, total_l_imp, total_l_faith = total_l_stoch_rec/len(dataloader.dataset), total_l_stoch_rec_l/len(dataloader.dataset), total_l_imp/len(dataloader.dataset), total_l_faith/len(dataloader.dataset)
        print(f"Epoch {epoch+1}: avg loss = {avg_loss:.6f}")
        print(total_l_faith, total_l_stoch_rec.item(), total_l_stoch_rec_l, total_l_imp.item())
