In [None]:
# import all the torch stuff, as well as the ViT stuff and such
import plotly.express as px
import torch
from jaxtyping import Int, Float
from typing import List, Optional, Tuple
from tqdm import tqdm
from transformer_lens.hook_points import HookPoint
from transformer_lens import utils, HookedTransformer, HookedTransformerConfig, FactoredMatrix, ActivationCache
import circuitsvis as cv

In [2]:
import torch
from torch import Tensor
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
import einops
import typing

In [None]:
# create your original toy model

In [1]:
# create a SPDHookedTransformer Class
# that extends the hookedTransformer
# which inherits all the previous class stuff
# but also contains the SPD training algo?

# how is everything actually structured? 
"""
Maybe I should start with implementing it on a simple MLP setup. 

Orig network:
- 1 layer MLP 5-2-5 (defined as usual with torch.Sequential presumably)

SPD network:
let's give it 10 subcomponents per layer
then it's just like 10 matmuls? 
maybe i should define it like ant did in the toy models of superposition paper
stick them all into a trenchcoat
this might be easier once I have defined the toy modle

"""

In [4]:
config = {
    "num_layers": 2,
    "in_size": 10,
    "hidden_size": 5,
    "subcomponents_per_layer": 5, 
    "beta_1": 1, 
    "beta_2": 1, 
    "beta_3": 1, 
    "causal_imp_min": 1, 
    "num_samples": 10,
    "importance_mlp_size": 10,
}

In [7]:
class ToyResidMLP(nn.Module):
    def __init__(self, config, device="mps"):
        super().__init__()
        # Initialize Weights for the
        self.num_layers, self.in_size, self.hidden_size = config["num_layers"], config["in_size"], config["hidden_size"]
        self.device = device
        self.W_in = nn.ParameterList([torch.empty((in_size, hidden_size), device=device) for i in range(num_layers)])
        self.W_out = nn.ParameterList([torch.empty((hidden_size, in_size), device=device) for i in range(num_layers)])
        self.b = nn.ParameterList([torch.zeros((hidden_size,), device=device) for i in range(num_layers)])

        for param in self.W_in + self.W_out: 
            nn.init.xavier_normal_(param)
        
    def forward(self, x): 
        # apply matmul lol
        N, D = x.shape
        assert D == self.in_size, f"Input shape does not match model's accepted size {self.in_size}"
        # something to ensure shape?
        
        x_resid = x
        for i in range(self.num_layers):
            hidden = F.relu(torch.einsum("nd,dh -> nh", x_resid, self.W_in[i]))
            layer_out = torch.einsum("nh,hd -> nd", hidden, self.W_out[i])
            x_resid += layer_out
        # am I supposed to have a embed and out?
        return x_resid


def toy_train(model, lr, num_steps): 
    # init AdamW optimizer on model
    # wrap train function for tqdm
    # for step in num_step:
    # generate batch for the step
    pass

