# Stochastic Parameter Decomposition

## Imports

In [1]:
# import all the torch stuff, as well as the ViT stuff and such
# import plotly.express as px
# import torch
# from jaxtyping import Int, Float
# from typing import List, Optional, Tuple
# from tqdm import tqdm
# from transformer_lens.hook_points import HookPoint
# from transformer_lens import utils, HookedTransformer, HookedTransformerConfig, FactoredMatrix, ActivationCache
# import circuitsvis as cv

In [1]:
import torch
from torch import Tensor
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
import numpy as np
import einops
import typing
from tqdm import tqdm

## Toy Model MLP

In [3]:
# create a SPDHookedTransformer Class
# that extends the hookedTransformer
# which inherits all the previous class stuff
# but also contains the SPD training algo?

# how is everything actually structured?
"""
Maybe I should start with implementing it on a simple MLP setup.

Orig network:
- 1 layer MLP 5-2-5 (defined as usual with torch.Sequential presumably)

SPD network:
let's give it 10 subcomponents per layer
then it's just like 10 matmuls?
maybe i should define it like ant did in the toy models of superposition paper
stick them all into a trenchcoat
this might be easier once I have defined the toy modle

"""

"\nMaybe I should start with implementing it on a simple MLP setup. \n\nOrig network:\n- 1 layer MLP 5-2-5 (defined as usual with torch.Sequential presumably)\n\nSPD network:\nlet's give it 10 subcomponents per layer\nthen it's just like 10 matmuls? \nmaybe i should define it like ant did in the toy models of superposition paper\nstick them all into a trenchcoat\nthis might be easier once I have defined the toy modle\n\n"

In [2]:
class ToyResidMLP(nn.Module):
    def __init__(self, config, device="cuda"):
        super().__init__()
        # Initialize Weights for the
        self.num_layers, self.pre_embed_size, self.in_size, self.hidden_size = config["num_layers"], config["pre_embed_size"], config["in_size"], config["hidden_size"]
        self.device = device
        self.W_embed = nn.Parameter(torch.empty((self.pre_embed_size, self.in_size)))
        self.W_unembed = nn.Parameter(torch.empty((self.in_size, self.pre_embed_size)))
        self.W_in = nn.ParameterList([torch.empty((self.in_size, self.hidden_size), device=device) for i in range(self.num_layers)])
        self.W_out = nn.ParameterList([torch.empty((self.hidden_size, self.in_size), device=device) for i in range(self.num_layers)])
        self.b = nn.ParameterList([torch.zeros((self.hidden_size,), device=device) for i in range(self.num_layers)])

        for param in [self.W_embed, self.W_unembed] + list(self.W_in) + list(self.W_out):
            nn.init.xavier_normal_(param)

    def forward(self, x):

        assert x.shape[1] == self.pre_embed_size, f"Input shape {x.shape[0]} does not match model's accepted size {self.pre_embed_size}"
        # embed
        x_resid = torch.einsum("np,pi->ni", x.clone(), self.W_embed)
        N, D = x_resid.shape


        for l in range(self.num_layers):
            hidden = F.relu(torch.einsum("nd,dh -> nh", x_resid, self.W_in[l]) + self.b[l])
            layer_out = torch.einsum("nh,hd -> nd", hidden, self.W_out[l])
            x_resid = x_resid + layer_out
        # am I supposed to have a embed and out?
        x_out = torch.einsum("ni,ip->np", x_resid, self.W_unembed)
        return x_out


# Generated by LLM bc I am lazy and this is not important to me learning stuff atm
class SparseAutoencoderDataset(Dataset):
    """
    Dataset for learning to reconstruct sparse inputs.
    Each item is (input, target) where target is the input itself (or ReLU of input).
    """
    def __init__(self, in_dim=100, n_samples=10000, sparsity=0.9, device="cuda"):
        super().__init__()
        self.in_dim = in_dim
        self.n_samples = n_samples
        self.device = device

        # Pre-generate all samples
        self.inputs = []
        self.targets = []

        for _ in range(n_samples):
            # Sparse input: each entry is -1, 0, or 1, with sparsity
            x = np.random.choice([0, -1, 1], size=(in_dim,), p=[1-sparsity, sparsity/2, sparsity/2])
            x = torch.tensor(x, dtype=torch.float32, device=device)
            target = F.relu(x)

            self.inputs.append(x)
            self.targets.append(target)

    def __len__(self):
        return self.n_samples

    def __getitem__(self, idx):
        return self.inputs[idx], self.targets[idx]

def train_toy_resid_mlp(
    model,
    dataloader,
    lr=1e-3,
    num_epochs=10,
    device="cuda",
    print_every=1
):
    model.train()
    model.to(device)
    optimizer = torch.optim.AdamW(model.parameters(), lr=lr)
    loss_fn = nn.MSELoss()
    for epoch in range(num_epochs):
        total_loss = 0.0
        for x, y in tqdm(dataloader, desc=f"Epoch {epoch+1}/{num_epochs}"):
            x = x.to(device)
            y = y.to(device)
            optimizer.zero_grad()
            y_pred = model(x)
            loss = loss_fn(y_pred, y)
            loss.backward()
            optimizer.step()
            total_loss += loss.item() * x.size(0)
        avg_loss = total_loss / len(dataloader.dataset)
        if (epoch+1) % print_every == 0:
            print(f"Epoch {epoch+1}: avg MSE loss = {avg_loss:.6f}")

In [4]:
## LLM-Generated Usage Example

torch.autograd.set_detect_anomaly(True)

config = {
    "num_layers": 1,
    "pre_embed_size": 100,
    "in_size": 1000,
    "hidden_size": 50,
    "subcomponents_per_layer": 30,
    "beta_1": 1.0,
    "beta_2": 1.0,
    "beta_3": 0.1,
    "causal_imp_min": 1.0,
    "num_mask_samples": 20,
    "importance_mlp_size": 5,
}


if __name__ == "__main__":
    # Config
    # device = "mps" if torch.backends.mps.is_available() else "cpu"
    device = "cpu"
    # Dataset and DataLoader
    # Dataset and DataLoader
    dataset = SparseAutoencoderDataset(
        in_dim=100,
        n_samples=15000,
        sparsity=0.9,
        device=device,
    )
    print(device)

    dataloader = DataLoader(dataset, batch_size=128, shuffle=True)
    # Model
    toy_model = ToyResidMLP(config, device=device)
    # Train
    train_toy_resid_mlp(toy_model, dataloader, lr=8e-2, num_epochs=20, device=device)

cpu


Epoch 1/20: 100%|██████████| 118/118 [00:02<00:00, 57.75it/s]


Epoch 1: avg MSE loss = 245616.226061


Epoch 2/20: 100%|██████████| 118/118 [00:02<00:00, 54.20it/s]


Epoch 2: avg MSE loss = 69.965749


Epoch 3/20: 100%|██████████| 118/118 [00:01<00:00, 59.31it/s]


Epoch 3: avg MSE loss = 22.630260


Epoch 4/20: 100%|██████████| 118/118 [00:01<00:00, 59.38it/s]


Epoch 4: avg MSE loss = 10.240307


Epoch 5/20: 100%|██████████| 118/118 [00:02<00:00, 58.89it/s]


Epoch 5: avg MSE loss = 5.388696


Epoch 6/20: 100%|██████████| 118/118 [00:01<00:00, 60.20it/s]


Epoch 6: avg MSE loss = 3.096830


Epoch 7/20: 100%|██████████| 118/118 [00:02<00:00, 56.40it/s]


Epoch 7: avg MSE loss = 1.902151


Epoch 8/20: 100%|██████████| 118/118 [00:02<00:00, 58.02it/s]


Epoch 8: avg MSE loss = 1.242300


Epoch 9/20: 100%|██████████| 118/118 [00:01<00:00, 60.57it/s]


Epoch 9: avg MSE loss = 0.858383


Epoch 10/20: 100%|██████████| 118/118 [00:01<00:00, 60.03it/s]


Epoch 10: avg MSE loss = 0.630948


Epoch 11/20: 100%|██████████| 118/118 [00:01<00:00, 59.11it/s]


Epoch 11: avg MSE loss = 0.493606


Epoch 12/20: 100%|██████████| 118/118 [00:01<00:00, 59.60it/s]


Epoch 12: avg MSE loss = 0.410911


Epoch 13/20: 100%|██████████| 118/118 [00:02<00:00, 53.97it/s]


Epoch 13: avg MSE loss = 0.362089


Epoch 14/20: 100%|██████████| 118/118 [00:01<00:00, 59.12it/s]


Epoch 14: avg MSE loss = 0.334819


Epoch 15/20: 100%|██████████| 118/118 [00:01<00:00, 60.74it/s]


Epoch 15: avg MSE loss = 0.321945


Epoch 16/20: 100%|██████████| 118/118 [00:01<00:00, 62.09it/s]


