# SPD Refector

My first implementation was slow and poorly-organized. I'm reimplementing it with the following changes:

1. Make subcomponents their own modules
2. Make importance predictors their own modules
3. Vectorize masking

In [1]:
import torch
from torch import Tensor
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
import numpy as np
import einops
import typing
from tqdm import tqdm

In [102]:
DEVICE = "mps"
CONFIG = {
    "num_layers": 1,
    "pre_embed_size": 100,
    "in_size": 1000,
    "hidden_size": 50,
    "subcomponents_per_layer": 30, 
    "beta_1": 1.0, 
    "beta_2": 1.0, 
    "beta_3": 0.1, 
    "causal_imp_min": 1.0, 
    "num_mask_samples": 20,
    "importance_mlp_size": 5,
}
TRAIN_CONFIG = {
    "lr": 8e-4,
    "lr_step_size": 4,
    "lr_gamma": 0.5,
}
BATCH_SIZE=128
NUM_EPOCHS=25

In [100]:
class ToyResidMLP(nn.Module):
    def __init__(self, config, device="cpu"):
        super().__init__()
        # Initialize Weights for the
        self.num_layers, self.pre_embed_size, self.in_size, self.hidden_size = config["num_layers"], config["pre_embed_size"], config["in_size"], config["hidden_size"]
        self.device = device
        self.W_embed = nn.Parameter(torch.empty((self.pre_embed_size, self.in_size)))
        self.W_unembed = nn.Parameter(torch.empty((self.in_size, self.pre_embed_size)))
        self.W_in = nn.ParameterList([torch.empty((self.in_size, self.hidden_size), device=device) for i in range(self.num_layers)])
        self.W_out = nn.ParameterList([torch.empty((self.hidden_size, self.in_size), device=device) for i in range(self.num_layers)])
        self.b = nn.ParameterList([torch.zeros((self.hidden_size,), device=device) for i in range(self.num_layers)])

        for param in [self.W_embed, self.W_unembed] + list(self.W_in) + list(self.W_out): 
            nn.init.xavier_normal_(param)
        
    def forward(self, x): 
        
        assert x.shape[1] == self.pre_embed_size, f"Input shape {x.shape[0]} does not match model's accepted size {self.pre_embed_size}"
        # embed 
        x_resid = torch.einsum("np,pi->ni", x.clone(), self.W_embed)
        N, D = x_resid.shape

        for l in range(self.num_layers):
            hidden = F.relu(torch.einsum("nd,dh -> nh", x_resid, self.W_in[l]) + self.b[l])
            layer_out = torch.einsum("nh,hd -> nd", hidden, self.W_out[l])
            x_resid = x_resid + layer_out
        # am I supposed to have a embed and out?
        x_out = torch.einsum("ni,ip->np", x_resid, self.W_unembed) 
        return x_out


# I vibecoded this originally, and regretted it. 
class SparseAutoencoderDataset(Dataset):
    def __init__(self, in_dim=100, n_samples=10000, sparsity=0.9, device="cpu"):
        super().__init__()
        self.in_dim = in_dim
        self.n_samples = n_samples
        self.device = device

        # Pre-generate all samples
        self.inputs = []
        self.targets = []
        
        for _ in range(n_samples):
            x = np.random.uniform(-1, 1, size=(in_dim))
            mask = np.random.rand(in_dim) > sparsity  # 1-sparsity fraction will be nonzero
            x = x * mask
            x = torch.tensor(x, dtype=torch.float32, device=device)
            
            target = F.relu(x)
            
            self.inputs.append(x)
            self.targets.append(target)
    
    def __len__(self):
        return self.n_samples
    
    def __getitem__(self, idx):
        return self.inputs[idx], self.targets[idx]

def train_toy_resid_mlp(
    model,
    dataloader,
    lr=1e-3,
    num_epochs=10,
    device="cpu",
    print_every=1
):
    model.train()
    model.to(device)
    optimizer = torch.optim.AdamW(model.parameters(), lr=lr)
    loss_fn = nn.MSELoss()
    for epoch in range(num_epochs):
        total_loss = 0.0
        for x, y in tqdm(dataloader, desc=f"Epoch {epoch+1}/{num_epochs}"):
            x = x.to(device)
            y = y.to(device)
            optimizer.zero_grad()
            y_pred = model(x)
            loss = loss_fn(y_pred, y)
            loss.backward()
            optimizer.step()
            total_loss += loss.item() * x.size(0)
        avg_loss = total_loss / len(dataloader.dataset)
        if (epoch+1) % print_every == 0:
            print(f"Epoch {epoch+1}: avg MSE loss = {avg_loss:.6f}")


