# Stochastic Parameter Decomposition

## Imports

In [None]:
# import all the torch stuff, as well as the ViT stuff and such
# import plotly.express as px
# import torch
# from jaxtyping import Int, Float
# from typing import List, Optional, Tuple
# from tqdm import tqdm
# from transformer_lens.hook_points import HookPoint
# from transformer_lens import utils, HookedTransformer, HookedTransformerConfig, FactoredMatrix, ActivationCache
# import circuitsvis as cv

In [11]:
import torch
from torch import Tensor
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
import numpy as np
import einops
import typing
from tqdm import tqdm

## Toy Model MLP

In [1]:
# create a SPDHookedTransformer Class
# that extends the hookedTransformer
# which inherits all the previous class stuff
# but also contains the SPD training algo?

# how is everything actually structured? 
"""
Maybe I should start with implementing it on a simple MLP setup. 

Orig network:
- 1 layer MLP 5-2-5 (defined as usual with torch.Sequential presumably)

SPD network:
let's give it 10 subcomponents per layer
then it's just like 10 matmuls? 
maybe i should define it like ant did in the toy models of superposition paper
stick them all into a trenchcoat
this might be easier once I have defined the toy modle

"""

In [55]:
config = {
    "num_layers": 1,
    "pre_embed_size": 100,
    "in_size": 1000,
    "hidden_size": 50,
    "subcomponents_per_layer": 10, 
    "beta_1": 1, 
    "beta_2": 1, 
    "beta_3": 1, 
    "causal_imp_min": 1, 
    "num_mask_samples": 20,
    "importance_mlp_size": 10,
}

In [79]:
class ToyResidMLP(nn.Module):
    def __init__(self, config, device="mps"):
        super().__init__()
        # Initialize Weights for the
        self.num_layers, self.pre_embed_size, self.in_size, self.hidden_size = config["num_layers"], config["pre_embed_size"], config["in_size"], config["hidden_size"]
        self.device = device
        self.W_embed = nn.Parameter(torch.empty((self.pre_embed_size, self.in_size)))
        self.W_unembed = nn.Parameter(torch.empty((self.in_size, self.pre_embed_size)))
        self.W_in = nn.ParameterList([torch.empty((self.in_size, self.hidden_size), device=device) for i in range(self.num_layers)])
        self.W_out = nn.ParameterList([torch.empty((self.hidden_size, self.in_size), device=device) for i in range(self.num_layers)])
        self.b = nn.ParameterList([torch.zeros((self.hidden_size,), device=device) for i in range(self.num_layers)])

        for param in [self.W_embed, self.W_unembed] + list(self.W_in) + list(self.W_out): 
            nn.init.xavier_normal_(param)
        
    def forward(self, x): 
        
        assert x.shape[1] == self.pre_embed_size, f"Input shape {x.shape[0]} does not match model's accepted size {self.pre_embed_size}"
        # embed 
        x_resid = torch.einsum("np,pi->ni", x.clone(), self.W_embed)
        N, D = x_resid.shape

        
        for l in range(self.num_layers):
            hidden = F.relu(torch.einsum("nd,dh -> nh", x_resid, self.W_in[l]) + self.b[l])
            layer_out = torch.einsum("nh,hd -> nd", hidden, self.W_out[l])
            x_resid = x_resid + layer_out
        # am I supposed to have a embed and out?
        x_out = torch.einsum("ni,ip->np", x_resid, self.W_unembed) 
        return x_out


# Generated by LLM bc I am lazy and this is not important to me learning stuff atm
class SparseAutoencoderDataset(Dataset):
    """
    Dataset for learning to reconstruct sparse inputs.
    Each item is (input, target) where target is the input itself (or ReLU of input).
    """
    def __init__(self, in_dim=100, n_samples=10000, sparsity=0.9, device="mps"):
        super().__init__()
        self.in_dim = in_dim
        self.n_samples = n_samples
        self.device = device

        # Pre-generate all samples
        self.inputs = []
        self.targets = []
        
        for _ in range(n_samples):
            # Sparse input: each entry is -1, 0, or 1, with sparsity
            x = np.random.choice([0, -1, 1], size=(in_dim,), p=[1-sparsity, sparsity/2, sparsity/2])
            x = torch.tensor(x, dtype=torch.float32, device=device)
            target = F.relu(x)
            
            self.inputs.append(x)
            self.targets.append(target)
    
    def __len__(self):
        return self.n_samples
    
    def __getitem__(self, idx):
        return self.inputs[idx], self.targets[idx]