Epoch 16: avg MSE loss = 0.318983


Epoch 17/20: 100%|██████████| 118/118 [00:01<00:00, 62.11it/s]


Epoch 17: avg MSE loss = 0.322831


Epoch 18/20: 100%|██████████| 118/118 [00:02<00:00, 56.44it/s]


Epoch 18: avg MSE loss = 0.331379


Epoch 19/20: 100%|██████████| 118/118 [00:01<00:00, 59.59it/s]


Epoch 19: avg MSE loss = 0.343274


Epoch 20/20: 100%|██████████| 118/118 [00:01<00:00, 63.14it/s]

Epoch 20: avg MSE loss = 0.356685





## SPD Model and Train Function

In [6]:
# generated by LLM

class HardSigmoid(nn.Module):
    """
    Implements the hard sigmoid activation function as described in the paper:
        σ_H(x) = 0 if x <= 0
               = x if 0 < x < 1
               = 1 if x >= 1
    This is equivalent to: torch.clamp(x, min=0.0, max=1.0)
    """
    def __init__(self):
        super().__init__()

    def forward(self, x):
        # Clamp values between 0 and 1
        return torch.clamp(x, min=0.0, max=1.0)


class LowerLeakyHardSigmoid(nn.Module):
    """
    Lower-leaky hard sigmoid: σH,lower(x)
    - 0.01*x if x <= 0 (leaky below 0)
    - x if 0 <= x <= 1 (linear in middle)
    - 1 if x >= 1 (saturated above 1)

    Used for forward pass masks in stochastic reconstruction losses.
    """
    def __init__(self, leak_slope=0.01):
        super().__init__()
        self.leak_slope = leak_slope

    def forward(self, x):
        return torch.where(
            x <= 0,
            self.leak_slope * x,
            torch.where(
                x >= 1,
                torch.ones_like(x),
                x
            )
        )

class UpperLeakyHardSigmoid(nn.Module):
    """
    Upper-leaky hard sigmoid: σH,upper(x)
    - 0 if x <= 0 (hard cutoff below 0)
    - x if 0 <= x <= 1 (linear in middle)
    - 1 + 0.01*(x-1) if x >= 1 (leaky above 1)

    Used for importance loss computation.
    """
    def __init__(self, leak_slope=0.01):
        super().__init__()
        self.leak_slope = leak_slope

    def forward(self, x):
        return torch.where(
            x <= 0,
            torch.zeros_like(x),
            torch.where(
                x >= 1,
                1 + self.leak_slope * (x - 1),
                x
            )
        )

# Test the functions
if __name__ == "__main__":
    lower_leaky = LowerLeakyHardSigmoid()
    upper_leaky = UpperLeakyHardSigmoid()

    test_vals = torch.tensor([-2.0, -0.5, 0.0, 0.3, 0.7, 1.0, 1.5, 2.0])

    print("Input:", test_vals)
    print("Lower-leaky:", lower_leaky(test_vals))
    print("Upper-leaky:", upper_leaky(test_vals))

Input: tensor([-2.0000, -0.5000,  0.0000,  0.3000,  0.7000,  1.0000,  1.5000,  2.0000])
Lower-leaky: tensor([-0.0200, -0.0050,  0.0000,  0.3000,  0.7000,  1.0000,  1.0000,  1.0000])
Upper-leaky: tensor([0.0000, 0.0000, 0.0000, 0.3000, 0.7000, 1.0000, 1.0050, 1.0100])


In [6]:
# generated by LLM

class HardSigmoid(nn.Module):
    """
    Implements the hard sigmoid activation function as described in the paper:
        σ_H(x) = 0 if x <= 0
               = x if 0 < x < 1
               = 1 if x >= 1
    This is equivalent to: torch.clamp(x, min=0.0, max=1.0)
    """
    def __init__(self):
        super().__init__()

    def forward(self, x):
        # Clamp values between 0 and 1
        return torch.clamp(x, min=0.0, max=1.0)


class LowerLeakyHardSigmoid(nn.Module):
    """
    Lower-leaky hard sigmoid: σH,lower(x)
    - 0.01*x if x <= 0 (leaky below 0)
    - x if 0 <= x <= 1 (linear in middle)
    - 1 if x >= 1 (saturated above 1)

    Used for forward pass masks in stochastic reconstruction losses.
    """
    def __init__(self, leak_slope=0.01):
        super().__init__()
        self.leak_slope = leak_slope

    def forward(self, x):
        return torch.where(
            x <= 0,
            self.leak_slope * x,
            torch.where(
                x >= 1,
                torch.ones_like(x),
                x
            )
        )

class UpperLeakyHardSigmoid(nn.Module):
    """
    Upper-leaky hard sigmoid: σH,upper(x)
    - 0 if x <= 0 (hard cutoff below 0)
    - x if 0 <= x <= 1 (linear in middle)
    - 1 + 0.01*(x-1) if x >= 1 (leaky above 1)

    Used for importance loss computation.
    """
    def __init__(self, leak_slope=0.01):
        super().__init__()
        self.leak_slope = leak_slope

    def forward(self, x):
        return torch.where(
            x <= 0,
            torch.zeros_like(x),
            torch.where(
                x >= 1,
                1 + self.leak_slope * (x - 1),
                x
            )
        )