## LLM-Generated Usage Example
if __name__ == "__main__":
    device=DEVICE
    config=CONFIG
    
    dataset = SparseAutoencoderDataset(
        in_dim=100,
        n_samples=15000,
        sparsity=0.9,
        device=device,
    )

    print(device)
    dataloader = DataLoader(dataset, batch_size=128, shuffle=True)
    # Model
    toy_model = ToyResidMLP(config, device=device)
    # Train
    train_toy_resid_mlp(toy_model, dataloader, lr=8e-2, num_epochs=20, device=device)

mps


Epoch 1/20: 100%|██████████████████████████████████████████████████████████████████████████████████████████████████| 118/118 [00:01<00:00, 107.03it/s]


Epoch 1: avg MSE loss = 10775.810246


Epoch 2/20: 100%|██████████████████████████████████████████████████████████████████████████████████████████████████| 118/118 [00:00<00:00, 243.19it/s]


Epoch 2: avg MSE loss = 2.699132


Epoch 3/20: 100%|██████████████████████████████████████████████████████████████████████████████████████████████████| 118/118 [00:00<00:00, 245.40it/s]


Epoch 3: avg MSE loss = 0.914385


Epoch 4/20: 100%|██████████████████████████████████████████████████████████████████████████████████████████████████| 118/118 [00:00<00:00, 199.57it/s]


Epoch 4: avg MSE loss = 0.447380


Epoch 5/20: 100%|██████████████████████████████████████████████████████████████████████████████████████████████████| 118/118 [00:00<00:00, 211.31it/s]


Epoch 5: avg MSE loss = 0.249607


Epoch 6/20: 100%|██████████████████████████████████████████████████████████████████████████████████████████████████| 118/118 [00:00<00:00, 245.06it/s]


Epoch 6: avg MSE loss = 0.150474


Epoch 7/20: 100%|██████████████████████████████████████████████████████████████████████████████████████████████████| 118/118 [00:00<00:00, 212.89it/s]


Epoch 7: avg MSE loss = 0.096076


Epoch 8/20: 100%|██████████████████████████████████████████████████████████████████████████████████████████████████| 118/118 [00:00<00:00, 241.68it/s]


Epoch 8: avg MSE loss = 0.064693


Epoch 9/20: 100%|██████████████████████████████████████████████████████████████████████████████████████████████████| 118/118 [00:00<00:00, 243.43it/s]


Epoch 9: avg MSE loss = 0.045993


Epoch 10/20: 100%|█████████████████████████████████████████████████████████████████████████████████████████████████| 118/118 [00:00<00:00, 211.47it/s]


Epoch 10: avg MSE loss = 0.034528


Epoch 11/20: 100%|█████████████████████████████████████████████████████████████████████████████████████████████████| 118/118 [00:00<00:00, 226.13it/s]


Epoch 11: avg MSE loss = 0.027398


Epoch 12/20: 100%|█████████████████████████████████████████████████████████████████████████████████████████████████| 118/118 [00:00<00:00, 241.22it/s]


Epoch 12: avg MSE loss = 0.022937


Epoch 13/20: 100%|█████████████████████████████████████████████████████████████████████████████████████████████████| 118/118 [00:00<00:00, 239.26it/s]


Epoch 13: avg MSE loss = 0.020132


Epoch 14/20: 100%|█████████████████████████████████████████████████████████████████████████████████████████████████| 118/118 [00:00<00:00, 238.54it/s]


Epoch 14: avg MSE loss = 0.018385


Epoch 15/20: 100%|█████████████████████████████████████████████████████████████████████████████████████████████████| 118/118 [00:00<00:00, 195.21it/s]


Epoch 15: avg MSE loss = 0.017302


Epoch 16/20: 100%|█████████████████████████████████████████████████████████████████████████████████████████████████| 118/118 [00:00<00:00, 223.02it/s]


Epoch 16: avg MSE loss = 0.016653


Epoch 17/20: 100%|█████████████████████████████████████████████████████████████████████████████████████████████████| 118/118 [00:00<00:00, 248.43it/s]


Epoch 17: avg MSE loss = 0.016272


Epoch 18/20: 100%|█████████████████████████████████████████████████████████████████████████████████████████████████| 118/118 [00:00<00:00, 198.96it/s]


Epoch 18: avg MSE loss = 0.016069


Epoch 19/20: 100%|█████████████████████████████████████████████████████████████████████████████████████████████████| 118/118 [00:00<00:00, 215.65it/s]


Epoch 19: avg MSE loss = 0.015973


