Model is single hidden layer: x -> n_h -> y  (size of x and y is equal)

W1 = (n_h, x), W2 = (y, n_h)
x = 3 features (should be able to vary this)
y = abs(x) 
n_h = 2 neurons (should be able to vary this)
loss = MSE

- Vary the activation function used in each layer (e.g. ReLU or identity, by default, ReLU)

- Vary the bias value used in each layer (all equal or specific values)

- Vary the feature sparsity (all equal or specific values (e.g. function of index))

- Vary the feature importance (all equal or specific values (e.g. function of index))

- Vary the loss fn

- Optional: Vary optimizer and lr schedule

In [2]:
"""Notebook settings and imports."""

%load_ext autoreload
%autoreload 2
# %flow mode reactive

import os

from dataclasses import dataclass, field
from typing import Callable, List, Optional, Tuple, Union

import numpy as np
import torch as t

from einops import asnumpy, einsum, rearrange, reduce, repeat, pack, parse_shape, unpack
from einops.layers.torch import Rearrange, Reduce
from jaxtyping import Float, Int
from matplotlib import pyplot as plt
from plotly import express as px
from plotly import graph_objects as go
from plotly import io as pio
from rich import print as rprint
from torch import nn, optim
from torch.nn import functional as F
from tqdm.notebook import tqdm

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


In [61]:
@dataclass
class CisConfig:
    n_instances: int = 2
    n_feat: int = 2
    n_hidden: int = 4
    act_fn: List[Callable] = field(default_factory=lambda: [F.relu, F.relu])
    b1: str | Float[t.Tensor, "inst hid"] = field(default_factory=lambda: "0")
    b2: str | Float[t.Tensor, "inst hid"] = field(default_factory=lambda: "0")
    feat_sparsity: float| t.Tensor = 0
    feat_importance: float | t.Tensor = 1.0
    optimizer: Callable = t.optim.Adam

    def __post_init__(self):
        """Ensure attribute values are valid."""
        pass

In [70]:
class Cis(nn.Module):
    """Computation in Superposition toy model."""

    # Some attribute type hints
    W1: Float[t.Tensor, "inst feat hid"]
    W2: Float[t.Tensor, "inst hid feat"]
    b1: Float[t.Tensor, "inst hid"]
    b2: Float[t.Tensor, "inst feat"]
    s: Float[t.Tensor, "inst feat"]  # feature sparsity
    i: Float[t.Tensor, "inst feat"]  # feature importance


    def __init__(self, cfg: CisConfig):
        """Initializes model params."""
        super().__init__()
        self.cfg = cfg

        # Model Weights
        self.W1 = nn.Parameter(nn.init.xavier_normal_(t.empty(cfg.n_instances, cfg.n_feat, cfg.n_hidden)))
        self.W2 = nn.Parameter(nn.init.xavier_normal_(t.empty(cfg.n_instances, cfg.n_hidden, cfg.n_feat)))

        # Model Biases
        # self.b1 = nn.Parameter(t.zeros(cfg.n_instances, cfg.n_hidden)) if cfg.b1 == "0" else cfg.b1
        self.b1 = t.zeros(cfg.n_instances, cfg.n_hidden) if cfg.b1 == "0" else cfg.b1
        self.b2 = nn.Parameter(t.zeros(cfg.n_instances, cfg.n_feat)) if cfg.b2 == "0" else cfg.b2

        # Sparsities
        if isinstance(cfg.feat_sparsity, float):
            self.s = t.full((cfg.n_instances, cfg.n_feat), cfg.feat_sparsity)
        else:
            self.s = cfg.feat_sparsity

        # Importances
        if isinstance(self.cfg.feat_importance, float):
            self.i = t.full((self.cfg.n_feat,), self.cfg.feat_importance)
        elif callable(self.cfg.feat_importance):
            self.i = cfg.feat_importance


    def gen_batch(self, batch_sz: int) -> Float[t.Tensor, "batch inst feat"]:
        """Generates a batch of data (sparse feature vals on [-1, 1])."""

        # Randomly generate features vals, and for each, randomly determine which samples are non-zero
        x = t.rand(batch_sz, self.cfg.n_instances, self.cfg.n_feat) * 2 - 1  # [-1, 1]
        is_active = t.rand(batch_sz, self.cfg.n_instances, self.cfg.n_feat) < (1 - self.s)
        
        return x * is_active


    def forward(
        self, x: Float[t.Tensor, "batch inst feat"], fx: Callable = t.abs
    ) -> Float[t.Tensor, ""]:
        """Runs a forward pass through model returning the loss."""

        # Hidden layer
        h = einsum(x, self.W1, "batch inst feat, inst feat hid -> batch inst hid")
        h = self.cfg.act_fn[0](h + self.b1)

        # Output layer
        y = einsum(h, self.W2, "batch inst hid, inst hid feat -> batch inst feat")
        y = self.cfg.act_fn[1](y + self.b2)

        # Compute weighted MSE loss
        y_true = fx(x)
        loss = reduce(((y - y_true) ** 2 * self.i), "batch inst feat -> ", "mean")

        return loss


    def optimize(
        self, optimizer: t.optim.Optimizer, batch_sz: int, steps: int, logging_freq: int
    ):
        """Optimizes the model."""

        losses = []
        pbar = tqdm(range(steps), desc="Training")

        for step in pbar:
            x = self.gen_batch(batch_sz)
            loss = self.forward(x)
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
            # Log progress
            if step % logging_freq == 0 or (step + 1 == steps):
                losses.append(loss.item())
                pbar.set_postfix({"loss": f"{loss.item():.4f}"})

        return losses