class SPDModelMLP(nn.Module):
    # fix the betas stuff below that's basically a type hint
    def __init__(self, target_model, config, device="cuda"):
        super().__init__()

        print("Devices: ", device, target_model.device)
        self.device = device
        object.__setattr__(self, "target_model", target_model) # sets pointer to target_model without registering its parameters as subsidiary
        assert self.device == target_model.device, "Models not on same device"

        self.num_layers, self.pre_embed_size, self.in_size, self.hidden_size, self.imp_hidden_size = config["num_layers"], config["pre_embed_size"], config["in_size"], config["hidden_size"], config["importance_mlp_size"]
        self.C = config["subcomponents_per_layer"]
        self.hypers = dict(list(config.items())[4:]) # sets the "hypers" to contain all the hyperparameters for the model

        # Subcomponent vectors, each of shape C by in_size; to be used
        # with outer product to create our low-rank subcomponent matrices
        self.V_embed = nn.Parameter(torch.empty((self.C, self.pre_embed_size,), device=device))
        self.U_embed = nn.Parameter(torch.empty((self.C, self.in_size,), device=device))
        self.V_unembed = nn.Parameter(torch.empty((self.C, self.in_size,), device=device))
        self.U_unembed = nn.Parameter(torch.empty((self.C, self.pre_embed_size,), device=device))

        self.V_in = nn.ParameterList([torch.empty((self.C, self.in_size,), device=device) for i in range(self.num_layers)])
        self.U_in = nn.ParameterList([torch.empty((self.C, self.hidden_size,), device=device) for i in range(self.num_layers)])
        self.V_out = nn.ParameterList([torch.empty((self.C, self.hidden_size,), device=device) for i in range(self.num_layers)])
        self.U_out = nn.ParameterList([torch.empty((self.C, self.in_size,), device=device) for i in range(self.num_layers)])

        # idk what you do with the biases lol
        self.b = nn.ParameterList([torch.zeros((self.hidden_size,), device=device) for i in range(self.num_layers)])

        # this is so horrible I'm sorry
        # gate_in_in is the gate_in weights for the in subcomponent of each layer
        # gate_in_out is gate_in weights for the out component
        # they each get an extra one on the end which is for the embed matrix
        # Probably should have just registered these as submodels, but I didn't know that those existed before.
        self.imp_W_gate_in_in = nn.ParameterList([torch.empty((self.C, 1, self.imp_hidden_size), device=device) for i in range(self.num_layers+1)])
        self.imp_W_gate_out_in = nn.ParameterList([torch.empty((self.C, self.imp_hidden_size, 1), device=device) for i in range(self.num_layers+1)])
        self.imp_b_in_in = nn.ParameterList([torch.empty((self.C, self.imp_hidden_size), device=device) for i in range(self.num_layers+1)])
        self.imp_b_out_in = nn.ParameterList([torch.empty((self.C, 1), device=device) for i in range(self.num_layers+1)])

        self.imp_W_gate_in_out = nn.ParameterList([torch.empty((self.C, 1, self.imp_hidden_size), device=device) for i in range(self.num_layers+1)])
        self.imp_W_gate_out_out = nn.ParameterList([torch.empty((self.C, self.imp_hidden_size, 1), device=device) for i in range(self.num_layers+1)])
        self.imp_b_in_out = nn.ParameterList([torch.empty((self.C, self.imp_hidden_size), device=device) for i in range(self.num_layers+1)])
        self.imp_b_out_out = nn.ParameterList([torch.empty((self.C, 1), device=device) for i in range(self.num_layers+1)])

        # self.pred_weights = torch.empty((num_layers, self.C), device=device)

        for param in self.parameters():
            # using xavier anyway -- note that the variance etc is
            # changed because we take the outer product in the
            # forward pass
            if param.dim() >= 2: # xavier does not work for 1d tensors
                nn.init.xavier_normal_(param)
            else:
                nn.init.zeros_(param)
                param = param + 0.1


    def forward(self, x, return_activs_and_weights=False, masks=None): # Regular run. Unclear whether I should have masking when I do a regular forward pass.
        v_activations = []
        weight_matrices = []
        x_in = x.clone()
        layerwise_resids = []

        if masks is not None:

            v_activ_embed = torch.einsum("np,cp->nc", x_in, self.V_embed) * masks[-1]["embed"]
            v_activ_embed_nomask = torch.einsum("np,cp->nc", x_in, self.V_embed)

            x_embedded = torch.einsum("nc,ci->ni", v_activ_embed, self.U_embed)
            x_embedded_nomask = torch.einsum("nc,ci->ni", v_activ_embed_nomask, self.U_embed)

            x_resid = x_embedded.clone()

            for l in range(self.num_layers):
                ### Run the forward pass just for this layer
                v_activ_layer_in = torch.einsum("ni,ci->nc", x_resid, self.V_in[l]) * masks[l]["in"] # V activations
                layer_hidden = F.relu(torch.einsum("nc,ch->nh", v_activ_layer_in, self.U_in[l]) + self.b[l]) # U and ReLU
                v_activ_layer_out = torch.einsum("nh,ch->nc", layer_hidden, self.V_out[l]) * masks[l]["out"] # V activ
                layer_out = torch.einsum("nc,ci->ni", v_activ_layer_out, self.U_out[l]) # U
                x_resid = x_resid + layer_out # add to residual stream

                # Run the layerwise forward pass for the whole model, with masking only this layer
                x_resid_layerwise_in = x_embedded_nomask.clone() # pre-embedded
                x_resid_layerwise_out = x_embedded_nomask.clone() # pre-embedded again

                for l_2 in range(self.num_layers): # oh god this is so awful. literally sobbing rn. i can't believe I'm writing it like this
                    if l_2 == l:
                        # run with mask on in
                        v_activ_layer_in_l = torch.einsum("ni,ci->nc", x_resid_layerwise_in, self.V_in[l]) * masks[l]["in"]
                        layer_hidden_l = F.relu(torch.einsum("nc,ch->nh", v_activ_layer_in_l, self.U_in[l])+self.b[l])
                        v_activ_layer_out_l = torch.einsum("nh,ch->nc", layer_hidden_l, self.V_out[l])
                        layer_out_l_masked_in = torch.einsum("nc,ci->ni", v_activ_layer_out_l, self.U_out[l])
                        x_resid_layerwise_in = x_resid_layerwise_in + layer_out_l_masked_in

                        # run with mask on out
                        v_activ_layer_in_l = torch.einsum("ni,ci->nc", x_resid_layerwise_out, self.V_in[l])
                        layer_hidden_l = F.relu(torch.einsum("nc,ch->nh", v_activ_layer_in_l, self.U_in[l])+self.b[l])
                        v_activ_layer_out_l = torch.einsum("nh,ch->nc", layer_hidden_l, self.V_out[l]) * masks[l]["out"]
                        layer_out_l_masked_out = torch.einsum("nc,ci->ni", v_activ_layer_out_l, self.U_out[l])
                        x_resid_layerwise_out = x_resid_layerwise_out + layer_out_l_masked_out


                    else: # THE SAME EXACT THING EXCEPT WITHOUT THE MASKS >_<
                        v_activ_layer_in_l = torch.einsum("ni,ci->nc", x_resid_layerwise_in, self.V_in[l])
                        layer_hidden_l = F.relu(torch.einsum("nc,ch->nh", v_activ_layer_in_l, self.U_in[l]) + self.b[l])
                        v_activ_layer_out_l = torch.einsum("nh,ch->nc", layer_hidden_l, self.V_out[l])
                        layer_out_l = torch.einsum("nc,ci->ni", v_activ_layer_out_l, self.U_out[l])
                        x_resid_layerwise_in = x_resid_layerwise_in + layer_out_l

                        v_activ_layer_in_l = torch.einsum("ni,ci->nc", x_resid_layerwise_out, self.V_in[l])
                        layer_hidden_l = F.relu(torch.einsum("nc,ch->nh", v_activ_layer_in_l, self.U_in[l]) + self.b[l])
                        v_activ_layer_out_l = torch.einsum("nh,ch->nc", layer_hidden_l, self.V_out[l])
                        layer_out_l = torch.einsum("nc,ci->ni", v_activ_layer_out_l, self.U_out[l])
                        x_resid_layerwise_out = x_resid_layerwise_out + layer_out_l

                    x_out_layerwise_in = torch.einsum("nc,cp->np", torch.einsum("ni,ci->nc", x_resid_layerwise_in, self.V_unembed), self.U_unembed)
                    x_out_layerwise_out = torch.einsum("nc,cp->np", torch.einsum("ni,ci->nc", x_resid_layerwise_out, self.V_unembed), self.U_unembed)
                    layerwise_resids.append({"in":x_out_layerwise_in,"out":x_out_layerwise_out})

                # run layerwise for embed and unembed as well, because very intelligently I gave them their own separate weight matrices
                # and they don't get included when iterating through layers
                x_embed_resid = x_embedded.clone()
                for l in range(self.num_layers):
                    v_activ_layer_in_l = torch.einsum("ni,ci->nc", x_embed_resid, self.V_in[l])
                    layer_hidden_l = F.relu(torch.einsum("nc,ch->nh", v_activ_layer_in_l, self.U_in[l]) + self.b[l])
                    v_activ_layer_out_l = torch.einsum("nh,ch->nc", layer_hidden_l, self.V_out[l])
                    layer_out_embed_mask = torch.einsum("nc,ci->ni", v_activ_layer_out_l, self.U_out[l])
                    x_embed_resid = x_embed_resid + layer_out_embed_mask
                x_out_embed_mask = torch.einsum("nc,cp->np", torch.einsum("ni,ci->nc", x_embed_resid, self.V_unembed), self.U_unembed)

                x_unembed_resid = x_embedded_nomask.clone()
                for l in range(self.num_layers):
                    v_activ_layer_in_l = torch.einsum("ni,ci->nc", x_unembed_resid, self.V_in[l])
                    layer_hidden_l = F.relu(torch.einsum("nc,ch->nh", v_activ_layer_in_l, self.U_in[l]) + self.b[l])
                    v_activ_layer_out_l = torch.einsum("nh,ch->nc", layer_hidden_l, self.V_out[l])
                    layer_out_unembed_mask = torch.einsum("nc,ci->ni", v_activ_layer_out_l, self.U_out[l])
                    x_unembed_resid = x_unembed_resid + layer_out_unembed_mask
                x_out_unembed_mask = torch.einsum("nc,cp->np", torch.einsum("ni,ci->nc", x_unembed_resid, self.V_unembed)* masks[-1]["unembed"], self.U_unembed)

                layerwise_resids.append({"embed": x_out_embed_mask, "unembed": x_out_unembed_mask})

            v_activ_unembed = torch.einsum("ni,ci->nc", x_resid, self.V_unembed) * masks[-1]["unembed"]
            x_out = torch.einsum("nc,cp->np", v_activ_unembed, self.U_unembed)


        else:
            W_embed = torch.einsum("cp, ci -> cpi", self.V_embed, self.U_embed).sum(dim=0)
            x_resid = torch.einsum("np, pi -> ni", x_in, W_embed)
            v_activ_embed = torch.einsum("np,cp->nc", x_in, self.V_embed).unsqueeze(-1)

            for l in range(self.num_layers):
                # Use outer product to create weights for the layer, then sum all the subcomponents
                W_in = torch.einsum("ci,ch-> cih", self.V_in[l], self.U_in[l]).sum(dim=0) # shape i h
                W_out = torch.einsum("ch,ci-> chi", self.V_out[l], self.U_out[l]).sum(dim=0) # shape h i
                if return_activs_and_weights == True:
                    weight_matrices.append({"in": W_in, "out": W_out})
                # COMPUTE
                layer_in = F.relu(torch.einsum("ni,ih -> nh", x_resid, W_in) + self.b[l])
                layer_out = torch.einsum("nh,hi->ni", F.relu(layer_in), W_out)

                if return_activs_and_weights == True: # LMAO I HAD IT INDENTED WRONG
                    # calculate activations
                    v_activ_layer_in = torch.einsum("ni,ci->nc", x_resid, self.V_in[l]).unsqueeze(-1)
                    v_activ_layer_out = torch.einsum("nh,ch->nc", layer_in, self.V_out[l]).unsqueeze(-1)
                    # both are now shape n,c,1
                    v_activations.append({"in": v_activ_layer_in, "out": v_activ_layer_out})

                    # code to make sure I'm not doing some goofy shit with the activations at the beginning
                    check_layer_in_uv = F.relu(torch.einsum("nc,ch->nh", v_activ_layer_in.squeeze(), self.U_in[l]) + self.b[l])
                    assert torch.allclose(layer_in, check_layer_in_uv, atol=10), f"Subcomponent activations not calculated correctly, max difference is {torch.max(layer_in-check_layer_in_uv)}; {layer_in}{check_layer_in_uv}"
                    # would in theory have one for the second too but I'm being lazy

            x_resid = x_resid + layer_out
            W_unembed = torch.einsum("ci, cp-> cip", self.V_unembed, self.U_unembed).sum(dim=0)
            x_out = torch.einsum("ni, ip -> np", x_resid, W_unembed)

            if return_activs_and_weights == True:
                weight_matrices.append({"embed": W_embed, "unembed": W_unembed})
                v_activ_unembed = torch.einsum("ni,ci->nc", x_resid, self.V_unembed).unsqueeze(-1)
                v_activations.append({"embed": v_activ_embed, "unembed": v_activ_unembed})
            # putting the embed matrices at the end of the masks. since we iterate through l in range(num_layers) this will not get added! :D

        if return_activs_and_weights:
            return x_out, v_activations, weight_matrices
        elif masks is not None:
            return x_out, layerwise_resids
        else:
            return x_out