Epoch 20/20: 100%|█████████████████████████████████████████████████████████████████████████████████████████████████| 118/118 [00:00<00:00, 242.77it/s]

Epoch 20: avg MSE loss = 0.015942





In [25]:
class Subcomponent(nn.Module): 
    # Subcomponents approximate a in_dim x out_dim matrix with c components
    def __init__(self, shape, num_components, device="cuda"):
        super().__init__()
        self.in_dims = shape[0]
        self.out_dims = shape[1]
        self.shape = shape
        self.C = num_components
        self.device=device
        
        self.V = nn.Parameter(torch.empty((self.in_dims, self.C), device=device))
        self.U = nn.Parameter(torch.empty((self.C, self.out_dims), device=device))
        
        nn.init.xavier_normal_(self.V)
        nn.init.xavier_normal_(self.U)

    def forward(self, x, mask=None):
        # self is shape (in_dims, C), (C, out_dims) -> (in_dims, out_dims)
        # x is shape (N, in_dims)
        # mask is shape (N, C)
        N, in_dims = x.shape
        
        if mask is None: 
            activations = x @ self.V
        else: 
            activations = x @ self.V * mask
        out = activations @ self.U
        
        return out, activations

    def return_weights(self): 
        weights = self.V @ self.U
        return weights

In [46]:
class ImportancePredictor(nn.Module):
    def __init__(self, hidden_size, num_components, device="cuda"): 
        super().__init__()
        self.hidden_size = hidden_size
        self.C = num_components
        self.device = device

        # These are techncially (C, hidden, 1) and (C, 1, hidden) 
        # but squeezing and unsqueezing is less efficient
        self.W_gate_in = nn.Parameter(torch.empty((self.C, hidden_size), device=device))
        self.W_gate_out = nn.Parameter(torch.empty((self.C, hidden_size), device=device))
        
        # I think that nonzero biases will be easier to learn 
        self.b_in = nn.Parameter(torch.zeros((self.C, hidden_size), device=device) + 0.1)
        self.b_out = nn.Parameter(torch.zeros((self.C,), device=device) + 0.1) # technically shape (C, 1)
        
        nn.init.xavier_normal_(self.W_gate_in)
        nn.init.xavier_normal_(self.W_gate_out)


    def forward(self, subcomponent_activations):
        # Activations are shape (N, C) (see Subcomponent.forward)
        # In_weights are shape (C, hidden_size) 
        hidden = F.gelu(torch.einsum("nc,cs->ncs", subcomponent_activations, self.W_gate_in) + self.b_in)
        prediction_out = torch.einsum("ncs,cs->nc", hidden, self.W_gate_out) + self.b_out

        return prediction_out # now shape (N,C). Unsqueeze to get (N,C,1)

In [40]:
class MLPSubcomponentLayer(nn.Module):
    def __init__(self, embed_size, hidden_size, num_components, device="cuda"):
        super().__init__()
        
        # "in matrix" since `in` is a builtin thing
        self.in_mat = Subcomponent((embed_size, hidden_size), num_components, device=device)
        self.out_mat = Subcomponent((hidden_size, embed_size), num_components, device=device)
        self.bias = nn.Parameter(torch.zeros((1, hidden_size,), device=device)) 

    def forward(self, x, masks=None):
        # Mask should be dict {in: (N, C), out: (N,C)}
        if masks is None: 
            masks={"in": None, "out": None} # So that mask[0] won't break; instead will pass None in which is ok
        mat_output, activs_in = self.in_mat(x, masks["in"])
        hidden= F.relu(mat_output + self.bias)
        out, activs_out = self.out_mat(hidden, masks["out"])
        return out, {"in": activs_in, "out": activs_out}

    def return_weights_layer(self):
        return {"in": self.in_mat.return_weights(), "out": self.out_mat.return_weights()}
        