In [99]:
config = CisConfig(n_instances=1, n_feat=3, n_hidden=6, feat_sparsity=0, feat_importance=1.0)

In [100]:
model = Cis(config)

In [101]:
batch_sz = 128
steps = 5000
logging_freq = steps // 10

model.optimize(t.optim.Adam(model.parameters()), batch_sz, steps, logging_freq)

Training:   0%|          | 0/5000 [00:00<?, ?it/s]

[0.33576270937919617,
 0.06316550821065903,
 0.028050122782588005,
 0.0024157902225852013,
 0.0004883696092292666,
 6.413614755729213e-05,
 6.53607412459678e-06,
 3.1265510642697336e-06,
 2.5135636860795785e-07,
 9.928488680088776e-07,
 2.62755275315385e-08]

In [102]:
model.W1

Parameter containing:
tensor([[[ 1.0469, -0.0106, -0.8279, -0.0021, -0.0026, -0.0107],
         [-0.0100, -0.0017, -0.0079, -1.1026,  1.3598, -0.0017],
         [ 0.0028,  1.1591,  0.0022,  0.0217,  0.0267, -1.1735]]],
       requires_grad=True)

In [103]:
model.W2

Parameter containing:
tensor([[[ 9.5527e-01,  1.7602e-03,  8.7130e-03],
         [-2.2676e-03, -1.7184e-02,  8.6281e-01],
         [ 1.2080e+00, -2.2295e-03, -1.1111e-02],
         [-8.7021e-03,  9.0732e-01, -1.3165e-03],
         [ 7.0212e-03,  7.3574e-01,  1.0730e-03],
         [ 2.2604e-03,  1.6757e-02,  8.5224e-01]]], requires_grad=True)

In [68]:
model.b1

Parameter containing:
tensor([[-1.1866e-04,  5.7883e-01, -1.6888e-02, -1.4789e-02]],
       requires_grad=True)

In [69]:
model.b2

Parameter containing:
tensor([[0.9745, 0.0626]], requires_grad=True)

In [73]:
print("Model parameters:", [name for name, _ in model.named_parameters()])

Model parameters: ['W1', 'W2', 'b2']