loss_history = []
loss_history_stoch_rec = []
loss_history_stoch_rec_layer = []
loss_history_faithfulness = []
loss_history_imp_min = []


def train_SPD(spd_model, dataloader, lr=1e-5, num_epochs=20):

    spd_model.train()
    print(f"Training on device {spd_model.device}")
    optimizer = torch.optim.AdamW(spd_model.parameters(), lr = lr)

    scheduler = torch.optim.lr_scheduler.StepLR(
        optimizer,
        step_size=4,   # e.g. every 4 epochs
        gamma=0.5      # multiply LR by 0.5 each time
    )

    P = sum(p.numel() for p in spd_model.target_model.parameters())

    upper_leaky_sigmoid = UpperLeakyHardSigmoid()
    lower_leaky_sigmoid = LowerLeakyHardSigmoid()

    for epoch in range(num_epochs):
        total_loss = 0.0
        total_l_stoch_rec, total_l_stoch_rec_l, total_l_imp, total_l_faith = 0.0, 0.0, 0.0, 0.0
        print(f"Starting epoch {epoch+1}, lr = {optimizer.param_groups[0]['lr']:.2e}")

        with tqdm(dataloader, desc=f"Epoch {epoch+1}/{num_epochs}") as t:
            for batch_idx, (x,y) in enumerate(t):
                x = x.to(device)
                y = y.to(device)

                N = x.shape[0]     # x is shape N by in_size
                C = spd_model.C
                target_model = spd_model.target_model
                optimizer.zero_grad()

                # COMPUTE TARGET OUTPUT
                with torch.no_grad():
                    target_out = spd_model.target_model(x)

                # FAITHFULNESS LOSS

                spd_output, spd_activations, spd_weights = spd_model(x, return_activs_and_weights = True)
                squared_error = 0

                embed_diff = target_model.W_embed - spd_weights[-1]["embed"]
                unembed_diff = target_model.W_unembed - spd_weights[-1]["unembed"]
                squared_error = squared_error + torch.linalg.matrix_norm(embed_diff) ** 2 + torch.linalg.matrix_norm(unembed_diff) ** 2

                for l in range(spd_model.num_layers):
                    in_diff = target_model.W_in[l] - spd_weights[l]["in"]
                    out_diff = target_model.W_out[l] - spd_weights[l]["out"]

                    # torch.linalg.matrix_norm defaults to the frobenius norm
                    # this takes the frobenius norm of the diff and then squares it
                    squared_error_layer = torch.linalg.matrix_norm(in_diff) ** 2 + torch.linalg.matrix_norm(out_diff) ** 2
                    squared_error = squared_error + squared_error_layer

                mean_squared_error = squared_error/P

                l_faithfulness = mean_squared_error


                ## IMPORTANCE-MINIMALITY LOSS

                pred_importances = []
                l_importance_minimality = 0.0

                components_imp_pred_embed_hidden = F.gelu(torch.einsum("nco,cos->ncs", spd_activations[-1]["embed"], spd_model.imp_W_gate_in_in[-1]) + spd_model.imp_b_in_in[-1])
                components_imp_pred_embed = (torch.einsum("ncs,cso->nco", components_imp_pred_embed_hidden, spd_model.imp_W_gate_out_in[-1]) + spd_model.imp_b_out_in[-1])

                components_imp_pred_unembed_hidden = F.gelu(torch.einsum("nco,cos->ncs", spd_activations[-1]["unembed"], spd_model.imp_W_gate_in_out[-1]) + spd_model.imp_b_in_out[-1])
                components_imp_pred_unembed = (torch.einsum("ncs,cso->nco", components_imp_pred_unembed_hidden, spd_model.imp_W_gate_out_out[-1]) + spd_model.imp_b_out_out[-1])
                l_importance_minimality = l_importance_minimality + (upper_leaky_sigmoid(components_imp_pred_embed).sum() ** spd_model.hypers["causal_imp_min"]) + (upper_leaky_sigmoid(components_imp_pred_unembed).sum() ** spd_model.hypers["causal_imp_min"])

                for l in range(spd_model.num_layers):
                    # both activations are n by c containing dot product so we already have hard_sigmoid
                    # spd_activations[l][inout] is shape n,c,1 (nco)
                    # imp_W_in is c by 1 by imp_size (cos)
                    # want to map to ncs then back to nco
                    # imp_b_in is shape (C, s) so should broadcast to ncs nicely

                    # TODO: DEFINE hard_sigmoid AS A TORCH MODULE SO IT CAN CALCULATE THE DERIVATIVE
                    # in theory should write this as a bunch of models stored in the main model, but that's not how i did it and i've already
                    # written this, shrug
                    # oh my god this is so bad
                    # imp_W_gate_in_in is the gate in for the W_in, etc
                    components_imp_pred_hidden_in = F.gelu(torch.einsum("nco,cos->ncs", spd_activations[l]["in"], spd_model.imp_W_gate_in_in[l]) + spd_model.imp_b_in_in[l])
                    components_pred_layer_in = torch.einsum("ncs,cso->nco", components_imp_pred_hidden_in, spd_model.imp_W_gate_out_in[l]) + spd_model.imp_b_out_in[l]

                    #same thing for the out matrix in layer l
                    components_imp_pred_hidden_out = F.gelu(torch.einsum("nco,cos->ncs", spd_activations[l]["out"], spd_model.imp_W_gate_in_out[l]) + spd_model.imp_b_in_out[l])
                    components_pred_layer_out = (torch.einsum("ncs,cso->nco", components_imp_pred_hidden_out, spd_model.imp_W_gate_out_out[l]) + spd_model.imp_b_out_out[l])

                    pred_importances.append({"in": components_pred_layer_in, "out": components_pred_layer_out})
                    l_importance_minimality = l_importance_minimality + (upper_leaky_sigmoid(components_pred_layer_in).sum() ** spd_model.hypers["causal_imp_min"]) + (upper_leaky_sigmoid(components_pred_layer_out).sum() ** spd_model.hypers["causal_imp_min"])

                pred_importances.append({"embed": components_imp_pred_embed, "unembed": components_imp_pred_unembed})


                l_importance_minimality /= N # divide by N, the batch size, to avg across the batch (notated as B in the paper)


                ## STOCHASTIC RECONSTRUCTION LOSS

                l_stochastic_recon = 0.0
                l_stochastic_recon_layerwise = 0.0
                R = torch.rand((spd_model.hypers["num_mask_samples"], N, spd_model.num_layers+1, 2, C), device=device)

                for s in range(spd_model.hypers["num_mask_samples"]):
                    layer_masks = []
                    # Running this with a for loop. This is slow; I'd ideally just run
                    # something like M = G[None, :, :] + (1 - G[None, :, :]) * R
                    # But I think there's a clarity tradeoff here so I'm just going to do this for now
                    layer_mask_embed = lower_leaky_sigmoid(pred_importances[-1]["embed"].squeeze()) + (torch.ones_like(pred_importances[-1]["embed"].squeeze()) - lower_leaky_sigmoid(pred_importances[-1]["embed"].squeeze()))*R[s,:,spd_model.num_layers,0,:]
                    layer_mask_unembed = lower_leaky_sigmoid(pred_importances[-1]["unembed"].squeeze()) + (torch.ones_like(pred_importances[-1]["unembed"].squeeze()) - lower_leaky_sigmoid(pred_importances[-1]["unembed"].squeeze())) * R[s,:,spd_model.num_layers,1,:]

                    for l in range(spd_model.num_layers):
                        # ugh, this is pretty bad. Ideally I'd go back and refactor so that I don't
                        # have separate in and out weights, but this is what I have. going to keep it
                        # like this for now, but will refactor once I have something that works.
                        # components pred in is shape N, C, 1, squeeze to N, C. there are L of them
                        # R is N, L, C, sample just l to get N, C
                        # I want masks for each component on each layer on each datapoint, so
                        # masks should be one C vector for each layer, per-datapoint. so NLC

                        # whoops i'm now putting squeezes on everything. seems like I should just unsqueeze the thing when I store it, but I don't feel like refactoring it like that :(
                        layer_mask_in = lower_leaky_sigmoid(pred_importances[l]["in"].squeeze()) + (torch.ones_like(pred_importances[l]["in"].squeeze()) - lower_leaky_sigmoid(pred_importances[l]["in"].squeeze())) * R[s,:,l,0,:]
                        layer_mask_out = lower_leaky_sigmoid(pred_importances[l]["out"].squeeze()) + (torch.ones_like(pred_importances[l]["out"].squeeze()) - lower_leaky_sigmoid(pred_importances[l]["out"].squeeze())) * R[s,:,l,1,:]
                        # these are both shape N, C
                        layer_masks.append({"in": layer_mask_in, "out": layer_mask_out})

                    layer_masks.append({"embed": layer_mask_embed, "unembed": layer_mask_unembed})
                    masked_out, layerwise_masked_outs = spd_model(x, masks=layer_masks)
                    l_stochastic_recon = l_stochastic_recon + torch.linalg.matrix_norm(target_out - masked_out) ** 2 # uses matrix norm of difference
                    for i in range(len(layerwise_masked_outs)-1):
                        l_stochastic_recon_layerwise = l_stochastic_recon_layerwise + torch.linalg.matrix_norm(target_out - layerwise_masked_outs[i]["in"]) ** 2 + torch.linalg.matrix_norm(target_out - layerwise_masked_outs[i]["out"]) **2
                        # l_stochastic_recon_layerwise = l_stochastic_recon_layerwise + (target_out - layerwise_masked_outs[i]["in"]) ** 2 + (target_out - layerwise_masked_outs[i]["out"]) ** 2 # switch to MSE

                    l_stochastic_recon_layerwise = l_stochastic_recon_layerwise + torch.linalg.matrix_norm(target_out - layerwise_masked_outs[-1]["embed"]) **2 + torch.linalg.matrix_norm(target_out - layerwise_masked_outs[-1]["unembed"]) **2
                    # l_stochastic_recon_layerwise = l_stochastic_recon_layerwise + (target_out - layerwise_masked_outs[-1]["embed"]) ** 2 + (target_out - layerwise_masked_outs[-1]["unembed"]) ** 2


                l_stochastic_recon /= spd_model.hypers["num_mask_samples"]
                l_stochastic_recon_layerwise /= (spd_model.hypers["num_mask_samples"] * (2 * spd_model.num_layers + 2) ) #one in/out matrix per layer + embed & unembed

                loss = l_faithfulness + spd_model.hypers["beta_1"] * l_stochastic_recon + spd_model.hypers["beta_2"] * l_stochastic_recon_layerwise + spd_model.hypers["beta_3"] * l_importance_minimality

                loss.backward()
                # torch.nn.utils.clip_grad_norm_(spd_model.parameters(), max_norm=1.0)


                if batch_idx % 118 == 0:
                    with torch.no_grad():

                        G = lower_leaky_sigmoid(pred_importances[0]["in"]).squeeze()
                        print("pred_imp min/max:", pred_importances[0]["in"].min(), pred_importances[0]["in"].max())
                        print("mask min/max:", layer_masks[0]["in"].min(), layer_masks[0]["in"].max())

                        loss_history.append(l_faithfulness.item())
                        loss_history_faithfulness.append(l_faithfulness.item())
                        loss_history_stoch_rec.append(l_stochastic_recon.item())
                        loss_history_stoch_rec_layer.append(l_stochastic_recon_layerwise.item())
                        loss_history_imp_min.append(l_importance_minimality.item())

                        print("Masked out min: ", masked_out.min().item(), ", max: ", masked_out.max().item())
                        print(f"Faithfulness: {l_faithfulness}, Stoch Rec: {l_stochastic_recon}, Stoch Rec Layerwise: {l_stochastic_recon_layerwise}, Importance Min: {l_importance_minimality}")
                        total_norm = 0.0
                        for p in spd_model.parameters():
                            if p.grad is not None:
                                param_norm = p.grad.data.norm(2)  # L2 norm of the gradient
                                total_norm += param_norm.item() ** 2

                        total_norm = total_norm ** 0.5
                        print(f"Total gradient norm: {total_norm}")

                optimizer.step()

                total_loss += loss.item() * x.size(0)

                t.set_postfix(loss=loss.item())

        avg_loss = total_loss / len(dataloader.dataset)
        total_l_stoch_rec, total_l_stoch_rec_l, total_l_imp, total_l_faith = total_l_stoch_rec/len(dataloader.dataset), total_l_stoch_rec_l/len(dataloader.dataset), total_l_imp/len(dataloader.dataset), total_l_faith/len(dataloader.dataset)
        print(total_l_faith, total_l_stoch_rec, total_l_stoch_rec_l, total_l_imp)
        scheduler.step()