In [28]:
class SPDModelMLP(nn.Module): 
    def __init__(self, target_model, config, device="cuda"): 
        super().__init__()
        self.device = device
        object.__setattr__(self, "target_model", target_model) # sets pointer to target_model without registering its parameters as subsidiary
        
        # Unpack Config
        self.C, self.num_layers, self.pre_embed_size, self.embed_size, self.hidden_size, self.imp_hidden_size = config["subcomponents_per_layer"], config["num_layers"], config["pre_embed_size"], config["in_size"], config["hidden_size"], config["importance_mlp_size"]
        self.hypers = dict(list(config.items())[4:])
        self.num_matrices = self.num_layers * 2 + 2
        self.P = sum(p.numel() for p in self.target_model.parameters())


        # Define weights/subcomponents 
        self.embed = Subcomponent((self.pre_embed_size, self.embed_size), self.C, device=device)
        self.unembed = Subcomponent((self.embed_size, self.pre_embed_size), self.C, device=device)
        
        self.layers = nn.ModuleList(
            MLPSubcomponentLayer(self.embed_size, self.hidden_size, self.C, device=device) for _ in range(self.num_layers)
        )
        
        # Define Importance Predictors
        # You can index importance predictors via imp_pred_ers[layer]["in"/"out"]
        self.imp_pred_ers = nn.ModuleList(
            [nn.ModuleDict({
                name: ImportancePredictor(self.imp_hidden_size, self.C, device=device) for name in ["in", "out"]                    
            }) for l in range(self.num_layers)] +
            [nn.ModuleDict({
                name: ImportancePredictor(self.imp_hidden_size, self.C, device=device) for name in ["embed", "unembed"]                    
            })]
        )

    def forward(self, x, masks=None, return_activs_weights=False):
        # Create the masks object so that dict keys never break (simplifies code)
        if masks is None:
            layer_masks = [
                { name: None for name in ("in", "out") }
                for _ in range(self.num_layers)
            ]
            end_masks = { name: None for name in ("embed", "unembed") }
            masks = layer_masks + [end_masks]

        
        activations = []

        # regular forward pass
        x, embed_activs = self.embed(x, masks[-1]["embed"])
        for l in range(self.num_layers):
            x, layer_activations = self.layers[l](x, masks=masks[l])
            activations.append(layer_activations)
            
        x, unembed_activs = self.unembed(x, masks[-1]["unembed"])
        activations.append({"embed": embed_activs, "unembed": unembed_activs})             
        
        
        if not return_activs_weights:
            return x
        else:
            weights = []

            for l in range(self.num_layers):
                weights.append(self.layers[l].return_weights_layer())
            weights.append({"embed": self.embed.return_weights(), "unembed": self.unembed.return_weights()})
            
            return x, activations, weights

In [29]:
class HardSigmoid(nn.Module):
    """
    Implements the hard sigmoid activation function as described in the paper:
        σ_H(x) = 0 if x <= 0
               = x if 0 < x < 1
               = 1 if x >= 1
    This is equivalent to: torch.clamp(x, min=0.0, max=1.0)
    """
    def __init__(self):
        super().__init__()

    def forward(self, x):
        # Clamp values between 0 and 1
        return torch.clamp(x, min=0.0, max=1.0)


class LowerLeakyHardSigmoid(nn.Module):
    """
    Lower-leaky hard sigmoid: σH,lower(x)
    - 0.01*x if x <= 0 (leaky below 0)
    - x if 0 <= x <= 1 (linear in middle)  
    - 1 if x >= 1 (saturated above 1)
    
    Used for forward pass masks in stochastic reconstruction losses.
    """
    def __init__(self, leak_slope=0.01):
        super().__init__()
        self.leak_slope = leak_slope

    def forward(self, x):
        return torch.where(
            x <= 0, 
            self.leak_slope * x,
            torch.where(
                x >= 1,
                torch.ones_like(x),
                x
            )
        )

class UpperLeakyHardSigmoid(nn.Module):
    """
    Upper-leaky hard sigmoid: σH,upper(x)  
    - 0 if x <= 0 (hard cutoff below 0)
    - x if 0 <= x <= 1 (linear in middle)
    - 1 + 0.01*(x-1) if x >= 1 (leaky above 1)
    
    Used for importance loss computation.
    """
    def __init__(self, leak_slope=0.01):
        super().__init__()
        self.leak_slope = leak_slope

    def forward(self, x):
        return torch.where(
            x <= 0,
            torch.zeros_like(x), 
            torch.where(
                x >= 1,
                1 + self.leak_slope * (x - 1),
                x
            )
        )

