Model is single hidden layer: x -> n_h -> y  (size of x and y is equal)

W1 = (n_h, x), W2 = (y, n_h)
x = 3 features (should be able to vary this)
y = abs(x) 
n_h = 2 neurons (should be able to vary this)
loss = MSE

- Vary the activation function used in each layer (e.g. ReLU or identity, by default, ReLU)

- Vary the bias value used in each layer (all equal, specific values, or function of index)

- Vary the feature sparsity (all equal, specific values, or function of index)

- Vary the feature importance (all equal, specific values, or function of index)

- Optional: Vary optimizer and lr schedule

In [3]:
"""Notebook settings and imports."""

%load_ext autoreload
%autoreload 2
# %flow mode reactive

import os

from dataclasses import dataclass, field
from typing import Callable, List, Optional, Tuple, Union

import numpy as np
import torch as t

from einops import asnumpy, einsum, rearrange, reduce, repeat, pack, parse_shape, unpack
from einops.layers.torch import Rearrange, Reduce
from jaxtyping import Float, Int
from matplotlib import pyplot as plt
from plotly import express as px
from plotly import graph_objects as go
from plotly import io as pio
from rich import print as rprint
from torch import nn, optim
from torch.nn import functional as F
from tqdm.notebook import tqdm

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


In [None]:
@dataclass
class CisConfig:
    n_instances: int = 2
    n_feat: int = 2
    n_hidden: int = 3
    act_fn: List[Callable] = field(default_factory=lambda: [F.relu, F.relu])
    b: str | t.Tensor | Callable = field(default_factory=lambda: "0")
    feat_sparsity: float| t.Tensor | Callable = 0.99
    feat_importance: float | t.Tensor | Callable = 1.0
    optimizer: Callable = t.optim.Adam

    def __post_init__(self):
        """Ensure attribute values are valid."""
        pass

In [None]:
class Cis(nn.Module):
    """Computation in Superposition toy model."""

    # Some attribute type hints
    W1: Float[t.Tensor, "n_instances n_feat n_hidden"]
    W2: Float[t.Tensor, "n_instances n_hidden n_feat"]
    b1: Float[t.Tensor, "n_instances n_hidden"]
    b2: Float[t.Tensor, "n_instances n_feat"]
    S: Float[t.Tensor, "n_instances n_feat"]  # feature sparsity
    I: Float[t.Tensor, "n_instances n_feat"]  # feature importance


def __init__(self, cfg: CisConfig):
    """Initializes model params."""
    super().__init__()
    self.cfg = cfg

    # Weights
    self.W1 = nn.init.xavier_normal_(t.empty(cfg.n_instances, cfg.n_feat, cfg.n_hidden))
    self.W2 = nn.init.xavier_normal_(t.empty(cfg.n_instances, cfg.n_hidden, cfg.n_feat))

    # Biases
    if cfg.b == "0":
        self.b1 = nn.Parameter(t.zeros(cfg.n_instances, cfg.n_hidden))
        self.b2 = nn.Parameter(t.zeros(cfg.n_instances, cfg.n_feat))
    elif isinstance(cfg.b, t.Tensor):
        self.b1 = t.Tensor([nn.Parameter(cfg.b[0, :, :]) for i in range(cfg.n_instances)])
        self.b2 = t.Tensor([nn.Parameter(cfg.b[1, :, :]) for i in range(cfg.n_instances)])

    # Sparsities
    if isinstance(self.cfg.feat_sparsity, float):
        self.S = t.full((self.cfg.n_feat,), self.cfg.feat_sparsity)
    elif callable(self.cfg.feat_sparsity):
        self.S = t.tensor([self.cfg.feat_sparsity(i) for i in range(self.cfg.n_feat)])

    # Importances
    if isinstance(self.cfg.feat_importance, float):
        self.I = t.full((self.cfg.n_feat,), self.cfg.feat_importance)
    elif callable(self.cfg.feat_importance):
        self.I = t.tensor([self.cfg.feat_importance(i) for i in range(self.cfg.n_feat)])


def gen_batch(self, batch_sz: int) -> Float[t.Tensor, "batch_sz n_instances n_feat"]:
    """Generates a batch of data (sparse feature vals on [-1, 1])."""

    # Randomly generate features vals, and for each, randomly determine which samples are non-zero
    x = t.rand(batch_sz, self.cfg.n_instances, self.cfg.n_feat) * 2 - 1  # [-1, 1]
    is_active = t.rand(batch_sz, self.cfg.n_instances, self.cfg.n_feat) < self.S
    
    return x * is_active


def forward(self, x: Float[t.Tensor, "batch_sz n_instances n_feat"]) -> Float[t.Tensor, ""]:
    """Forward pass following the paper's model setup."""

    # First layer
    h = einsum(x, self.W1, "batch instance feat, feat hidden -> batch instance hidden")
    h = self.cfg.act_fn[0](h + self.b1)

    # Second layer
    y = einsum(h, self.W2, "batch instance hidden, hidden feat -> batch instance feat")
    y = self.cfg.act_fn[1](y + self.b2)

    # Compute weighted MSE loss
    y_true = t.abs(x)
    loss = reduce(((y - y_true) ** 2 * self.I), "batch instance feat -> instance", "mean")

    return loss


def optimize(
    self, optimizer: t.optim.Optimizer, batch_sz: int, steps: int, logging_freq: int
):
    """Optimizes the model."""

    losses = []
    pbar = tqdm(range(steps), desc="Training")

    for step in pbar:
        x = self.gen_batch(batch_sz)
        loss = self.forward(x)
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        # Log progress
        if step % logging_freq == 0 or (step + 1 == steps):
            losses.append(loss.item())
            pbar.set_postfix({"loss": f"{loss.item():.4f}"})

    return losses