if __name__ == "__main__":
    # Config
    #device = "cuda" if torch.cuda.is_available() else "mps" if torch.backends.mps.is_available() else "cpu"
    device = "cpu"
    config = {
        "num_layers": 1,
        "pre_embed_size": 100,
        "in_size": 1000,
        "hidden_size": 50,
        "subcomponents_per_layer": 30,
        "beta_1": 1.0,
        "beta_2": 1.0,
        "beta_3": 0.1,
        "causal_imp_min": 1.0,
        "num_mask_samples": 20,
        "importance_mlp_size": 5,
    }

    dataloader = DataLoader(dataset, batch_size=128, shuffle=True)
    # Model
    spd_model = SPDModelMLP(toy_model, config, "cpu")
    # Train
    train_SPD(spd_model, dataloader, lr=8e-3, num_epochs=25)

Devices:  cpu cpu
Training on device cpu
Starting epoch 1, lr = 8.00e-03


Epoch 1/25:   1%|          | 1/118 [00:02<04:50,  2.48s/it, loss=1.39e+3]

pred_imp min/max: tensor(-0.5432) tensor(0.4992)
mask min/max: tensor(-0.0022) tensor(0.9999)
Masked out min:  -0.49410781264305115 , max:  0.43328431248664856
Faithfulness: 0.0018302702810615301, Stoch Rec: 427.2201232910156, Stoch Rec Layerwise: 960.5885620117188, Importance Min: 15.379110336303711
Total gradient norm: 2491.97116503356


Epoch 1/25: 100%|██████████| 118/118 [04:40<00:00,  2.38s/it, loss=22.8]


0.0 0.0 0.0 0.0
Starting epoch 2, lr = 8.00e-03


Epoch 2/25:   1%|          | 1/118 [00:02<05:14,  2.69s/it, loss=102]

pred_imp min/max: tensor(-1.7503) tensor(1.4362)
mask min/max: tensor(-0.0015) tensor(1.)
Masked out min:  -0.2612195611000061 , max:  0.8695976138114929
Faithfulness: 0.0013476668391376734, Stoch Rec: 51.66410446166992, Stoch Rec Layerwise: 46.16731643676758, Importance Min: 41.769378662109375
Total gradient norm: 175.56545572021062


Epoch 2/25: 100%|██████████| 118/118 [04:40<00:00,  2.38s/it, loss=15]


0.0 0.0 0.0 0.0
Starting epoch 3, lr = 8.00e-03


Epoch 3/25:   1%|          | 1/118 [00:02<04:31,  2.32s/it, loss=65.2]

pred_imp min/max: tensor(-1.6178) tensor(2.1596)
mask min/max: tensor(-0.0055) tensor(1.)
Masked out min:  -0.2901337742805481 , max:  0.840610921382904
Faithfulness: 0.0012874448439106345, Stoch Rec: 30.950634002685547, Stoch Rec Layerwise: 28.87361717224121, Importance Min: 53.40042495727539
Total gradient norm: 230.86277460665954


Epoch 3/25: 100%|██████████| 118/118 [04:39<00:00,  2.37s/it, loss=13.6]


0.0 0.0 0.0 0.0
Starting epoch 4, lr = 8.00e-03


Epoch 4/25:   1%|          | 1/118 [00:02<04:30,  2.31s/it, loss=54.1]