In [104]:
def train_spd(spd_model, dataloader, train_config, num_epochs=1):

    # SPD model is a model
    spd_model.train()
    print(f"Training on device {spd_model.device}")
    optimizer = torch.optim.AdamW(spd_model.parameters(), lr = train_config["lr"])

    scheduler = torch.optim.lr_scheduler.StepLR(
        optimizer, 
        step_size=train_config["lr_step_size"],   # e.g. every 4 epochs
        gamma=train_config["lr_gamma"]      # multiply LR by 0.5 each time
    )
    
    # P = sum(p.numel() for p in spd_model.target_model.parameters()) [moved into model]
    upper_leaky_sigmoid = UpperLeakyHardSigmoid()
    lower_leaky_sigmoid = LowerLeakyHardSigmoid()

    for epoch in range(num_epochs):
        total_loss = 0.0
        total_l_stoch_rec, total_l_stoch_rec_l, total_l_imp, total_l_faith = 0.0, 0.0, 0.0, 0.0
        print(f"Starting epoch {epoch+1}, lr = {optimizer.param_groups[0]['lr']:.2e}")

        with tqdm(dataloader, desc=f"Epoch {epoch+1}/{num_epochs}") as t:
            for batch_idx, (x,y) in enumerate(t):
                x = x.to(device)
                y = y.to(device)
                
                N = x.shape[0]     # x is shape N by in_size
                S, C, P, L = spd_model.hypers["num_mask_samples"], spd_model.C, spd_model.P, spd_model.num_layers
                target_model = spd_model.target_model
                optimizer.zero_grad()

                # ====== TARGET MODEL OUTPUT ======
                with torch.no_grad(): 
                    target_out = spd_model.target_model(x)

                # ====== FAITHFULNESS LOSS ======
                spd_output, spd_activations, spd_weights = spd_model(x, return_activs_weights = True)
                squared_error = 0

                for l in range(L):
                    in_diff = target_model.W_in[l] - spd_weights[l]["in"]
                    out_diff = target_model.W_out[l] - spd_weights[l]["out"]
                    squared_error_layer = torch.linalg.matrix_norm(in_diff)**2 + torch.linalg.matrix_norm(out_diff)**2 
                    squared_error = squared_error + squared_error_layer

                embed_diff = target_model.W_embed - spd_weights[-1]["embed"]
                unembed_diff = target_model.W_unembed - spd_weights[-1]["unembed"]
                squared_error_embed = torch.linalg.matrix_norm(embed_diff)**2 + torch.linalg.matrix_norm(unembed_diff)**2 
                squared_error = squared_error + squared_error_embed

                l_faithfulness = squared_error/P # "Mean Squared Error" across parameters
                

                # ===== IMPORTANCE MINIMALITY LOSS ======
                l_importance_minimality = 0.0

                pred_importances = []

                imp_pred_embed = spd_model.imp_pred_ers[-1]["embed"](spd_activations[-1]["embed"])
                imp_pred_unembed = spd_model.imp_pred_ers[-1]["unembed"](spd_activations[-1]["unembed"])
                l_importance_minimality = l_importance_minimality + (upper_leaky_sigmoid(imp_pred_embed)** spd_model.hypers["causal_imp_min"]).sum() + (upper_leaky_sigmoid(imp_pred_unembed) ** spd_model.hypers["causal_imp_min"]).sum()
                for l in range(L):
                    imp_pred_in = spd_model.imp_pred_ers[l]["in"](spd_activations[l]["in"])
                    imp_pred_out = spd_model.imp_pred_ers[l]["out"](spd_activations[l]["out"])
                    pred_importances.append({"in": imp_pred_in, "out": imp_pred_out})
                    l_importance_minimality = l_importance_minimality + (upper_leaky_sigmoid(imp_pred_in) ** spd_model.hypers["causal_imp_min"]).sum()  + (upper_leaky_sigmoid(imp_pred_out)** spd_model.hypers["causal_imp_min"]).sum() 
                pred_importances.append({"embed": imp_pred_embed, "unembed": imp_pred_unembed})

                l_importance_minimality /= N


                # ===== STOCHASTIC RECONSTRUCTION LOSS ======
                l_stochastic_recon = 0.0
                l_stochastic_recon_layerwise = 0.0
                R = torch.rand((S, N, L+1, 2, C), device=device)

                # we have activs in shape (N,C). Stack and unsqueeze(0) to (1,2,N,C) and cat along dim 1 to get (L+1,2,N,C)
                stacked_imps = torch.cat([torch.stack((pred_importances[l]["in"], pred_importances[l]["out"])).reshape(1,2,N,C) for l in range(L)] + [torch.stack((pred_importances[-1]["embed"],pred_importances[-1]["unembed"])).reshape(1,2,N,C)])
                # reshape (L+1, 2, N, C) -> (N, L+1, 2, C)
                stacked_imps = torch.movedim(stacked_imps, 2, 0)
                # Apply sigmoid and then reshape to (1, N, L+1, 2, C)
                G = lower_leaky_sigmoid(stacked_imps).unsqueeze(0)
                masks = G + (1-G)*R # shape (S, N, L+1, 2, C)
                masks = masks.reshape(S*N, L+1, 2, C)

                # Move masks back into layerwise structure
                layer_masks = [
                    { name: masks[:,l,idx,:] for name, idx in [("in",0), ("out",1)] } for l in range(L)
                ]
                end_masks = { name: masks[:,l+1,idx,:] for name, idx in [("embed", 0), ("unembed", 1)] }
                masks_dictified = layer_masks + [end_masks]

                #Tile inputs (1, N, embed_size) S times to have N_new = S * N -> (N * S, embed_size)
                x_repeated = x.clone().unsqueeze(0).expand(S, -1, -1).reshape(S * N, -1)

                # Run regular masked loss and update stochastic recon
                masked_out = spd_model(x_repeated, masks=masks_dictified)
                target_out_tiled = target_out.clone().unsqueeze(0).expand(S, -1, -1).reshape(S*N, -1)
                l_stochastic_recon = l_stochastic_recon + torch.linalg.matrix_norm(target_out_tiled-masked_out)**2
                l_stochastic_recon /= S

                # ===== STOCHASTIC RECONSTRUCTION LAYERWISE LOSS ======
                # Create Layerwise Mask constructor
                layer_masks_none = [
                    { name: None for name in ("in", "out") }
                    for _ in range(L)
                ]
                end_masks_none = { name: None for name in ("embed", "unembed") }
                layer_mask_constructor = layer_masks_none + [end_masks_none]

                # Embed and unembed outs
                embed_mask = layer_mask_constructor.copy()
                embed_mask[-1]["embed"] = masks_dictified[-1]["embed"]
                embed_out=spd_model(x_repeated, embed_mask)

                unembed_mask = layer_mask_constructor.copy()
                unembed_mask[-1]["unembed"] = masks_dictified[-1]["unembed"]
                unembed_out=spd_model(x_repeated,unembed_mask)
                l_stochastic_recon_layerwise = l_stochastic_recon_layerwise + torch.linalg.matrix_norm(target_out_tiled-embed_out)**2 + torch.linalg.matrix_norm(target_out_tiled-unembed_out)**2

                for l in range(l): 
                    for name in ["in", "out"]:
                        layer_mask = layer_mask_constructor.copy()
                        layer_mask[l][name] = masks_dictified[l][name]
                        layer_out = spd_model(x_repeated,layer_mask)
                        l_stochastic_recon_layerwise = l_stochastic_recon_layerwise + torch.linalg.matrix_norm(target_out_tiled-embed_out)**2
                        
                l_stochastic_recon_layerwise /= (S * spd_model.num_matrices)
                                
                beta1, beta2, beta3 = spd_model.hypers["beta_1"], spd_model.hypers["beta_2"], spd_model.hypers["beta_3"]

                # Loss computations
                loss = l_faithfulness + beta1*l_stochastic_recon + beta2*l_stochastic_recon_layerwise + beta3*l_importance_minimality
                loss.backward()
                optimizer.step() 
                t.set_postfix(loss=loss.item())
            print(f"(Last batch) Faithfulness: {l_faithfulness}, Stoch Rec: {l_stochastic_recon}, Stoch Rec Layerwise: {l_stochastic_recon_layerwise}, Importance Min: {l_importance_minimality}")



