Model is single hidden layer: x -> n_h -> y  (size of x and y is equal)

W1 = (n_h, x), W2 = (y, n_h)
x = 3 features (should be able to vary this)
y = abs(x) 
n_h = 2 neurons (should be able to vary this)
loss = MSE

- Vary the activation function used in each layer (e.g. ReLU or identity, by default, ReLU)

- Vary the bias value used in each layer (all equal or specific values)

- Vary the feature sparsity (all equal or specific values (e.g. function of index))

- Vary the feature importance (all equal or specific values (e.g. function of index))

- Optional: Vary optimizer and lr schedule

In [2]:
"""Notebook settings and imports."""

%load_ext autoreload
%autoreload 2
# %flow mode reactive

import os

from dataclasses import dataclass, field
from typing import Callable, List, Optional, Tuple, Union

import numpy as np
import torch as t

from einops import asnumpy, einsum, rearrange, reduce, repeat, pack, parse_shape, unpack
from einops.layers.torch import Rearrange, Reduce
from jaxtyping import Float, Int
from matplotlib import pyplot as plt
from plotly import express as px
from plotly import graph_objects as go
from plotly import io as pio
from rich import print as rprint
from torch import nn, optim
from torch.nn import functional as F
from tqdm.notebook import tqdm

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


In [None]:
@dataclass
class CisConfig:
    n_instances: int = 2
    n_feat: int = 2
    n_hidden: int = 3
    act_fn: List[Callable] = field(default_factory=lambda: [F.relu, F.relu])
    b1: str | Float[t.Tensor, "inst hid"] = field(default_factory=lambda: "0")
    b2: str | Float[t.Tensor, "inst hid"] = field(default_factory=lambda: "0")
    feat_sparsity: float| t.Tensor = 0.99
    feat_importance: float | t.Tensor = 1.0
    optimizer: Callable = t.optim.Adam

    def __post_init__(self):
        """Ensure attribute values are valid."""
        pass

In [None]:
class Cis(nn.Module):
    """Computation in Superposition toy model."""

    # Some attribute type hints
    W1: Float[t.Tensor, "inst feat hid"]
    W2: Float[t.Tensor, "inst hid feat"]
    b1: Float[t.Tensor, "inst hid"]
    b2: Float[t.Tensor, "inst feat"]
    s: Float[t.Tensor, "inst feat"]  # feature sparsity
    i: Float[t.Tensor, "inst feat"]  # feature importance


def __init__(self, cfg: CisConfig):
    """Initializes model params."""
    super().__init__()
    self.cfg = cfg

    # Model Weights
    self.W1 = nn.init.xavier_normal_(t.empty(cfg.n_instances, cfg.feat, cfg.n_hidden))
    self.W2 = nn.init.xavier_normal_(t.empty(cfg.n_instances, cfg.n_hidden, cfg.feat))

    # Model Biases
    self.b1 == nn.Parameter(t.zeros(cfg.n_instances, cfg.n_hidden)) if cfg.b1 == "0" else cfg.b1
    self.b2 == nn.Parameter(t.zeros(cfg.n_instances, cfg.feat)) if cfg.b2 == "0" else cfg.b2

    # Sparsities
    if isinstance(cfg.feat_sparsity, float):
        self.s = t.full((cfg.n_instances, cfg.feat), cfg.feat_sparsity)
    else:
        self.s = cfg.feat_sparsity

    # Importances
    if isinstance(self.cfg.feat_importance, float):
        self.i = t.full((self.cfg.feat,), self.cfg.feat_importance)
    elif callable(self.cfg.feat_importance):
        self.i = cfg.feat_importance


def gen_batch(self, batch_sz: int) -> Float[t.Tensor, "batch inst feat"]:
    """Generates a batch of data (sparse feature vals on [-1, 1])."""

    # Randomly generate features vals, and for each, randomly determine which samples are non-zero
    x = t.rand(batch_sz, self.cfg.n_instances, self.cfg.feat) * 2 - 1  # [-1, 1]
    is_active = t.rand(batch_sz, self.cfg.n_instances, self.cfg.feat) < (1 - self.s)
    
    return x * is_active


def forward(
    self, x: Float[t.Tensor, "batch inst feat"], fx: Callable = t.abs
) -> Float[t.Tensor, ""]:
    """Runs a forward pass through model returning the loss."""

    # Hidden layer
    h = einsum(x, self.W1, "batch inst feat, feat hid -> batch inst hid")
    h = self.cfg.act_fn[0](h + self.b1)

    # Output layer
    y = einsum(h, self.W2, "batch inst hid, hid feat -> batch inst feat")
    y = self.cfg.act_fn[1](y + self.b2)

    # Compute weighted MSE loss
    y_true = fx(x)
    loss = reduce(((y - y_true) ** 2 * self.i), "batch inst feat -> ", "mean")

    return loss


def optimize(
    self, optimizer: t.optim.Optimizer, batch_sz: int, steps: int, logging_freq: int
):
    """Optimizes the model."""

    losses = []
    pbar = tqdm(range(steps), desc="Training")

    for step in pbar:
        x = self.gen_batch(batch_sz)
        loss = self.forward(x)
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        # Log progress
        if step % logging_freq == 0 or (step + 1 == steps):
            losses.append(loss.item())
            pbar.set_postfix({"loss": f"{loss.item():.4f}"})

    return losses

NameError: name 'CisConfig' is not defined

In [3]:
x = t.rand(2, 2, 3) * 2 - 1

In [None]:
reduce(x, "batch instance feat -> ", "sum")

tensor(1.1897)

In [10]:
x

tensor([[[ 0.2367,  0.0967, -0.2356],
         [-0.1621, -0.4485,  0.5213]],

        [[-0.3576,  0.8022,  0.2868],
         [ 0.1523, -0.2837,  0.5812]]])