pred_imp min/max: tensor(-1.9547) tensor(3.1598)
mask min/max: tensor(0.0003) tensor(1.)
Masked out min:  -0.30887770652770996 , max:  0.8940812945365906
Faithfulness: 0.0012575257569551468, Stoch Rec: 25.275955200195312, Stoch Rec Layerwise: 23.47529411315918, Importance Min: 53.85778045654297
Total gradient norm: 187.58192810135733


Epoch 4/25: 100%|██████████| 118/118 [04:39<00:00,  2.37s/it, loss=13.5]


0.0 0.0 0.0 0.0
Starting epoch 5, lr = 4.00e-03


Epoch 5/25:   1%|          | 1/118 [00:02<05:12,  2.67s/it, loss=45.5]

pred_imp min/max: tensor(-1.9458) tensor(2.2914)
mask min/max: tensor(-0.0034) tensor(1.)
Masked out min:  -0.4131135046482086 , max:  0.8017029762268066
Faithfulness: 0.001228934619575739, Stoch Rec: 20.911834716796875, Stoch Rec Layerwise: 19.14470672607422, Importance Min: 54.56949234008789
Total gradient norm: 86.25928470924507


Epoch 5/25: 100%|██████████| 118/118 [04:40<00:00,  2.38s/it, loss=12.3]


0.0 0.0 0.0 0.0
Starting epoch 6, lr = 4.00e-03


Epoch 6/25:   1%|          | 1/118 [00:02<04:31,  2.32s/it, loss=42.8]

pred_imp min/max: tensor(-2.8286) tensor(2.9667)
mask min/max: tensor(-0.0057) tensor(1.)
Masked out min:  -0.4108780026435852 , max:  0.65569007396698
Faithfulness: 0.0012213274603709579, Stoch Rec: 19.574831008911133, Stoch Rec Layerwise: 17.68197250366211, Importance Min: 55.02389144897461
Total gradient norm: 76.2337629171995


Epoch 6/25: 100%|██████████| 118/118 [04:40<00:00,  2.38s/it, loss=12.6]


0.0 0.0 0.0 0.0
Starting epoch 7, lr = 4.00e-03


Epoch 7/25:   1%|          | 1/118 [00:02<04:28,  2.29s/it, loss=41.2]

pred_imp min/max: tensor(-1.7720) tensor(2.4049)
mask min/max: tensor(-0.0021) tensor(1.)
Masked out min:  -0.3597259223461151 , max:  0.761620283126831
Faithfulness: 0.0012131613912060857, Stoch Rec: 18.714948654174805, Stoch Rec Layerwise: 16.90483283996582, Importance Min: 55.32904052734375
Total gradient norm: 78.00765288851747


Epoch 7/25: 100%|██████████| 118/118 [04:40<00:00,  2.37s/it, loss=12.4]


0.0 0.0 0.0 0.0
Starting epoch 8, lr = 4.00e-03


Epoch 8/25:   1%|          | 1/118 [00:02<05:12,  2.67s/it, loss=40.6]

pred_imp min/max: tensor(-2.2780) tensor(2.8542)
mask min/max: tensor(-0.0035) tensor(1.)
Masked out min:  -0.4139865040779114 , max:  0.8102222681045532
Faithfulness: 0.0012052588863298297, Stoch Rec: 18.010154724121094, Stoch Rec Layerwise: 16.773448944091797, Importance Min: 58.292762756347656
Total gradient norm: 53.72612105276772


Epoch 8/25: 100%|██████████| 118/118 [04:40<00:00,  2.38s/it, loss=12.3]


0.0 0.0 0.0 0.0
Starting epoch 9, lr = 2.00e-03


Epoch 9/25:   1%|          | 1/118 [00:02<04:29,  2.31s/it, loss=39.8]

pred_imp min/max: tensor(-2.0257) tensor(3.1190)
mask min/max: tensor(-0.0012) tensor(1.)
Masked out min:  -0.33889704942703247 , max:  0.7203965783119202
Faithfulness: 0.0011970062041655183, Stoch Rec: 17.234954833984375, Stoch Rec Layerwise: 16.594478607177734, Importance Min: 59.33014678955078
Total gradient norm: 79.38111243781908


Epoch 9/25: 100%|██████████| 118/118 [04:39<00:00,  2.37s/it, loss=12.3]


0.0 0.0 0.0 0.0
Starting epoch 10, lr = 2.00e-03


Epoch 10/25:   1%|          | 1/118 [00:02<04:28,  2.30s/it, loss=37.7]

pred_imp min/max: tensor(-2.3798) tensor(2.4826)
mask min/max: tensor(-0.0034) tensor(1.)
Masked out min:  -0.36474189162254333 , max:  0.6204657554626465
Faithfulness: 0.0011943482095375657, Stoch Rec: 16.137128829956055, Stoch Rec Layerwise: 15.59235668182373, Importance Min: 59.794029235839844
Total gradient norm: 39.31164905544749


Epoch 10/25: 100%|██████████| 118/118 [04:39<00:00,  2.37s/it, loss=11.8]


0.0 0.0 0.0 0.0
Starting epoch 11, lr = 2.00e-03


Epoch 11/25:   1%|          | 1/118 [00:02<05:11,  2.66s/it, loss=38.4]

pred_imp min/max: tensor(-2.3439) tensor(3.3763)
mask min/max: tensor(-0.0054) tensor(1.)
Masked out min:  -0.42482319474220276 , max:  0.8554617762565613
Faithfulness: 0.001190854818560183, Stoch Rec: 16.5850830078125, Stoch Rec Layerwise: 15.883773803710938, Importance Min: 59.534976959228516
Total gradient norm: 44.088907997087176


Epoch 11/25: 100%|██████████| 118/118 [04:40<00:00,  2.38s/it, loss=12.2]


0.0 0.0 0.0 0.0
Starting epoch 12, lr = 2.00e-03


Epoch 12/25:   1%|          | 1/118 [00:02<04:29,  2.30s/it, loss=38]

pred_imp min/max: tensor(-1.8942) tensor(2.4391)
mask min/max: tensor(-0.0074) tensor(1.)
Masked out min:  -0.3182293772697449 , max:  0.9881288409233093
Faithfulness: 0.0011874587507918477, Stoch Rec: 16.289627075195312, Stoch Rec Layerwise: 15.744044303894043, Importance Min: 59.250038146972656
Total gradient norm: 52.8275046152963


Epoch 12/25: 100%|██████████| 118/118 [04:41<00:00,  2.38s/it, loss=11.9]


0.0 0.0 0.0 0.0
Starting epoch 13, lr = 1.00e-03


Epoch 13/25:   1%|          | 1/118 [00:02<04:41,  2.41s/it, loss=38.7]

pred_imp min/max: tensor(-1.8829) tensor(3.1376)
mask min/max: tensor(-0.0044) tensor(1.)
Masked out min:  -0.3526413142681122 , max:  0.8232038021087646
Faithfulness: 0.0011842931853607297, Stoch Rec: 16.653886795043945, Stoch Rec Layerwise: 16.096633911132812, Importance Min: 59.05140686035156
Total gradient norm: 44.951287555002125


Epoch 13/25: 100%|██████████| 118/118 [04:39<00:00,  2.37s/it, loss=11.7]


0.0 0.0 0.0 0.0
Starting epoch 14, lr = 1.00e-03


Epoch 14/25:   1%|          | 1/118 [00:02<04:47,  2.46s/it, loss=38.9]

pred_imp min/max: tensor(-2.4398) tensor(3.2440)
mask min/max: tensor(-0.0020) tensor(1.)
Masked out min:  -0.3931247293949127 , max:  0.6370551586151123
Faithfulness: 0.0011827393900603056, Stoch Rec: 16.75955581665039, Stoch Rec Layerwise: 16.235301971435547, Importance Min: 58.92161178588867
Total gradient norm: 41.745517683541806


Epoch 14/25: 100%|██████████| 118/118 [04:39<00:00,  2.37s/it, loss=11.9]


0.0 0.0 0.0 0.0
Starting epoch 15, lr = 1.00e-03


Epoch 15/25:   1%|          | 1/118 [00:02<04:30,  2.31s/it, loss=37.4]

pred_imp min/max: tensor(-2.3556) tensor(3.2931)
mask min/max: tensor(-0.0031) tensor(1.)
Masked out min:  -0.3528532087802887 , max:  0.7803086042404175
Faithfulness: 0.0011812610318884254, Stoch Rec: 15.981236457824707, Stoch Rec Layerwise: 15.46411418914795, Importance Min: 59.079986572265625
Total gradient norm: 46.26904320493297


Epoch 15/25: 100%|██████████| 118/118 [04:38<00:00,  2.36s/it, loss=12.2]


0.0 0.0 0.0 0.0
Starting epoch 16, lr = 1.00e-03


Epoch 16/25:   1%|          | 1/118 [00:02<04:36,  2.36s/it, loss=37.4]

