In [1]:
%load_ext autoreload
%autoreload 2

In [2]:
import os
os.environ['KMP_DUPLICATE_LIB_OK'] = 'TRUE'
import torch as t
from torch import nn, Tensor
from torch.distributions.categorical import Categorical
from torch.nn import functional as F
from dataclasses import dataclass
import numpy as np
import einops
from jaxtyping import Float, Int
from typing import Optional, Callable, Union, List, Tuple
from functools import partial
from tqdm.notebook import tqdm
from dataclasses import dataclass
from rich import print as rprint
from rich.table import Table
from IPython.display import display, HTML
from pathlib import Path

import jaxtyping

In [3]:
import sys
# Make sure exercises are in the path
exercises_dir = Path("../exercises").resolve()
section_dir = (exercises_dir / "part4_superposition_and_saes").resolve()
if str(exercises_dir) not in sys.path: sys.path.append(str(exercises_dir))

from plotly_utils import imshow, line, hist
from part4_superposition_and_saes.utils import (
    plot_features_in_2d,
    plot_features_in_Nd,
    plot_features_in_Nd_discrete,
    plot_correlated_features,
    plot_feature_geometry,
    frac_active_line_plot,
    plot_features_in_2d_hierarchy
)
import part4_superposition_and_saes.tests as tests
import part4_superposition_and_saes.solutions as solutions

if t.backends.mps.is_available():
    print("current PyTorch install was "
              "built with MPS enabled.")
    if t.backends.mps.is_built():
        print("MPS is available")
        device = t.device("mps")
else:
    device = t.device("cuda" if t.cuda.is_available() else "cpu")

MAIN = __name__ == "__main__"

current PyTorch install was built with MPS enabled.
MPS is available




current PyTorch install was built with MPS enabled.
MPS is available


# TMH: Superposition in a Nonprivileged Basis

# Define Model

In [4]:
def linear_lr(step, steps):
    return (1 - (step / steps))

def constant_lr(*_):
    return 1.0

def cosine_decay_lr(step, steps):
    return np.cos(0.5 * np.pi * step / (steps - 1))

@dataclass
class Config:
    # We optimize n_instances models in a single training loop to let us sweep over
    # sparsity or importance curves  efficiently. You should treat `n_instances` as
    # kinda like a batch dimension, but one which is built into our training setup.
    n_instances: int
    n_features: int = 6
    tree_depth: int = 2
    branching_factor: int = 2
    n_hidden: int = 2
    n_correlated_pairs: int = 0
    n_anticorrelated_pairs: int = 0
    early_stopping: bool = False