def train_toy_resid_mlp(
    model,
    dataloader,
    lr=1e-3,
    num_epochs=10,
    device="mps",
    print_every=1
):
    model.train()
    model.to(device)
    optimizer = torch.optim.AdamW(model.parameters(), lr=lr)
    loss_fn = nn.MSELoss()
    for epoch in range(num_epochs):
        total_loss = 0.0
        for x, y in tqdm(dataloader, desc=f"Epoch {epoch+1}/{num_epochs}"):
            x = x.to(device)
            y = y.to(device)
            optimizer.zero_grad()
            y_pred = model(x)
            loss = loss_fn(y_pred, y)
            loss.backward()
            optimizer.step()
            total_loss += loss.item() * x.size(0)
        avg_loss = total_loss / len(dataloader.dataset)
        if (epoch+1) % print_every == 0:
            print(f"Epoch {epoch+1}: avg MSE loss = {avg_loss:.6f}")

In [80]:
## LLM-Generated Usage Example

torch.autograd.set_detect_anomaly(True)

if __name__ == "__main__":
    # Config
    device = "cuda" if torch.cuda.is_available() else "cpu"
    # Dataset and DataLoader
    # Dataset and DataLoader
    dataset = SparseAutoencoderDataset(
        in_dim=100,
        n_samples=10000,
        sparsity=0.9,
        device=device,
    )
    dataloader = DataLoader(dataset, batch_size=128, shuffle=True)
    # Model
    model = ToyResidMLP(config, device=device)
    # Train
    train_toy_resid_mlp(model, dataloader, lr=8e-2, num_epochs=20, device=device)

Epoch 1/20: 100%|████████████████████████████████████████████████████████████████████████████████████████████████████| 79/79 [00:00<00:00, 101.31it/s]


Epoch 1: avg MSE loss = 365801.498374


Epoch 2/20: 100%|████████████████████████████████████████████████████████████████████████████████████████████████████| 79/79 [00:00<00:00, 100.97it/s]


Epoch 2: avg MSE loss = 348.291439


Epoch 3/20: 100%|████████████████████████████████████████████████████████████████████████████████████████████████████| 79/79 [00:00<00:00, 102.89it/s]


Epoch 3: avg MSE loss = 42.082283


Epoch 4/20: 100%|████████████████████████████████████████████████████████████████████████████████████████████████████| 79/79 [00:00<00:00, 106.78it/s]


Epoch 4: avg MSE loss = 21.854127


Epoch 5/20: 100%|████████████████████████████████████████████████████████████████████████████████████████████████████| 79/79 [00:00<00:00, 104.16it/s]


Epoch 5: avg MSE loss = 13.048978


Epoch 6/20: 100%|████████████████████████████████████████████████████████████████████████████████████████████████████| 79/79 [00:00<00:00, 106.63it/s]


Epoch 6: avg MSE loss = 8.327664


Epoch 7/20: 100%|████████████████████████████████████████████████████████████████████████████████████████████████████| 79/79 [00:00<00:00, 106.65it/s]


Epoch 7: avg MSE loss = 5.563310


Epoch 8/20: 100%|████████████████████████████████████████████████████████████████████████████████████████████████████| 79/79 [00:00<00:00, 104.45it/s]


Epoch 8: avg MSE loss = 3.846435


Epoch 9/20: 100%|████████████████████████████████████████████████████████████████████████████████████████████████████| 79/79 [00:00<00:00, 103.32it/s]


Epoch 9: avg MSE loss = 2.737103


Epoch 10/20: 100%|███████████████████████████████████████████████████████████████████████████████████████████████████| 79/79 [00:00<00:00, 102.70it/s]