In [106]:
if __name__ == "__main__":

    device = DEVICE # specified at beginning of file
    config = CONFIG
    train_config = TRAIN_CONFIG
    
    dataloader = DataLoader(dataset, batch_size=BATCH_SIZE, shuffle=True)
    spd_model = SPDModelMLP(toy_model, config, device)
    train_spd(spd_model, dataloader, train_config, num_epochs=NUM_EPOCHS)

Training on device mps
Starting epoch 1, lr = 8.00e-04


Epoch 1/25: 100%|████████████████████████████████████████████████████████████████████████████████████████| 118/118 [00:02<00:00, 47.14it/s, loss=1.27]


(Last batch) Faithfulness: 0.0020853562746196985, Stoch Rec: 0.8647415041923523, Stoch Rec Layerwise: 0.3849048614501953, Importance Min: 0.15678267180919647
Starting epoch 2, lr = 8.00e-04


Epoch 2/25: 100%|████████████████████████████████████████████████████████████████████████████████████████| 118/118 [00:02<00:00, 49.19it/s, loss=1.12]


(Last batch) Faithfulness: 0.002012841636314988, Stoch Rec: 0.7959372401237488, Stoch Rec Layerwise: 0.3166409134864807, Importance Min: 0.03770322725176811
Starting epoch 3, lr = 8.00e-04


Epoch 3/25: 100%|████████████████████████████████████████████████████████████████████████████████████████| 118/118 [00:02<00:00, 49.94it/s, loss=1.15]


(Last batch) Faithfulness: 0.0019466744270175695, Stoch Rec: 0.8418064117431641, Stoch Rec Layerwise: 0.2982615530490875, Importance Min: 0.05962591990828514
Starting epoch 4, lr = 8.00e-04