class Model(nn.Module):
    W: Float[Tensor, "n_instances n_hidden n_features"]
    b_final: Float[Tensor, "n_instances n_features"]
    # Our linear map is x -> ReLU(W.T @ W @ x + b_final)

    def __init__(
        self,
        cfg: Config,
        feature_probability: Optional[Union[float, Tensor]] = None,
        importance: Optional[Union[float, Tensor]] = None,
        early_stopping: bool = False,
        device = device,
    ):
        super().__init__()
        self.cfg = cfg

        if feature_probability is None: feature_probability = t.ones(())
        if isinstance(feature_probability, float): feature_probability = t.tensor(feature_probability)
        self.feature_probability = feature_probability.to(device).broadcast_to((cfg.n_instances, cfg.n_features))
        if importance is None: importance = t.ones(())
        if isinstance(importance, float): importance = t.tensor(importance)
        self.importance = importance.to(device).broadcast_to((cfg.n_instances, cfg.n_features))

        self.W = nn.Parameter(nn.init.xavier_normal_(t.empty((cfg.n_instances, cfg.n_hidden, cfg.n_features))))
        self.b_final = nn.Parameter(t.zeros((cfg.n_instances, cfg.n_features)))
        self.to(device)

        self.early_stopping = early_stopping

        self.device = device


    def forward(
        self,
        features: Float[Tensor, "... instances features"]
    ) -> Float[Tensor, "... instances features"]:
        hidden = einops.einsum(
           features, self.W,
           "... instances features, instances hidden features -> ... instances hidden"
        )
        out = einops.einsum(
            hidden, self.W,
            "... instances hidden, instances hidden features -> ... instances features"
        )
        return F.relu(out + self.b_final)


    def generate_batch(self, batch_size) -> Float[Tensor, "batch_size instances features"]:
        '''
        Generates a batch of data. We'll return to this function later when we apply correlations.
        '''
        pass # See below for solutions


    def calculate_loss(
        self,
        out: Float[Tensor, "batch instances features"],
        batch: Float[Tensor, "batch instances features"],
    ) -> Float[Tensor, ""]:
        '''
        Calculates the loss for a given batch, using this loss described in the Toy Models paper:

            https://transformer-circuits.pub/2022/toy_model/index.html#demonstrating-setup-loss

        Remember, `model.importance` will always have shape (n_instances, n_features).
        '''
        pass # See below for solutions


    def optimize(
        self,
        batch_size: int = 1024,
        steps: int = 10_000,
        log_freq: int = 100,
        lr: float = 1e-3,
        lr_scale: Callable[[int, int], float] = constant_lr,
    ):
        '''
        Optimizes the model using the given hyperparameters.
        '''
        optimizer = t.optim.Adam(list(self.parameters()), lr=lr)

        progress_bar = tqdm(range(steps))

        for step in progress_bar:

            # Update learning rate
            step_lr = lr * lr_scale(step, steps)
            for group in optimizer.param_groups:
                group['lr'] = step_lr

            # Optimize
            optimizer.zero_grad()
            batch = self.generate_batch(batch_size)
            out = self(batch)
            loss = self.calculate_loss(out, batch)
            loss.backward()
            optimizer.step()

            # Display progress bar
            if step % log_freq == 0 or (step + 1 == steps):
                progress_bar.set_postfix(loss=loss.item()/self.cfg.n_instances, lr=step_lr)


tests.test_model(Model)

All tests in `test_model` passed!


In [5]:
def generate_batch(self: Model, batch_size) -> Float[Tensor, "batch_size instances features"]:
    '''
    Generates a batch of data. We'll return to this function later when we apply correlations.
    '''

    n_features = self.cfg.branching_factor ** (self.cfg.tree_depth + 1) - 1

    feat = t.zeros((batch_size, self.cfg.n_instances, n_features))

    rand_shape = (batch_size, self.cfg.n_instances)

    for level in range(self.cfg.tree_depth+1):
        if level == 0:
            feat[:, :, 0] = (t.ones(rand_shape))
            continue

        start_idx = 2**(level) - 1

        for i in range(2**(level-1)):

            if self.early_stopping:
                #50% chance that the child is zero
                one = (t.rand(rand_shape) > 0.5).to(t.int)
                two = (t.rand(rand_shape) > 0.5).to(t.int)
                feat[:, :, start_idx + 2*i] = ((one - two) == 1).to(t.int)
                feat[:, :, start_idx + 2*i + 1] = ((two - one) == 1).to(t.int)
            else:
                feat[:, :, start_idx + 2*i] = (t.rand(rand_shape) > 0.5).to(t.int)
                feat[:, :, start_idx + 2*i + 1] = 1 - feat[:, :, start_idx + 2*i]

            # we need to zero out inactive features
            parent_idx = (start_idx + 2*i - 1) // 2
            feat[:, :, start_idx + 2*i] *= feat[:, :, parent_idx]
            feat[:, :, start_idx + 2*i + 1] *= feat[:, :, parent_idx]

    # shave off root
    feat = feat[:, :, 1:]

    feat = feat.to(self.device)

    return feat

Model.generate_batch = generate_batch