Epoch 10: avg MSE loss = 1.999273


Epoch 11/20: 100%|████████████████████████████████████████████████████████████████████████████████████████████████████| 79/79 [00:00<00:00, 93.49it/s]


Epoch 11: avg MSE loss = 1.495929


Epoch 12/20: 100%|███████████████████████████████████████████████████████████████████████████████████████████████████| 79/79 [00:00<00:00, 106.69it/s]


Epoch 12: avg MSE loss = 1.147356


Epoch 13/20: 100%|███████████████████████████████████████████████████████████████████████████████████████████████████| 79/79 [00:00<00:00, 100.04it/s]


Epoch 13: avg MSE loss = 0.903899


Epoch 14/20: 100%|███████████████████████████████████████████████████████████████████████████████████████████████████| 79/79 [00:00<00:00, 106.46it/s]


Epoch 14: avg MSE loss = 0.731857


Epoch 15/20: 100%|███████████████████████████████████████████████████████████████████████████████████████████████████| 79/79 [00:00<00:00, 103.62it/s]


Epoch 15: avg MSE loss = 0.610296


Epoch 16/20: 100%|███████████████████████████████████████████████████████████████████████████████████████████████████| 79/79 [00:00<00:00, 103.59it/s]


Epoch 16: avg MSE loss = 0.524525


Epoch 17/20: 100%|███████████████████████████████████████████████████████████████████████████████████████████████████| 79/79 [00:00<00:00, 104.82it/s]


Epoch 17: avg MSE loss = 0.464593


Epoch 18/20: 100%|███████████████████████████████████████████████████████████████████████████████████████████████████| 79/79 [00:00<00:00, 106.02it/s]


Epoch 18: avg MSE loss = 0.423585


Epoch 19/20: 100%|███████████████████████████████████████████████████████████████████████████████████████████████████| 79/79 [00:00<00:00, 100.32it/s]


Epoch 19: avg MSE loss = 0.396789


Epoch 20/20: 100%|███████████████████████████████████████████████████████████████████████████████████████████████████| 79/79 [00:00<00:00, 105.60it/s]

Epoch 20: avg MSE loss = 0.380262





## SPD Model and Train Function

In [8]:
# generated by LLM

class HardSigmoid(nn.Module):
    """
    Implements the hard sigmoid activation function as described in the paper:
        σ_H(x) = 0 if x <= 0
               = x if 0 < x < 1
               = 1 if x >= 1
    This is equivalent to: torch.clamp(x, min=0.0, max=1.0)
    """
    def __init__(self):
        super().__init__()

    def forward(self, x):
        # Clamp values between 0 and 1
        return torch.clamp(x, min=0.0, max=1.0)

tensor([0.0000, 0.0000, 0.0000, 0.3000, 0.7000, 1.0000, 1.0000])


