## Introduction

This notebook provides reference implementations for [JumpReLU SAEs](https://arxiv.org/abs/2407.14435) in JAX and PyTorch, expanding on the pseudo-code provided in the paper. Specifically, we include:

* Implementations of the `jumprelu` and `step` functions with custom backward passes;
* Implementations of the SAE forward pass and L0-based loss function;
* Training loop implementations that optionally normalise the norms of the decoder matrix.

We don't implement some features used in the training setup described in the paper (e.g. learning rate and sparsity coefficient $\lambda$ warm-up) which are reasonably easy to add on if desired.

The notebook also provides comprehensive tests to check that the Jax and PyTorch implementations are consistent. This includes an end-to-end training test where we train SAEs (on synthetic data, using identical initialisations) using both the Jax and PyTorch implementations and check we get the similar parameters after three steps. You may find these tests useful for testing other implementations of JumpReLU SAEs against these reference implementations to confirm consistency.

You should be able to run this notebook on a CPU runtime.

## Setup

This section imports modules needed by the rest of the notebook, sets some constants (mainly hyperparameters) and generates some synthetic data and initialises some SAE parameters that we use later for testing.

In [1]:
# @title Imports
import dataclasses
import functools
import itertools

import chex
import numpy as np
import jax
import jax.numpy as jnp
import plotly.express as px
import torch
from torch import nn

jax.config.update("jax_enable_x64", True)

In [2]:
# @title Hyperparameters and constants

NUM_STEPS = 3
BATCH_SIZE = 1024
ACTIVATIONS_SIZE = 16
SAE_WIDTH = 128
THRESHOLD_INIT = 0.001
# We use a higher bandwidth than in the paper to ensure a non-zero gradient
# to the threshold at every step (since we'll only be taking three steps)
BANDWIDTH = 0.1
FIX_DECODER_NORMS = True
LEARNING_RATE = 0.001  # Note this is not the learning rate in the paper
ADAM_B1 = 0.0
DATA_SEED = 9328302
PARAMS_SEED = 24396

In [3]:
# @title Create some synthetic data for testing

rng = np.random.default_rng(DATA_SEED)
dataset = rng.normal(
    size=(NUM_STEPS, BATCH_SIZE, ACTIVATIONS_SIZE)
) / np.sqrt(ACTIVATIONS_SIZE)

In [4]:
# @title Choose random SAE weights for testing

# We choose an initialization that is useful for testing. Specifically
# this means we initialize the biases and threshold to non-zero values
# and that we don't set the encoder weights to the transpose of the decoder
# (since they won't in general during training).
rng = np.random.default_rng(PARAMS_SEED)
W_dec = (rng.uniform(size=(SAE_WIDTH, ACTIVATIONS_SIZE)) - 0.5)
W_dec /= np.linalg.norm(W_dec, axis=-1, keepdims=True)
W_enc = (rng.uniform(size=(ACTIVATIONS_SIZE, SAE_WIDTH)) - 0.5)
b_enc = (rng.uniform(size=(SAE_WIDTH,)) - 0.5) * 0.1
b_dec = (rng.uniform(size=(ACTIVATIONS_SIZE,)) - 0.5) * 0.1
threshold = 0.15 * (rng.uniform(size=(SAE_WIDTH,))) * 0.1

## PyTorch implementation

In this section we translate the JAX implementation defined in the previous section into PyTorch. We'll then check carefully that the PyTorch implementation is consistent with the JAX one, the key test being that training over multiple steps with either implementation (using synthetic data and identical initialisation) yields the same parameters (up to numerical tolerance).

In [5]:
from typing import Any, Tuple


def rectangle_pt(x: torch.Tensor) -> torch.Tensor:
    return ((x > -0.5) & (x < 0.5)).to(x)


class Step2(torch.autograd.Function):
    BANDWIDTH = 0.001

    @staticmethod
    def forward(x: torch.Tensor, threshold: torch.Tensor) -> torch.Tensor:
        return (x > threshold).to(x)

    @staticmethod
    def setup_context(
        ctx: Any, inputs: Tuple[torch.Tensor, torch.Tensor], output: torch.Tensor
    ) -> None:
        x, threshold = inputs
        del output
        ctx.save_for_backward(x, threshold)

    @staticmethod
    def backward(
        ctx: Any, grad_output: torch.Tensor
    ) -> Tuple[torch.Tensor, torch.Tensor]:
        x, threshold = ctx.saved_tensors
        x_grad = 0.0 * grad_output  # We don't apply STE to x input
        threshold_grad = torch.sum(
            -(1.0 / Step2.BANDWIDTH)
            * rectangle_pt((x - threshold) / Step2.BANDWIDTH)
            * grad_output,
            dim=0,
        )
        return x_grad, threshold_grad


class JumpReLU2(torch.autograd.Function):
    BANDWIDTH = 0.001

    @staticmethod
    def forward(x: torch.Tensor, threshold: torch.Tensor) -> torch.Tensor:
        return x * (x > threshold).to(x)

    @staticmethod
    def setup_context(
        ctx: Any, inputs: Tuple[torch.Tensor, torch.Tensor], output: torch.Tensor
    ) -> None:
        x, threshold = inputs
        del output
        ctx.save_for_backward(x, threshold)

    @staticmethod
    def backward(
        ctx: Any, grad_output: torch.Tensor
    ) -> Tuple[torch.Tensor, torch.Tensor]:
        x, threshold = ctx.saved_tensors
        x_grad = (x > threshold) * grad_output  # We don't apply STE to x input
        threshold_grad = torch.sum(
            -(threshold / JumpReLU2.BANDWIDTH)
            * rectangle_pt((x - threshold) / JumpReLU2.BANDWIDTH)
            * grad_output,
            dim=0,
        )
        return x_grad, threshold_grad

Step2.BANDWIDTH = BANDWIDTH
JumpReLU2.BANDWIDTH = BANDWIDTH

In [6]:
# @title STEs, forward pass and loss function

def rectangle_pt(x):
    return ((x > -0.5) & (x < 0.5)).to(x)


class Step(torch.autograd.Function):
    @staticmethod
    def forward(x, threshold):
        return (x > threshold).to(x)

    @staticmethod
    def setup_context(ctx, inputs, output):
        x, threshold = inputs
        del output
        ctx.save_for_backward(x, threshold)

    @staticmethod
    def backward(ctx, grad_output):
        x, threshold = ctx.saved_tensors
        x_grad = 0.0 * grad_output  # We don't apply STE to x input
        threshold_grad = torch.sum(
            -(1.0 / BANDWIDTH)
            * rectangle_pt((x - threshold) / BANDWIDTH)
            * grad_output,
            dim=0,
        )
        return x_grad, threshold_grad


class JumpReLU(torch.autograd.Function):
    @staticmethod
    def forward(x, threshold):
        return x * (x > threshold).to(x)

    @staticmethod
    def setup_context(ctx, inputs, output):
        x, threshold = inputs
        del output
        ctx.save_for_backward(x, threshold)

    @staticmethod
    def backward(ctx, grad_output):
        x, threshold = ctx.saved_tensors
        x_grad = (x > threshold) * grad_output  # We don't apply STE to x input
        threshold_grad = torch.sum(
            -(threshold / BANDWIDTH)
            * rectangle_pt((x - threshold) / BANDWIDTH)
            * grad_output,
            dim=0,
        )
        return x_grad, threshold_grad


In [7]:

class Sae(nn.Module):
    def __init__(self, sae_width, activations_size, use_pre_enc_bias, mode):
        super().__init__()
        self.use_pre_enc_bias = use_pre_enc_bias
        self.W_enc = nn.Parameter(torch.tensor(W_enc))
        self.b_enc = nn.Parameter(torch.tensor(b_enc))
        self.W_dec = nn.Parameter(torch.tensor(W_dec))
        self.b_dec = nn.Parameter(torch.tensor(b_dec))
        self.log_threshold = nn.Parameter(
            torch.tensor(np.log(threshold))
        )
        self.mode = mode

    def __call__(self, x):
        if self.use_pre_enc_bias:
            x = x - self.b_dec

        pre_activations = x @ self.W_enc + self.b_enc
        threshold = torch.exp(self.log_threshold)
        if self.mode == "1":
            feature_magnitudes = JumpReLU.apply(pre_activations, threshold)
        else:
            feature_magnitudes = JumpReLU2.apply(pre_activations, threshold)
        x_reconstructed = feature_magnitudes @ self.W_dec + self.b_dec
        return x_reconstructed, pre_activations


def loss_fn_pt(sae, x, sparsity_coefficient, use_pre_enc_bias, mode):
    x_reconstructed, pre_activations = sae(x)

    # Compute per-example reconstruction loss
    reconstruction_error = x - x_reconstructed
    reconstruction_loss = torch.sum(reconstruction_error**2, dim=-1)

    # Compute per-example sparsity loss
    threshold = torch.exp(sae.log_threshold)

    if mode == "1":
        l0 = torch.sum(Step.apply(pre_activations, threshold), dim=-1)
    else:
        l0 = torch.sum(Step2.apply(pre_activations, threshold), dim=-1)
    sparsity_loss = sparsity_coefficient * l0

    # Return the batch-wise mean total loss
    return torch.mean(reconstruction_loss + sparsity_loss, dim=0)

In [8]:
# @title Training loop

def remove_parallel_component_pt(x, v):
    """Returns x with component parallel to v projected away (in PyTorch)."""
    v_normalised = v / (torch.norm(v, dim=-1, keepdim=True) + 1e-6)
    parallel_component = torch.einsum("...d,...d->...", x, v_normalised)
    return x - parallel_component[..., None] * v_normalised

def train_pt(
    dataset_iterator,
    sparsity_coefficient,
    use_pre_enc_bias,
    fix_decoder_norms,
    mode
):
    sae = Sae(SAE_WIDTH, ACTIVATIONS_SIZE, use_pre_enc_bias, mode)
    optimizer = torch.optim.Adam(
        sae.parameters(), lr=LEARNING_RATE, betas=(ADAM_B1, 0.999)
    )
    for batch in dataset_iterator:
        optimizer.zero_grad()
        loss_pt = loss_fn_pt(
            sae, torch.tensor(batch), sparsity_coefficient, use_pre_enc_bias, mode
        )
        loss_pt.backward()
        if fix_decoder_norms:
            sae.W_dec.grad = remove_parallel_component_pt(
                sae.W_dec.grad, sae.W_dec.data
            )
        optimizer.step()
        if fix_decoder_norms:
            sae.W_dec.data = sae.W_dec.data / torch.norm(
                sae.W_dec.data, dim=-1, keepdim=True
            )
    return sae

In [9]:
sparsity_coefficients = [0.0, 0.01, 0.1]  # Arbitrarily chosen
use_pre_enc_bias_l = [True, False]
fix_decoder_norms_l = [True, False]


for sparsity_coefficient, use_pre_enc_bias, fix_decoder_norms in itertools.product(
    sparsity_coefficients, use_pre_enc_bias_l, fix_decoder_norms_l
):
    print(
        f"Testing {sparsity_coefficient=}, {use_pre_enc_bias=}, "
        f"{fix_decoder_norms=}... ",
        end="",
        flush=True,
    )

    # Train using the JAX implementation
    # params_jax_trained = train_jax(
    #     iter(dataset),
    #     sparsity_coefficient=sparsity_coefficient,
    #     use_pre_enc_bias=use_pre_enc_bias,
    #     fix_decoder_norms=fix_decoder_norms,
    # )

    # Train using the PyTorch implementation
    sae_pt_trained_1 = train_pt(
        iter(dataset),
        sparsity_coefficient,
        use_pre_enc_bias,
        fix_decoder_norms,
        "1"
    )

    # Train using the PyTorch implementation
    sae_pt_trained_2 = train_pt(
        iter(dataset),
        sparsity_coefficient,
        use_pre_enc_bias,
        fix_decoder_norms,
        "2"
    )

    # First we want to make sure the params have actually evolved, otherwise
    # this test isn't meaningful!
    # chex.assert_trees_all_close(
    #     jax.tree.map(
    #         lambda x, y: np.mean(np.abs(x - y)) > 0.001,
    #         params_init,
    #         params_jax_trained,
    #     ),
    #     jax.tree.map(lambda _: True, params_init),
    # )

    # Now we check whether the parameters obtained using either implementation
    # are close
    chex.assert_trees_all_close(
        jax.tree.map(lambda x: x.numpy(), dict(sae_pt_trained_1.state_dict())),
        jax.tree.map(lambda x: x.numpy(), dict(sae_pt_trained_2.state_dict())),
    )

    print("OK.")

Testing sparsity_coefficient=0.0, use_pre_enc_bias=True, fix_decoder_norms=True... 

NameError: name 'params_init' is not defined