In [None]:
class SPDModelMLP(nn.Module): 
    # fix the betas stuff below that's basically a type hint
    def __init__(self, target_model, config, device="mps"): 
        self.device = device
        self.target_model = target_model
        self.num_layers, self.in_size, self.hidden_size, self.imp_hidden_size = config["num_layers"], config["in_size"], config["hidden_size"], config["importance_mlp_size"]
        assert self.device == target_model.device, "Models not on same device"
        self.C = config["subcomponents_per_layer"]
        
        # Subcomponent vectors, each of shape C by in_size; to be used
        # with outer product to create our low-rank subcomponent matrices
        self.V_in = nn.ParameterList([torch.empty((self.C, in_size,), device=device) for i in range(num_layers)])
        self.U_in = nn.ParameterList([torch.empty((self.C, hidden_size,), device=device) for i in range(num_layers)]) 
        self.V_out = nn.ParameterList([torch.empty((self.C, hidden_size,), device=device) for i in range(num_layers)])
        self.U_out = nn.ParameterList([torch.empty((self.C, in_size,), device=device) for i in range(num_layers)])
        
        # idk what you do with the biases lol
        self.b = nn.ParameterList([torch.zeros((hidden_size,), device=device) for i in range(num_layers)])
        
        # imp_W_in and out etc are the weights for the importance predictor
        # should be C networks per layer, mapping 1 -> C -> 1
        self.imp_W_in = nn.ParameterList([torch.empty(C, 1, self.imp_hidden_size) for i in range(num_layers)])
        self.imp_W_out = nn.ParameterList([torch.empty(C, self.imp_hidden_size, 1) for i in range(num_layers)])
        self.imp_b_in = nn.ParameterList([torch.empty(C, self.imp_hidden_size) for i in range(num_layers)])
        self.imp_b_out = nn.ParameterList([torch.empty(C, 1) for i in range(num_layers)])
        
        
        # self.pred_weights = torch.empty((num_layers, self.C), device=device)

        for param in self.V_in + self.U_in + self.V_out + self.U_out + self.b: 
            # using xavier anyway -- note that the variance etc is
            # changed because we take the outer product in the 
            # forward pass
            nn.init.xavier_normal_(param)

        
    def forward_with_activations(self, x): # Regular run. Unclear whether I should have masking when I do a regular forward pass.
        activations = []
        weight_matrices = []
        x_resid = x
        
        for i in self.num_layers:
            # Use outer product to create weights for the layer, then sum all the subcomponents
            W_in = torch.einsum("ci,ch-> cih", self.V_in[i], self.U_in[i]).sum(dim=0) # shape i h
            W_out = torch.einsum("ch,ci-> chi", self.V_out[i], self.U_out[i]).sum(dim=0) # shape h i
            weight_matrices.append({"in": W_in, "out", W_out})
            
            layer_activations = torch.einsum("di,ih -> dh", x_resid, W_in) + self.b[i]
            activations.append(layer_activations) 
            # THIS IS WRONG, NOT WHAT I WANT
            # NEED TO MAKE ACTIVATIONS BE THE V FOR EACH INDIV COMPONENT. REORDER THE 
            # OPERATIONS SO THAT THE COMPONENTS ACTIV FIRST, THEN SUM ALONG C? 
            # NO :( I USE THAT PRECOMPUTED WEIGHT FOR LATER. I GUESS WE CAN JUST 
            # RUN IT TWICE. SEEMS LIKE THIS MIGHT JUST BE INEFFICIENT AF 
            # OR MORE LIKELY SKILL ISSUE
            
            layer_out = torch.einsum("dh,hi->di", F.relu(layer_activations), W_out)
            
            x_resid += layer_out
        return x_resid, activations, weight_matrices


In [None]:
def generate_batch(config): 
    # return shape N, config.in_size 
    pass

def train_SPD(spd_model): # could also implement this by passing in the original model
    # generate batch somehow
    x = generate_batch # TODO (above). possibly put in the model? idk if needed.
    # x is shape N by in_size
    target_model = spd_model.target_model
    
    with torch.no_grad(): 
        target_out = spd_model.target_model(x)

    # MSE Loss
    spd_output, spd_activations spd_weights = model.forward_with_activations(x)
    squared_error = 0
    for i in range(num_layers):
        in_diff = target_model.W_in[i] - spd_weights[i]["in"]
        out_diff = target_model.W_out[i] - spd_weights[i]["out"]
        
        # torch.linalg.matrix_norm defaults to the frobenius norm
        # this takes the frobenius norm of the diff and then squares it
        squared_error_layer = torch.linalg.matrix_norm(in_diff) ** 2 + torch.linalg.matrix_norm(out_diff) ** 2 
        squared_error += squared_error_layer
        
    mean_squared_error = squared_error/num_layers
    l_faithfulness = mean_squared_error

    # Predict importance

    for i in range(num_layers): 
        # confused about this -- should the activations 

    

    
        


    

    
    



In [None]:
"""
forward pass is like
0. generate or sample datafs
1. run the regular model
2. run SPD model
3. faithfulness loss (check that the SPD model weights sum to the original model)
4. Get 'intermediate activations' and then get MLP predictions for each layer
5. compute importance-minimality loss
6. Sample Rs and compute masked weights
7. run model with random masking, with 
"""