In [None]:
class SPDModelMLP(nn.Module): 
    # fix the betas stuff below that's basically a type hint
    def __init__(self, target_model, config, device="mps"): 
        self.device = device
        self.target_model = target_model
        self.num_layers, self.pre_embed_size, self.in_size, self.hidden_size, self.imp_hidden_size = config["num_layers"], config["pre_embed_size"], config["in_size"], config["hidden_size"], config["importance_mlp_size"]
        assert self.device == target_model.device, "Models not on same device"
        self.C = config["subcomponents_per_layer"]
        self.hypers = dict(list(config.items())[4:]) # sets the "hypers" to contain all the hyperparameters for the model
        
        # Subcomponent vectors, each of shape C by in_size; to be used
        # with outer product to create our low-rank subcomponent matrices
        self.V_embed = nn.Parameter(torch.empty((self.C, self.pre_embed_size,), device=device))
        self.U_embed = nn.Parameter(torch.empty((self.C, self.in_size,), device=device))
        self.V_unembed = nn.Parameter(torch.empty((self.C, self.in_size,), device=device))
        self.U_unembed = nn.Parameter(torch.empty((self.C, self.pre_embed_size,), device=device))

        self.V_in = nn.ParameterList([torch.empty((self.C, self.in_size,), device=device) for i in range(num_layers)])
        self.U_in = nn.ParameterList([torch.empty((self.C, self.hidden_size,), device=device) for i in range(num_layers)]) 
        self.V_out = nn.ParameterList([torch.empty((self.C, self.hidden_size,), device=device) for i in range(num_layers)])
        self.U_out = nn.ParameterList([torch.empty((self.C, self.in_size,), device=device) for i in range(num_layers)])
        
        
        # idk what you do with the biases lol
        self.b = nn.ParameterList([torch.zeros((self.hidden_size,), device=device) for i in range(num_layers)])
        
        # this is so horrible I'm sorry
        # gate_in_in is the gate_in weights for the in subcomponent of each layer
        # gate_in_out is gate_in weights for the out component
        # they each get an extra one on the end which is for the embed matrix
        self.imp_W_gate_in_in = nn.ParameterList([torch.empty(C, 1, self.imp_hidden_size) for i in range(num_layers)+1])
        self.imp_W_gate_out_in = nn.ParameterList([torch.empty(C, self.imp_hidden_size, 1) for i in range(num_layers)+1]) 
        self.imp_b_in_in = nn.ParameterList([torch.empty(C, self.imp_hidden_size) for i in range(num_layers)+1])
        self.imp_b_out_in = nn.ParameterList([torch.empty(C, 1) for i in range(num_layers)+1])
        
        self.imp_W_gate_in_out = nn.ParameterList([torch.empty(C, 1, self.imp_hidden_size) for i in range(num_layers)+1])
        self.imp_W_gate_out_out = nn.ParameterList([torch.empty(C, self.imp_hidden_size, 1) for i in range(num_layers)+1]) 
        self.imp_b_in_out = nn.ParameterList([torch.empty(C, self.imp_hidden_size) for i in range(num_layers)+1])
        self.imp_b_out_out = nn.ParameterList([torch.empty(C, 1) for i in range(num_layers)+1])
        
        # self.pred_weights = torch.empty((num_layers, self.C), device=device)

        for param in self.V_in + self.U_in + self.V_out + self.U_out + self.b: 
            # using xavier anyway -- note that the variance etc is
            # changed because we take the outer product in the 
            # forward pass
            nn.init.xavier_normal_(param)

        
    def forward(self, x, return_activs_and_weights=False, masks=None): # Regular run. Unclear whether I should have masking when I do a regular forward pass.
        v_activations = []
        weight_matrices = []
        x_in = x.clone()
        layerwise_resids = []

        if masks is not None: 
            
            v_activ_embed = torch.einsum("np,cp->nc", x_in, self.V_embed) * masks[-1]["embed"] # TODO: ADD MASKS FOR EMBED
            x_embedded = torch.einsum("nc,ci->nh", v_activ_embed, self.U_embed)
            x_resid = x_embedded.clone()
                        
            for l in range(self.num_layers):
                # may have dimension issues in these einsums :(
                ### Run the forward pass just for this layer
                v_activ_layer_in = torch.einsum("ni,ci->nc", x_resid, self.V_in[l]) * masks[l]["in"]
                layer_hidden = torch.einsum("nc,ch->nh", v_activ_layer_in, self.U_in[l])
                v_activ_layer_out = torch.einsum("nh,ch->nc", layer_in, self.V_out[l]) * masks[l]["out"]
                layer_out = torch.einsum("nc,ci->ni", v_activ_layer_in, self.U_out[l])
                x_resid = x_resid + layer_out

                # Run the layerwise forward pass for the whole model, with masking only this layer
                x_resid_layerwise = x_embedded.clone() # pre-embedded
                for l_2 in self.num_layers: # oh god this is so awful. literally sobbing rn. i can't believe I'm writing it like this
                    if l_2 == l: 
                        v_activ_layer_in_l = torch.einsum("ni,ci->nc", x_resid_layerwise, self.V_in[l]) * masks[l]["in"]
                        layer_hidden_l = torch.einsum("nc,ch->nh", v_activ_layer_in_l, self.U_in[l])
                        v_activ_layer_l = torch.einsum("nh,ch->nc", layer_in_l, self.V_out[l]) * masks[l]["out"]
                        layer_out_l = torch.einsum("nc,ci->ni", v_activ_layer_in_l, self.U_out[l])
                    else: # THE SAME EXACT THING EXCEPT WITHOUT THE MASKS >_< 
                        v_activ_layer_in_l = torch.einsum("ni,ci->nc", x_resid_layerwise, self.V_in[l])
                        layer_hidden_l = torch.einsum("nc,ch->nh", v_activ_layer_in_l, self.U_in[l])
                        v_activ_layer_l = torch.einsum("nh,ch->nc", layer_in_l, self.V_out[l])
                        layer_out_l = torch.einsum("nc,ci->ni", v_activ_layer_in_l, self.U_out[l])
                    x_resid_layerwise = x_resid_layerwise + layer_out_l
                v_activ_unembed_layerwise = torch.einsum("nc,cp->np", torch.einsum("ni,ci->nc", x_resid, self.V_unembed) * masks[-1]["unembed"], self.U_unembed)
                layerwise_resids.append(x_resid_layerwise)
            
            v_activ_unembed = torch.einsum("ni,ci->nc", x_resid, self.V_unembed) * masks[-1]["unembed"]
            x_out = torch.einsum("nc,cp->np", v_activ_unembed, self.U_unembed)

                                
        else: 
            W_embed = torch.einsum("cp, ci -> cpi", self.V_embed, self.U_embed).sum(dim=0)
            x_resid = torch.einsum("np, pi -> ni", x_in, self.W_embed)
            
            for l in range(self.num_layers):
                # Use outer product to create weights for the layer, then sum all the subcomponents
                W_in = torch.einsum("ci,ch-> cih", self.V_in[l], self.U_in[l]).sum(dim=0) # shape i h
                W_out = torch.einsum("ch,ci-> chi", self.V_out[l], self.U_out[l]).sum(dim=0) # shape h i
                if return_activs_and_weights == True:
                    weight_matrices.append({"in": W_in, "out", W_out})
                # COMPUTE 
                layer_in = torch.einsum("ni,ih -> nh", x_resid, W_in) + self.b[l]
                layer_out = torch.einsum("nh,hi->ni", F.relu(layer_in), W_out)
            
            if return_activs_and_weights == True: 
                # calculate activations
                v_activ_layer_in = torch.einsum("ni,ci->nc", x_resid, self.V_in[l]).unsqueeze(-1)
                v_activ_layer_out = torch.einsum("nh,ch->nc", layer_in, self.V_out[l]).unsqueeze(-1)
                # both are now shape n,c,1
                v_activations.append({"in": v_activ_layer_in, "out": v_activ_layer_out})
                
                # code to make sure I'm not doing some goofy shit with the activations at the beginning
                check_layer_in_uv = torch.einsum("nc,ch->nh", v_activ_layer_in, self.U_in[l])
                assert layer_in == check_layer_in_uv, "Subcomponent activations not calculated correctly"
                # would in theory have one for the second too but I'm being lazy
            
            x_resid = x_resid + layer_out
            W_unembed = torch.einsum("ci, cp-> cip", self.V_unembed, self.U_unembed)
            x_out = torch.einsum("ni, ip -> np", x_resid, self.W_embed)
            weight_matrices.append({"embed": W_embed, "unembed": W_unembed})
            # putting the embed matrices at the end of the masks. since we iterate through l in range(num_layers) this will not get added! :D

        if return_activs_and_weights:
            return x_resid, v_activations, weight_matrices 
        else if masks is not None: 
            return x_resid, layerwise_resids
        else: 
            return x_resid
    