pred_imp min/max: tensor(-1.9932) tensor(3.4312)
mask min/max: tensor(-0.0016) tensor(1.)
Masked out min:  -0.38177233934402466 , max:  0.867191731929779
Faithfulness: 0.0011797931510955095, Stoch Rec: 15.928387641906738, Stoch Rec Layerwise: 15.545616149902344, Importance Min: 59.27867889404297
Total gradient norm: 38.119449864821625


Epoch 16/25: 100%|██████████| 118/118 [04:39<00:00,  2.36s/it, loss=12]


0.0 0.0 0.0 0.0
Starting epoch 17, lr = 5.00e-04


Epoch 17/25:   1%|          | 1/118 [00:02<04:45,  2.44s/it, loss=37.2]

pred_imp min/max: tensor(-2.2368) tensor(2.9023)
mask min/max: tensor(-0.0077) tensor(1.)
Masked out min:  -0.3665906488895416 , max:  0.811331570148468
Faithfulness: 0.0011781958164647222, Stoch Rec: 15.753721237182617, Stoch Rec Layerwise: 15.472018241882324, Importance Min: 59.724761962890625
Total gradient norm: 49.74023960609668


Epoch 17/25: 100%|██████████| 118/118 [04:38<00:00,  2.36s/it, loss=12]


0.0 0.0 0.0 0.0
Starting epoch 18, lr = 5.00e-04


Epoch 18/25:   1%|          | 1/118 [00:02<04:30,  2.31s/it, loss=37]

pred_imp min/max: tensor(-2.3253) tensor(3.1987)
mask min/max: tensor(-0.0043) tensor(1.)
Masked out min:  -0.3708230257034302 , max:  0.7845755815505981
Faithfulness: 0.0011771713616326451, Stoch Rec: 15.65058708190918, Stoch Rec Layerwise: 15.361968994140625, Importance Min: 59.72987747192383
Total gradient norm: 36.03940226925164


Epoch 18/25: 100%|██████████| 118/118 [04:39<00:00,  2.37s/it, loss=11.8]


0.0 0.0 0.0 0.0
Starting epoch 19, lr = 5.00e-04


Epoch 19/25:   1%|          | 1/118 [00:02<04:27,  2.29s/it, loss=37.5]

pred_imp min/max: tensor(-1.9934) tensor(2.8908)
mask min/max: tensor(-0.0173) tensor(1.)
Masked out min:  -0.39411628246307373 , max:  0.7565032243728638
Faithfulness: 0.0011763253714889288, Stoch Rec: 15.840815544128418, Stoch Rec Layerwise: 15.654398918151855, Importance Min: 59.68240737915039
Total gradient norm: 38.20987445335886


Epoch 19/25: 100%|██████████| 118/118 [04:39<00:00,  2.37s/it, loss=11.5]


0.0 0.0 0.0 0.0
Starting epoch 20, lr = 5.00e-04


Epoch 20/25:   1%|          | 1/118 [00:02<04:58,  2.55s/it, loss=37]

pred_imp min/max: tensor(-2.2634) tensor(3.1947)
mask min/max: tensor(-0.0084) tensor(1.)
Masked out min:  -0.4205290973186493 , max:  1.0234887599945068
Faithfulness: 0.0011757775209844112, Stoch Rec: 15.6609468460083, Stoch Rec Layerwise: 15.402249336242676, Importance Min: 59.760562896728516
Total gradient norm: 39.57675552012207


Epoch 20/25: 100%|██████████| 118/118 [04:38<00:00,  2.36s/it, loss=11.6]


0.0 0.0 0.0 0.0
Starting epoch 21, lr = 2.50e-04


Epoch 21/25:   1%|          | 1/118 [00:02<04:29,  2.30s/it, loss=37.1]

pred_imp min/max: tensor(-2.1476) tensor(3.6193)
mask min/max: tensor(-0.0082) tensor(1.)
Masked out min:  -0.4715726971626282 , max:  0.9158615469932556
Faithfulness: 0.0011752414284273982, Stoch Rec: 15.694682121276855, Stoch Rec Layerwise: 15.458259582519531, Importance Min: 59.56167221069336
Total gradient norm: 41.448426852835325


Epoch 21/25: 100%|██████████| 118/118 [04:39<00:00,  2.37s/it, loss=12]


0.0 0.0 0.0 0.0
Starting epoch 22, lr = 2.50e-04


Epoch 22/25:   1%|          | 1/118 [00:02<04:29,  2.31s/it, loss=36.8]

pred_imp min/max: tensor(-1.9975) tensor(3.0239)
mask min/max: tensor(-0.0087) tensor(1.)
Masked out min:  -0.4510664939880371 , max:  0.8957698345184326
Faithfulness: 0.0011746195377781987, Stoch Rec: 15.582290649414062, Stoch Rec Layerwise: 15.309797286987305, Importance Min: 59.551883697509766
Total gradient norm: 33.65965425383847


Epoch 22/25: 100%|██████████| 118/118 [04:38<00:00,  2.36s/it, loss=11.7]


0.0 0.0 0.0 0.0
Starting epoch 23, lr = 2.50e-04


Epoch 23/25:   1%|          | 1/118 [00:02<05:00,  2.57s/it, loss=37.2]

pred_imp min/max: tensor(-2.0512) tensor(3.0905)
mask min/max: tensor(-0.0055) tensor(1.)
Masked out min:  -0.3820277452468872 , max:  0.6254575848579407
Faithfulness: 0.0011742596980184317, Stoch Rec: 15.768109321594238, Stoch Rec Layerwise: 15.476945877075195, Importance Min: 59.45616149902344
Total gradient norm: 32.2865267979162


Epoch 23/25: 100%|██████████| 118/118 [04:38<00:00,  2.36s/it, loss=11.9]


0.0 0.0 0.0 0.0
Starting epoch 24, lr = 2.50e-04


Epoch 24/25:   1%|          | 1/118 [00:02<04:33,  2.34s/it, loss=37.3]

pred_imp min/max: tensor(-2.0241) tensor(3.6201)
mask min/max: tensor(-0.0090) tensor(1.)
Masked out min:  -0.4084336459636688 , max:  0.9107349514961243
Faithfulness: 0.0011739039327949286, Stoch Rec: 15.844325065612793, Stoch Rec Layerwise: 15.5100736618042, Importance Min: 59.4804801940918
Total gradient norm: 32.01846260660144


Epoch 24/25: 100%|██████████| 118/118 [04:39<00:00,  2.37s/it, loss=12]


0.0 0.0 0.0 0.0
Starting epoch 25, lr = 1.25e-04


Epoch 25/25:   1%|          | 1/118 [00:02<04:29,  2.30s/it, loss=37.8]

pred_imp min/max: tensor(-1.9682) tensor(2.8519)
mask min/max: tensor(-0.0037) tensor(1.)
Masked out min:  -0.37912943959236145 , max:  0.8561169505119324
Faithfulness: 0.0011736595770344138, Stoch Rec: 16.0897159576416, Stoch Rec Layerwise: 15.774398803710938, Importance Min: 59.39533615112305
Total gradient norm: 28.509992709340263


Epoch 25/25: 100%|██████████| 118/118 [04:38<00:00,  2.36s/it, loss=11.9]

0.0 0.0 0.0 0.0





In [None]:
eval_dataset = SparseAutoencoderDataset(in_dim=100, n_samples=100, sparsity=0.9, device="cuda")
dataloader = DataLoader(eval_dataset, batch_size=128, shuffle=True)
batch = next(iter(dataloader))

out_eval = spd_model(batch[0])
out_eval_target = spd_model.target_model(batch[0])
print("Original Model: \n", out_eval)
print("SPD Model: \n", out_eval_target)

def check_masks(x, spd_model):
    pred_importances = []
    spd_output, spd_activations, spd_weights = spd_model(x, return_activs_and_weights = True)
    components_imp_pred_embed_hidden = F.gelu(torch.einsum("nco,cos->ncs", spd_activations[-1]["embed"], spd_model.imp_W_gate_in_in[-1]) + spd_model.imp_b_in_in[-1])
    components_imp_pred_embed = (torch.einsum("ncs,cso->nco", components_imp_pred_embed_hidden, spd_model.imp_W_gate_out_in[-1]) + spd_model.imp_b_out_in[-1])

    components_imp_pred_unembed_hidden = F.gelu(torch.einsum("nco,cos->ncs", spd_activations[-1]["unembed"], spd_model.imp_W_gate_in_out[-1]) + spd_model.imp_b_in_out[-1])
    components_imp_pred_unembed = (torch.einsum("ncs,cso->nco", components_imp_pred_unembed_hidden, spd_model.imp_W_gate_out_out[-1]) + spd_model.imp_b_out_out[-1])

    for l in range(spd_model.num_layers):
        components_imp_pred_hidden_in = F.gelu(torch.einsum("nco,cos->ncs", spd_activations[l]["in"], spd_model.imp_W_gate_in_in[l]) + spd_model.imp_b_in_in[l])
        components_pred_layer_in = torch.einsum("ncs,cso->nco", components_imp_pred_hidden_in, spd_model.imp_W_gate_out_in[l]) + spd_model.imp_b_out_in[l]
        components_imp_pred_hidden_out = F.gelu(torch.einsum("nco,cos->ncs", spd_activations[l]["out"], spd_model.imp_W_gate_in_out[l]) + spd_model.imp_b_in_out[l])
        components_pred_layer_out = (torch.einsum("ncs,cso->nco", components_imp_pred_hidden_out, spd_model.imp_W_gate_out_out[l]) + spd_model.imp_b_out_out[l])

    pred_importances.append({"in": components_pred_layer_in, "out": components_pred_layer_out})
    pred_importances.append({"embed": components_imp_pred_embed, "unembed": components_imp_pred_unembed})

    return pred_importances