In [6]:
def calculate_loss(
    self: Model,
    out: Float[Tensor, "batch instances features"],
    batch: Float[Tensor, "batch instances features"],
) -> Float[Tensor, ""]:
    '''
    Calculates the loss for a given batch, using this loss described in the Toy Models paper:

        https://transformer-circuits.pub/2022/toy_model/index.html#demonstrating-setup-loss

    Remember, `self.importance` will always have shape (n_instances, n_features).
    '''
    error = self.importance * ((batch - out) ** 2)
    loss = einops.reduce(error, 'batch instances features -> instances', 'mean').sum()
    return loss

def calculate_loss_per_instance(
    self: Model,
    out: Float[Tensor, "batch instances features"],
    batch: Float[Tensor, "batch instances features"],
) -> Float[Tensor, ""]:
    '''
    Calculates the loss for a given batch, using this loss described in the Toy Models paper:

        https://transformer-circuits.pub/2022/toy_model/index.html#demonstrating-setup-loss

    Remember, `self.importance` will always have shape (n_instances, n_features).
    '''
    error = self.importance * ((batch - out) ** 2)
    losses = einops.reduce(error, 'batch instances features -> instances', 'mean')
    return losses

Model.calculate_loss = calculate_loss
Model.calculate_loss_per_instance = calculate_loss_per_instance


In [35]:
Model.generate_batch = generate_batch

# Tree of (2, *2)


In [52]:
N_INSTANCES = 5
TREE_DEPTHS = [1, 2, 3, 4, 5, 6]
N_HIDDENS = [2, 3, 4, 5, 6, 7, 8, 9, 10]


losses = []
cfgs = []
models = []

for tree_depth in TREE_DEPTHS:
    n_features = 2**(tree_depth + 1) - 2

    for n_hidden in N_HIDDENS:
        if n_hidden > n_features:
            continue

        cfg = Config(
            n_instances = N_INSTANCES,
            n_features = n_features,
            tree_depth = tree_depth,
            n_hidden = n_hidden,
        )
        cfgs.append(cfg)

for cfg in cfgs:
    print(cfg)
    # continue
    
    model = Model(cfg, device=device, early_stopping=True)
    model.optimize(steps=10000)
    models.append(model)

    batch = model.generate_batch(500)
    out = model(batch)
    loss = model.calculate_loss(out, batch)
    losses.append(loss)


Config(n_instances=5, n_features=2, tree_depth=1, branching_factor=2, n_hidden=2, n_correlated_pairs=0, n_anticorrelated_pairs=0, early_stopping=False)


  0%|          | 0/10000 [00:00<?, ?it/s]

Config(n_instances=5, n_features=6, tree_depth=2, branching_factor=2, n_hidden=2, n_correlated_pairs=0, n_anticorrelated_pairs=0, early_stopping=False)


  0%|          | 0/10000 [00:00<?, ?it/s]

Config(n_instances=5, n_features=6, tree_depth=2, branching_factor=2, n_hidden=3, n_correlated_pairs=0, n_anticorrelated_pairs=0, early_stopping=False)


  0%|          | 0/10000 [00:00<?, ?it/s]

Config(n_instances=5, n_features=6, tree_depth=2, branching_factor=2, n_hidden=4, n_correlated_pairs=0, n_anticorrelated_pairs=0, early_stopping=False)


  0%|          | 0/10000 [00:00<?, ?it/s]

Config(n_instances=5, n_features=6, tree_depth=2, branching_factor=2, n_hidden=5, n_correlated_pairs=0, n_anticorrelated_pairs=0, early_stopping=False)


  0%|          | 0/10000 [00:00<?, ?it/s]

Config(n_instances=5, n_features=6, tree_depth=2, branching_factor=2, n_hidden=6, n_correlated_pairs=0, n_anticorrelated_pairs=0, early_stopping=False)


  0%|          | 0/10000 [00:00<?, ?it/s]