Epoch 4/25: 100%|████████████████████████████████████████████████████████████████████████████████████████| 118/118 [00:02<00:00, 48.37it/s, loss=1.18]


(Last batch) Faithfulness: 0.001888036960735917, Stoch Rec: 0.8582914471626282, Stoch Rec Layerwise: 0.3030981123447418, Importance Min: 0.11846653372049332
Starting epoch 5, lr = 8.00e-04


Epoch 5/25: 100%|████████████████████████████████████████████████████████████████████████████████████████| 118/118 [00:02<00:00, 48.82it/s, loss=1.18]


(Last batch) Faithfulness: 0.0018354130443185568, Stoch Rec: 0.8585014343261719, Stoch Rec Layerwise: 0.30919575691223145, Importance Min: 0.14506344497203827
Starting epoch 6, lr = 8.00e-04


Epoch 6/25: 100%|████████████████████████████████████████████████████████████████████████████████████████| 118/118 [00:02<00:00, 48.94it/s, loss=0.96]


(Last batch) Faithfulness: 0.0017866144189611077, Stoch Rec: 0.6929793953895569, Stoch Rec Layerwise: 0.24649257957935333, Importance Min: 0.1913660317659378
Starting epoch 7, lr = 8.00e-04


Epoch 7/25: 100%|████████████████████████████████████████████████████████████████████████████████████████| 118/118 [00:02<00:00, 48.87it/s, loss=1.35]


(Last batch) Faithfulness: 0.0017462731339037418, Stoch Rec: 0.9930356740951538, Stoch Rec Layerwise: 0.3165791630744934, Importance Min: 0.38594111800193787
Starting epoch 8, lr = 8.00e-04


Epoch 8/25: 100%|████████████████████████████████████████████████████████████████████████████████████████| 118/118 [00:02<00:00, 47.41it/s, loss=1.14]


(Last batch) Faithfulness: 0.0017149151535704732, Stoch Rec: 0.7765235900878906, Stoch Rec Layerwise: 0.29021525382995605, Importance Min: 0.722955048084259
Starting epoch 9, lr = 8.00e-04


Epoch 9/25: 100%|████████████████████████████████████████████████████████████████████████████████████████| 118/118 [00:02<00:00, 48.24it/s, loss=1.31]


(Last batch) Faithfulness: 0.001686715753749013, Stoch Rec: 0.8582080006599426, Stoch Rec Layerwise: 0.3389424681663513, Importance Min: 1.14987051486969
Starting epoch 10, lr = 8.00e-04


Epoch 10/25: 100%|███████████████████████████████████████████████████████████████████████████████████████| 118/118 [00:02<00:00, 48.78it/s, loss=1.13]


(Last batch) Faithfulness: 0.0016565141268074512, Stoch Rec: 0.728486955165863, Stoch Rec Layerwise: 0.2528616487979889, Importance Min: 1.5116595029830933
Starting epoch 11, lr = 8.00e-04


Epoch 11/25: 100%|███████████████████████████████████████████████████████████████████████████████████████| 118/118 [00:02<00:00, 49.44it/s, loss=1.12]


(Last batch) Faithfulness: 0.001639935071580112, Stoch Rec: 0.6906450986862183, Stoch Rec Layerwise: 0.2563306391239166, Importance Min: 1.6820393800735474
Starting epoch 12, lr = 8.00e-04


Epoch 12/25: 100%|███████████████████████████████████████████████████████████████████████████████████████| 118/118 [00:02<00:00, 49.17it/s, loss=1.18]


(Last batch) Faithfulness: 0.0016250226181000471, Stoch Rec: 0.6643909215927124, Stoch Rec Layerwise: 0.2479252815246582, Importance Min: 2.709808111190796
Starting epoch 13, lr = 8.00e-04


Epoch 13/25: 100%|███████████████████████████████████████████████████████████████████████████████████████| 118/118 [00:02<00:00, 48.52it/s, loss=1.26]


(Last batch) Faithfulness: 0.001602047705091536, Stoch Rec: 0.6674121618270874, Stoch Rec Layerwise: 0.2529870569705963, Importance Min: 3.4084484577178955
Starting epoch 14, lr = 8.00e-04


Epoch 14/25: 100%|████████████████████████████████████████████████████████████████████████████████████████| 118/118 [00:02<00:00, 49.68it/s, loss=1.4]


(Last batch) Faithfulness: 0.0015834869118407369, Stoch Rec: 0.687835693359375, Stoch Rec Layerwise: 0.26173466444015503, Importance Min: 4.462008953094482
Starting epoch 15, lr = 8.00e-04