importances = check_masks(batch[0], spd_model)
print(importances[0]["in"], importances[0]["out"])

## Plotting Data

In [None]:
import matplotlib.pyplot as plt
import numpy as np


"""
for history in [loss_history, loss_history_stoch_rec, loss_history_stoch_rec_layer,loss_history_faithfulness,loss_history_imp_min]:
    for idx, item in enumerate(history):
        history[idx] = item.item()
"""

def plot_loss_histories(loss_history, loss_history_stoch_rec, loss_history_stoch_rec_layer,
                       loss_history_faithfulness, loss_history_imp_min, save_interval=10):
    """
    Plot loss histories from saved lists

    Args:
        loss_history: List of total losses
        loss_history_stoch_rec: List of stochastic reconstruction losses
        loss_history_stoch_rec_layer: List of stochastic reconstruction layer losses
        loss_history_faithfulness: List of faithfulness losses
        loss_history_imp_min: List of importance min losses
        save_interval: How often losses were saved (default 10 batches)
    """

    # Create batch numbers (assuming you save every save_interval batches)
    batch_numbers = [(i + 1) * save_interval for i in range(len(loss_history))]

    # Create figure with subplots
    fig, ((ax1, ax2), (ax3, ax4), (ax5, ax6)) = plt.subplots(3, 2, figsize=(15, 12))

    # Plot total loss
    ax1.plot(batch_numbers, loss_history, 'b-', linewidth=2, marker='o', markersize=4)
    ax1.set_title('Total Loss', fontsize=14, fontweight='bold')
    ax1.set_xlabel('Batch Number')
    ax1.set_ylabel('Loss')
    ax1.grid(True, alpha=0.3)
    if max(loss_history) > 100:  # Use log scale for large values
        ax1.set_yscale('log')

    # Plot faithfulness loss
    ax2.plot(batch_numbers, loss_history_faithfulness, 'r-', linewidth=2, marker='o', markersize=4)
    ax2.set_title('Faithfulness Loss', fontsize=14, fontweight='bold')
    ax2.set_xlabel('Batch Number')
    ax2.set_ylabel('Loss')
    ax2.grid(True, alpha=0.3)
    if max(loss_history_faithfulness) > 100:
        ax2.set_yscale('log')

    # Plot stochastic reconstruction loss
    ax3.plot(batch_numbers, loss_history_stoch_rec, 'g-', linewidth=2, marker='o', markersize=4)
    ax3.set_title('Stochastic Reconstruction Loss', fontsize=14, fontweight='bold')
    ax3.set_xlabel('Batch Number')
    ax3.set_ylabel('Loss')
    ax3.grid(True, alpha=0.3)

    # Plot stochastic reconstruction layer loss
    ax4.plot(batch_numbers, loss_history_stoch_rec_layer, 'm-', linewidth=2, marker='o', markersize=4)
    ax4.set_title('Stochastic Reconstruction Layer Loss', fontsize=14, fontweight='bold')
    ax4.set_xlabel('Batch Number')
    ax4.set_ylabel('Loss')
    ax4.grid(True, alpha=0.3)
    if max(loss_history_stoch_rec_layer) > 100:
        ax4.set_yscale('log')

    # Plot importance min loss
    ax5.plot(batch_numbers, loss_history_imp_min, 'c-', linewidth=2, marker='o', markersize=4)
    ax5.set_title('Importance Min Loss', fontsize=14, fontweight='bold')
    ax5.set_xlabel('Batch Number')
    ax5.set_ylabel('Loss')
    ax5.grid(True, alpha=0.3)

    # Plot all losses together (normalized)
    max_vals = [max(loss_history), max(loss_history_faithfulness), max(loss_history_stoch_rec),
                max(loss_history_stoch_rec_layer), max(loss_history_imp_min)]

    if all(m > 0 for m in max_vals):  # Only plot if all have valid max values
        ax6.plot(batch_numbers, np.array(loss_history)/max(loss_history), 'b-',
                linewidth=2, label='Total Loss', marker='o', markersize=3)
        ax6.plot(batch_numbers, np.array(loss_history_faithfulness)/max(loss_history_faithfulness), 'r-',
                linewidth=2, label='Faithfulness', marker='s', markersize=3)
        ax6.plot(batch_numbers, np.array(loss_history_stoch_rec)/max(loss_history_stoch_rec), 'g-',
                linewidth=2, label='Stoch Rec', marker='^', markersize=3)
        ax6.plot(batch_numbers, np.array(loss_history_stoch_rec_layer)/max(loss_history_stoch_rec_layer), 'm-',
                linewidth=2, label='Stoch Rec Layer', marker='d', markersize=3)
        ax6.plot(batch_numbers, np.array(loss_history_imp_min)/max(loss_history_imp_min), 'c-',
                linewidth=2, label='Importance Min', marker='v', markersize=3)

    ax6.set_title('All Losses (Normalized)', fontsize=14, fontweight='bold')
    ax6.set_xlabel('Batch Number')
    ax6.set_ylabel('Normalized Loss')
    ax6.legend()
    ax6.grid(True, alpha=0.3)

    plt.tight_layout()
    plt.show()

    return fig


fig = plot_loss_histories(loss_history, loss_history_stoch_rec, loss_history_stoch_rec_layer,
                         loss_history_faithfulness, loss_history_imp_min)
# Example usage:
"""
# In your training loop, save losses every 10 batches like this:
if batch_idx % 10 == 0:
    loss_history.append(total_loss.item())
    loss_history_stoch_rec.append(stoch_rec_loss.item())
    loss_history_stoch_rec_layer.append(stoch_rec_layer_loss.item())
    loss_history_faithfulness.append(faithfulness_loss.item())
    loss_history_imp_min.append(imp_min_loss.item())

# Then plot:
fig = plot_loss_histories(loss_history, loss_history_stoch_rec, loss_history_stoch_rec_layer,
                         loss_history_faithfulness, loss_history_imp_min)
"""

# Quick test with dummy data:
def test_plotter():
    # Create some example data
    n_points = 50
    batch_nums = range(10, 10 + n_points * 10, 10)

    # Simulate decreasing losses with some noise
    loss_history = [1000 * np.exp(-0.1 * i) + np.random.normal(0, 50) for i in range(n_points)]
    loss_history_faithfulness = [200 * np.exp(-0.15 * i) + np.random.normal(0, 10) for i in range(n_points)]
    loss_history_stoch_rec = [100 + 50 * np.sin(0.3 * i) + np.random.normal(0, 5) for i in range(n_points)]
    loss_history_stoch_rec_layer = [500 * np.exp(-0.08 * i) + np.random.normal(0, 20) for i in range(n_points)]
    loss_history_imp_min = [10 * np.exp(-0.05 * i) + np.random.normal(0, 1) for i in range(n_points)]

    return plot_loss_histories(loss_history, loss_history_stoch_rec, loss_history_stoch_rec_layer,
                              loss_history_faithfulness, loss_history_imp_min)

# Uncomment to test:
# test_fig = test_plotter()

In [None]:
ended up with
pred_imp min/max: tensor(0.7151) tensor(1.6347)
mask min/max: tensor(0.7313) tensor(1.)
Masked out min:  -7.006863117218018 , max:  5.271817207336426
Faithfulness: 2.7321490847498353e-08, Stoch Rec: 2.3285586833953857, Stoch Rec Layerwise: 1.1707794666290283, Importance Min: 118.15776062011719
Total gradient norm: 889.7729806290456

on hyperparameters

    config = {
        "num_layers": 1,
        "pre_embed_size": 100,
        "in_size": 1000,
        "hidden_size": 50,
        "subcomponents_per_layer": 30,
        "beta_1": 1.0,
        "beta_2": 1.0,
        "beta_3": 5e-2,
        "causal_imp_min": 1.0,
        "num_mask_samples": 20,
        "importance_mlp_size": 5,
    }


ideally should start using wandb lol