Config(n_instances=5, n_features=14, tree_depth=3, branching_factor=2, n_hidden=2, n_correlated_pairs=0, n_anticorrelated_pairs=0, early_stopping=False)


  0%|          | 0/10000 [00:00<?, ?it/s]

Config(n_instances=5, n_features=14, tree_depth=3, branching_factor=2, n_hidden=3, n_correlated_pairs=0, n_anticorrelated_pairs=0, early_stopping=False)


  0%|          | 0/10000 [00:00<?, ?it/s]

Config(n_instances=5, n_features=14, tree_depth=3, branching_factor=2, n_hidden=4, n_correlated_pairs=0, n_anticorrelated_pairs=0, early_stopping=False)


  0%|          | 0/10000 [00:00<?, ?it/s]

Config(n_instances=5, n_features=14, tree_depth=3, branching_factor=2, n_hidden=5, n_correlated_pairs=0, n_anticorrelated_pairs=0, early_stopping=False)


  0%|          | 0/10000 [00:00<?, ?it/s]

Config(n_instances=5, n_features=14, tree_depth=3, branching_factor=2, n_hidden=6, n_correlated_pairs=0, n_anticorrelated_pairs=0, early_stopping=False)


  0%|          | 0/10000 [00:00<?, ?it/s]

Config(n_instances=5, n_features=14, tree_depth=3, branching_factor=2, n_hidden=7, n_correlated_pairs=0, n_anticorrelated_pairs=0, early_stopping=False)


  0%|          | 0/10000 [00:00<?, ?it/s]

Config(n_instances=5, n_features=14, tree_depth=3, branching_factor=2, n_hidden=8, n_correlated_pairs=0, n_anticorrelated_pairs=0, early_stopping=False)


  0%|          | 0/10000 [00:00<?, ?it/s]

Config(n_instances=5, n_features=14, tree_depth=3, branching_factor=2, n_hidden=9, n_correlated_pairs=0, n_anticorrelated_pairs=0, early_stopping=False)


  0%|          | 0/10000 [00:00<?, ?it/s]

Config(n_instances=5, n_features=14, tree_depth=3, branching_factor=2, n_hidden=10, n_correlated_pairs=0, n_anticorrelated_pairs=0, early_stopping=False)


  0%|          | 0/10000 [00:00<?, ?it/s]

Config(n_instances=5, n_features=30, tree_depth=4, branching_factor=2, n_hidden=2, n_correlated_pairs=0, n_anticorrelated_pairs=0, early_stopping=False)


  0%|          | 0/10000 [00:00<?, ?it/s]

Config(n_instances=5, n_features=30, tree_depth=4, branching_factor=2, n_hidden=3, n_correlated_pairs=0, n_anticorrelated_pairs=0, early_stopping=False)


  0%|          | 0/10000 [00:00<?, ?it/s]

Config(n_instances=5, n_features=30, tree_depth=4, branching_factor=2, n_hidden=4, n_correlated_pairs=0, n_anticorrelated_pairs=0, early_stopping=False)


  0%|          | 0/10000 [00:00<?, ?it/s]

Config(n_instances=5, n_features=30, tree_depth=4, branching_factor=2, n_hidden=5, n_correlated_pairs=0, n_anticorrelated_pairs=0, early_stopping=False)


  0%|          | 0/10000 [00:00<?, ?it/s]

Config(n_instances=5, n_features=30, tree_depth=4, branching_factor=2, n_hidden=6, n_correlated_pairs=0, n_anticorrelated_pairs=0, early_stopping=False)


  0%|          | 0/10000 [00:00<?, ?it/s]

Config(n_instances=5, n_features=30, tree_depth=4, branching_factor=2, n_hidden=7, n_correlated_pairs=0, n_anticorrelated_pairs=0, early_stopping=False)


  0%|          | 0/10000 [00:00<?, ?it/s]

Config(n_instances=5, n_features=30, tree_depth=4, branching_factor=2, n_hidden=8, n_correlated_pairs=0, n_anticorrelated_pairs=0, early_stopping=False)


  0%|          | 0/10000 [00:00<?, ?it/s]