Epoch 15/25: 100%|███████████████████████████████████████████████████████████████████████████████████████| 118/118 [00:02<00:00, 49.09it/s, loss=1.37]


(Last batch) Faithfulness: 0.0015691678272560239, Stoch Rec: 0.646634042263031, Stoch Rec Layerwise: 0.25131699442863464, Importance Min: 4.746394634246826
Starting epoch 16, lr = 8.00e-04


Epoch 16/25: 100%|███████████████████████████████████████████████████████████████████████████████████████| 118/118 [00:02<00:00, 49.45it/s, loss=1.32]


(Last batch) Faithfulness: 0.0015613391296938062, Stoch Rec: 0.5961915254592896, Stoch Rec Layerwise: 0.2123069316148758, Importance Min: 5.140549182891846
Starting epoch 17, lr = 8.00e-04


Epoch 17/25: 100%|███████████████████████████████████████████████████████████████████████████████████████| 118/118 [00:02<00:00, 48.09it/s, loss=1.37]


(Last batch) Faithfulness: 0.001552774803712964, Stoch Rec: 0.5644382238388062, Stoch Rec Layerwise: 0.23972158133983612, Importance Min: 5.598663330078125
Starting epoch 18, lr = 8.00e-04


Epoch 18/25: 100%|███████████████████████████████████████████████████████████████████████████████████████| 118/118 [00:02<00:00, 49.39it/s, loss=1.49]


(Last batch) Faithfulness: 0.001542234793305397, Stoch Rec: 0.6072854995727539, Stoch Rec Layerwise: 0.23044335842132568, Importance Min: 6.460878849029541
Starting epoch 19, lr = 8.00e-04


Epoch 19/25: 100%|███████████████████████████████████████████████████████████████████████████████████████| 118/118 [00:02<00:00, 48.86it/s, loss=1.36]


(Last batch) Faithfulness: 0.001534105627797544, Stoch Rec: 0.5499958992004395, Stoch Rec Layerwise: 0.2130579650402069, Importance Min: 5.904991149902344
Starting epoch 20, lr = 8.00e-04


Epoch 20/25: 100%|███████████████████████████████████████████████████████████████████████████████████████| 118/118 [00:02<00:00, 45.51it/s, loss=1.32]


(Last batch) Faithfulness: 0.0015282455133274198, Stoch Rec: 0.5234001278877258, Stoch Rec Layerwise: 0.21105948090553284, Importance Min: 5.80555534362793
Starting epoch 21, lr = 8.00e-04


Epoch 21/25: 100%|███████████████████████████████████████████████████████████████████████████████████████| 118/118 [00:02<00:00, 42.42it/s, loss=1.19]


(Last batch) Faithfulness: 0.0015207236865535378, Stoch Rec: 0.46776431798934937, Stoch Rec Layerwise: 0.18135127425193787, Importance Min: 5.426008224487305
Starting epoch 22, lr = 8.00e-04


Epoch 22/25: 100%|███████████████████████████████████████████████████████████████████████████████████████| 118/118 [00:03<00:00, 36.81it/s, loss=1.28]


(Last batch) Faithfulness: 0.001516428543254733, Stoch Rec: 0.4723374843597412, Stoch Rec Layerwise: 0.19280461966991425, Importance Min: 6.134443283081055
Starting epoch 23, lr = 8.00e-04


Epoch 23/25: 100%|████████████████████████████████████████████████████████████████████████████████████████| 118/118 [00:02<00:00, 43.40it/s, loss=1.4]


(Last batch) Faithfulness: 0.0015142298070713878, Stoch Rec: 0.5100651979446411, Stoch Rec Layerwise: 0.20373240113258362, Importance Min: 6.881712436676025
Starting epoch 24, lr = 8.00e-04


Epoch 24/25: 100%|███████████████████████████████████████████████████████████████████████████████████████| 118/118 [00:02<00:00, 47.69it/s, loss=1.46]


(Last batch) Faithfulness: 0.0015123555203899741, Stoch Rec: 0.5451668500900269, Stoch Rec Layerwise: 0.22488179802894592, Importance Min: 6.857992649078369
Starting epoch 25, lr = 8.00e-04


Epoch 25/25: 100%|███████████████████████████████████████████████████████████████████████████████████████| 118/118 [00:02<00:00, 47.82it/s, loss=1.39]

(Last batch) Faithfulness: 0.0015070681693032384, Stoch Rec: 0.5214187502861023, Stoch Rec Layerwise: 0.2080308496952057, Importance Min: 6.5756354331970215