In [None]:
def generate_batch(config): 
    # return shape N, config.in_size 
    pass

"""

one last thing to do to make the model work

parameters:
    model,
    dataloader,
    lr=1e-3,
    num_epochs=10,
    device="mps",
    print_every=1
    
--- 

    model.train()
    model.to(device)
    optimizer = torch.optim.AdamW(model.parameters(), lr=lr)
    loss_fn = nn.MSELoss()
    for epoch in range(num_epochs):
        total_loss = 0.0
        for x, y in tqdm(dataloader, desc=f"Epoch {epoch+1}/{num_epochs}"):
            x = x.to(device)
            y = y.to(device)
            optimizer.zero_grad()
            y_pred = model(x)
            loss = loss_fn(y_pred, y)
            loss.backward()
            optimizer.step()
            total_loss += loss.item() * x.size(0)
        avg_loss = total_loss / len(dataloader.dataset)
        if (epoch+1) % print_every == 0:
            print(f"Epoch {epoch+1}: avg MSE loss = {avg_loss:.6f}")
"""

def train_SPD(spd_model, dataloader, "more stuff goes here"): # could also implement this by passing in the original model?
    
    for t in train_steps:
        x = generate_batch() # TODO. possibly put in the model? idk if needed.
        N = x.shape[0]     # x is shape N by in_size
        C = model.C
        target_model = spd_model.target_model
        hard_sigmoid = HardSigmoid()
        
        # COMPUTE TARGET OUTPUT
        with torch.no_grad(): 
            target_out = spd_model.target_model(x)
    
        # FAITHFULNESS LOSS
        
        spd_output, spd_activations spd_weights = model(x, return_activations = True, return_weight_matrices = True)
        squared_error = 0
        for l in range(num_layers):
            in_diff = target_model.W_in[l] - spd_weights[l]["in"]
            out_diff = target_model.W_out[l] - spd_weights[l]["out"]
            
            # torch.linalg.matrix_norm defaults to the frobenius norm
            # this takes the frobenius norm of the diff and then squares it
            squared_error_layer = torch.linalg.matrix_norm(in_diff) ** 2 + torch.linalg.matrix_norm(out_diff) ** 2 
            squared_error = squared_error + squared_error_layer
            
        mean_squared_error = squared_error/num_layers
        l_faithfulness = mean_squared_error
        
    
        ## IMPORTANCE-MINIMALITY LOSS
        
        pred_importances = []
        l_importance_minimality = 0
    
        components_imp_pred_embed_hidden = F.gelu(torch.einsum("nco,cos->ncs", spd_activations[-1]["embed"], spd_model.imp_W_gate_in_in[-1]) + spd_model.imp_b_in[-1])
        components_imp_pred_embed = hard_sigmoid(torch.einsum("ncs,cso->nco", components_imp_pred_hidden, spd_model.imp_W_gate_out_in[-1]) + spd_model.imp_b_out_in[-1])
        
        components_imp_pred_unembed_hidden = F.gelu(torch.einsum("nco,cos->ncs", spd_activations[-1]["unembed"], spd_model.imp_W_gate_in_in[-1]) + spd_model.imp_b_in[-1])
        components_imp_pred_unembed = hard_sigmoid(torch.einsum("ncs,cso->nco", components_imp_pred_unembed_hidden, spd_model.imp_W_gate_out_in[-1]) + spd_model.imp_b_out_in[-1])
        l_importance_minimality = l_importance_minimality + (components_imp_pred_embed ** spd_model.hypers["importance_mlp_size"]) + (components_imp_pred_unembed ** spd_model.hypers["importance_mlp_size"])
    
        for l in range(num_layers):
            # both activations are n by c containing dot product so we already have hard_sigmoid
            # spd_activations[l][inout] is shape n,c,1 (nco)
            # imp_W_in is c by 1 by imp_size (cos)
            # want to map to ncs then back to nco
            # imp_b_in is shape (C, s) so should broadcast to ncs nicely
            
            # TODO: DEFINE hard_sigmoid AS A TORCH MODULE SO IT CAN CALCULATE THE DERIVATIVE 
            # in theory should write this as a bunch of models stored in the main model, but that's not how i did it and i've already 
            # written this, shrug
            # oh my god this is so bad 
            # imp_W_gate_in_in is the gate in for the W_in, etc
            components_imp_pred_hidden_in = F.gelu(torch.einsum("nco,cos->ncs", spd_activations[l]["in"], spd_model.imp_W_gate_in_in[l]) + spd_model.imp_b_in_in[l])
            components_pred_layer_in = hard_sigmoid(torch.einsum("ncs,cso->nco", components_imp_pred_hidden, spd_model.imp_W_gate_out_in[l]) + spd_model.imp_b_out_in[l])
    
            #same thing for the out matrix in layer l
            components_imp_pred_hidden_out = F.gelu(torch.einsum("nco,cos->ncs", spd_activations[l]["out"], spd_model.imp_W_gate_in_out[l]) + spd_model.imp_b_in_out[l])
            components_pred_layer_out = hard_sigmoid(torch.einsum("ncs,cso->nco", components_imp_pred_hidden, spd_model.imp_W_gate_out_out[l]) + spd_model.imp_b_out_out[l])
            
            pred_importances.append({"in": components_pred_layer_in, "out": components_pred_layer_out)
            l_importance_minimality = l_importance_minimality + (components_pred_layer_in ** spd_model.hypers["importance_mlp_size"]) + (components_pred_layer_out ** spd_model.hypers["importance_mlp_size"])
        
        l_importance_minimality /= N # divide by N, the batch size, to avg across the batch (notated as B in the paper)
    
    
        ## STOCHASTIC RECONSTRUCTION LOSS 
    
        l_stochastic_recon = 0
        l_stochastic_recon_layerwise = 0
        R = torch.rand(model.hypers["num_mask_samples"], N, model.num_layers+1, 2, C)
        layer_masks = []
    
        for s in range(model.hypers["num_mask_samples"]):
            # Running this with a for loop. This is slow; I'd ideally just run 
            # something like M = G[None, :, :] + (1 - G[None, :, :]) * R
            # But I think there's a clarity tradeoff here so I'm just going to do this for now
            layer_mask_embed = pred_importances[-1]["embed"].squeeze() + (torch.ones_like(pred_importances[-1]["embed"]) - pred_importances[l]["embed"])*R[s,:,l,0,:]
            layer_mask_unembed = pred_importances[-1]["unembed"].squeeze() + (torch.ones_like(pred_importances[l]["unembed"]) - pred_importances[l]["unembed"]) * R[s,:,l,1,:]
            
            for l in range(num_layers):
                # ugh, this is pretty bad. Ideally I'd go back and refactor so that I don't 
                # have separate in and out weights, but this is what I have. going to keep it
                # like this for now, but will refactor once I have something that works.
                # components pred in is shape N, C, 1, squeeze to N, C. there are L of them
                # R is N, L, C, sample just l to get N, C
                # I want masks for each component on each layer on each datapoint, so
                # masks should be one C vector for each layer, per-datapoint. so NLC
                layer_mask_in = pred_importances[l]["in"].squeeze() + (torch.ones_like(pred_importances[l]["in"]) - pred_importances[l]["in"]) * R[s,:,l,0,:]
                layer_mask_out = pred_importances[l]["out"].squeeze() + (torch.ones_like(pred_importances[l]["out"]) - pred_importances[l]["out"]) * R[s,:,l,1,:]
                # these are both shape N, C 
                layer_masks.append({"in": layer_mask_in, "out": layer_mask_out})
                
            masked_out, layerwise_masked_outs = model(x, masks=layer_masks)
            l_stochastic_recon = l_stochastic_recon + torch.linalg.matrix_norm(target_out - masked_out) # uses matrix norm of difference
            for i in range(len(layerwise_masked_outs)-1):
                l_stochastic_recon_layerwise = l_stochastic_recon_layerwise + torch.linalg.matrix_norm(target_out - layerwise_masked_outs[i])
    
        l_stochastic_recon /= model.hypers["num_mask_samples"]
        l_stochastic_recon_layerwise /= (model.hypers["num_mask_samples"] * model.num_layers)
    
        loss = l_faithfulness + model.hypers["beta1"] * l_stochastic_recon + model.hypers["beta2"] * l_stochastic_recon_layerwise + model.hypers["beta3"] * l_importance_minimality

    """
    TODO: 
    - generate batches
        - i want a dataloader object that i can iterate through
        - my first test will be the 1-layer 50-hidden 100-in dim
    - start testing
    - wrap functions with tqdm
    """
    