Config(n_instances=5, n_features=30, tree_depth=4, branching_factor=2, n_hidden=9, n_correlated_pairs=0, n_anticorrelated_pairs=0, early_stopping=False)


  0%|          | 0/10000 [00:00<?, ?it/s]

Config(n_instances=5, n_features=30, tree_depth=4, branching_factor=2, n_hidden=10, n_correlated_pairs=0, n_anticorrelated_pairs=0, early_stopping=False)


  0%|          | 0/10000 [00:00<?, ?it/s]

Config(n_instances=5, n_features=62, tree_depth=5, branching_factor=2, n_hidden=2, n_correlated_pairs=0, n_anticorrelated_pairs=0, early_stopping=False)


  0%|          | 0/10000 [00:00<?, ?it/s]

Config(n_instances=5, n_features=62, tree_depth=5, branching_factor=2, n_hidden=3, n_correlated_pairs=0, n_anticorrelated_pairs=0, early_stopping=False)


  0%|          | 0/10000 [00:00<?, ?it/s]

Config(n_instances=5, n_features=62, tree_depth=5, branching_factor=2, n_hidden=4, n_correlated_pairs=0, n_anticorrelated_pairs=0, early_stopping=False)


  0%|          | 0/10000 [00:00<?, ?it/s]

Config(n_instances=5, n_features=62, tree_depth=5, branching_factor=2, n_hidden=5, n_correlated_pairs=0, n_anticorrelated_pairs=0, early_stopping=False)


  0%|          | 0/10000 [00:00<?, ?it/s]

Config(n_instances=5, n_features=62, tree_depth=5, branching_factor=2, n_hidden=6, n_correlated_pairs=0, n_anticorrelated_pairs=0, early_stopping=False)


  0%|          | 0/10000 [00:00<?, ?it/s]

Config(n_instances=5, n_features=62, tree_depth=5, branching_factor=2, n_hidden=7, n_correlated_pairs=0, n_anticorrelated_pairs=0, early_stopping=False)


  0%|          | 0/10000 [00:00<?, ?it/s]

Config(n_instances=5, n_features=62, tree_depth=5, branching_factor=2, n_hidden=8, n_correlated_pairs=0, n_anticorrelated_pairs=0, early_stopping=False)


  0%|          | 0/10000 [00:00<?, ?it/s]

Config(n_instances=5, n_features=62, tree_depth=5, branching_factor=2, n_hidden=9, n_correlated_pairs=0, n_anticorrelated_pairs=0, early_stopping=False)


  0%|          | 0/10000 [00:00<?, ?it/s]

Config(n_instances=5, n_features=62, tree_depth=5, branching_factor=2, n_hidden=10, n_correlated_pairs=0, n_anticorrelated_pairs=0, early_stopping=False)


  0%|          | 0/10000 [00:00<?, ?it/s]

Config(n_instances=5, n_features=126, tree_depth=6, branching_factor=2, n_hidden=2, n_correlated_pairs=0, n_anticorrelated_pairs=0, early_stopping=False)


  0%|          | 0/10000 [00:00<?, ?it/s]

Config(n_instances=5, n_features=126, tree_depth=6, branching_factor=2, n_hidden=3, n_correlated_pairs=0, n_anticorrelated_pairs=0, early_stopping=False)


  0%|          | 0/10000 [00:00<?, ?it/s]

Config(n_instances=5, n_features=126, tree_depth=6, branching_factor=2, n_hidden=4, n_correlated_pairs=0, n_anticorrelated_pairs=0, early_stopping=False)


  0%|          | 0/10000 [00:00<?, ?it/s]

Config(n_instances=5, n_features=126, tree_depth=6, branching_factor=2, n_hidden=5, n_correlated_pairs=0, n_anticorrelated_pairs=0, early_stopping=False)


  0%|          | 0/10000 [00:00<?, ?it/s]

Config(n_instances=5, n_features=126, tree_depth=6, branching_factor=2, n_hidden=6, n_correlated_pairs=0, n_anticorrelated_pairs=0, early_stopping=False)


  0%|          | 0/10000 [00:00<?, ?it/s]

Config(n_instances=5, n_features=126, tree_depth=6, branching_factor=2, n_hidden=7, n_correlated_pairs=0, n_anticorrelated_pairs=0, early_stopping=False)


  0%|          | 0/10000 [00:00<?, ?it/s]

Config(n_instances=5, n_features=126, tree_depth=6, branching_factor=2, n_hidden=8, n_correlated_pairs=0, n_anticorrelated_pairs=0, early_stopping=False)


  0%|          | 0/10000 [00:00<?, ?it/s]

Config(n_instances=5, n_features=126, tree_depth=6, branching_factor=2, n_hidden=9, n_correlated_pairs=0, n_anticorrelated_pairs=0, early_stopping=False)


  0%|          | 0/10000 [00:00<?, ?it/s]

Config(n_instances=5, n_features=126, tree_depth=6, branching_factor=2, n_hidden=10, n_correlated_pairs=0, n_anticorrelated_pairs=0, early_stopping=False)


  0%|          | 0/10000 [00:00<?, ?it/s]

In [53]:
import plotly.graph_objects as go

# First, let's organize our data
data = {}
for cfg, model, loss in zip(cfgs, models, losses):
    depth = cfg.tree_depth
    n_hidden = cfg.n_hidden
    if depth not in data:
        data[depth] = {}
    data[depth][n_hidden] = loss.item()

# Now let's create a heatmap
depths = sorted(data.keys())
n_hiddens = sorted(set(n_hidden for depth_data in data.values() for n_hidden in depth_data.keys()))

z = [[data[depth].get(n_hidden, float('nan')) for n_hidden in n_hiddens] for depth in depths]

fig = go.Figure(data=go.Heatmap(
    z=z,
    x=n_hiddens,
    y=depths,
    colorscale='Viridis',
    colorbar=dict(title='Loss')
))

fig.update_layout(
    title='Loss vs Tree Depth and n_hidden',
    xaxis_title='n_hidden',
    yaxis_title='Tree Depth',
    xaxis_type='category',
    yaxis_type='category'
)

fig.show()

# Now let's create a line plot
fig = go.Figure()

for depth in depths:
    x = sorted(data[depth].keys())
    y = [data[depth][n_hidden] for n_hidden in x]
    fig.add_trace(go.Scatter(
        x=x,
        y=y,
        mode='lines+markers',
        name=f'Depth {depth}'
    ))

fig.update_layout(
    title='Loss vs n_hidden for Different Tree Depths',
    xaxis_title='n_hidden',
    yaxis_title='Loss',
    xaxis_type='category'
)

fig.show()

# Finally, let's find the minimum n_hidden for "good" loss
threshold = 1e-5  # Define your threshold for "good" loss here
min_n_hidden = {}

for depth in depths:
    for n_hidden in sorted(data[depth].keys()):
        if data[depth][n_hidden] < threshold:
            min_n_hidden[depth] = n_hidden
            break
    if depth not in min_n_hidden:
        min_n_hidden[depth] = max(data[depth].keys())  # If no n_hidden achieves good loss, use the maximum tested

# Create a bar plot of minimum n_hidden for each depth
fig = go.Figure(data=go.Bar(
    x=list(min_n_hidden.keys()),
    y=list(min_n_hidden.values())
))

fig.update_layout(
    title=f'Minimum n_hidden for Loss < {threshold} by Tree Depth',
    xaxis_title='Tree Depth',
    yaxis_title='Minimum n_hidden',
    xaxis_type='category'
)

fig.show()

### Notes

- Currently looking at avg loss across instances. I think min is probably better.
- Compare to `early_stopping=False`
- Compare to non-hierarchical structure