In [1]:
import os
import torch
from tqdm.auto import tqdm
import plotly.express as px
import pandas as pd
from copy import deepcopy
from torch import nn
import torch.nn.functional as F

from transformers import AutoModelForCausalLM, BitsAndBytesConfig, AutoTokenizer
from huggingface_hub import hf_hub_download, notebook_login, login
import numpy as np

from sae_lens.toolkit.pretrained_saes_directory import get_pretrained_saes_directory
from sae_lens import SAE, HookedSAETransformer

if torch.backends.mps.is_available():
    device = "mps"
else:
    device = "cuda" if torch.cuda.is_available() else "cpu"

print(f"Device: {device}")

Device: cuda


In [2]:
model = HookedSAETransformer.from_pretrained("gpt2-small", device = device)

huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...
	- Avoid using `tokenizers` before the fork if possible
	- Explicitly set the environment variable TOKENIZERS_PARALLELISM=(true | false)
huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...
	- Avoid using `tokenizers` before the fork if possible
	- Explicitly set the environment variable TOKENIZERS_PARALLELISM=(true | false)


Loaded pretrained model gpt2-small into HookedTransformer


In [3]:
BATCH_SIZE = 1024
ACTIVATIONS_SIZE = 768
HOOK_POINT = 'blocks.8.hook_resid_pre'
THRESHOLD_INIT = 0.001
BANDWIDTH = 0.001
FIX_DECODER_NORMS = True
LEARNING_RATE = 0.001  # Note this is not the learning rate in the paper
ADAM_B1 = 0.0
DATA_SEED = 9328302
PARAMS_SEED = 24396
rng = np.random.default_rng(DATA_SEED)

In [4]:
def rectangle_pt(x):
    return ((x > -0.5) & (x < 0.5)).to(x)


class Step(torch.autograd.Function):
    @staticmethod
    def forward(x, threshold):
        return (x > threshold).to(x)

    @staticmethod
    def setup_context(ctx, inputs, output):
        x, threshold = inputs
        del output
        ctx.save_for_backward(x, threshold)

    @staticmethod
    def backward(ctx, grad_output):
        x, threshold = ctx.saved_tensors
        x_grad = 0.0 * grad_output  # We don't apply STE to x input
        threshold_grad = torch.sum(
            -(1.0 / BANDWIDTH)
            * rectangle_pt((x - threshold) / BANDWIDTH)
            * grad_output,
            dim=0,
        )
        return x_grad, threshold_grad


class JumpReLU(torch.autograd.Function):
    @staticmethod
    def forward(x, threshold):
        return x * (x > threshold).to(x)

    @staticmethod
    def setup_context(ctx, inputs, output):
        x, threshold = inputs
        del output
        ctx.save_for_backward(x, threshold)

    @staticmethod
    def backward(ctx, grad_output):
        x, threshold = ctx.saved_tensors
        x_grad = (x > threshold) * grad_output  # We don't apply STE to x input
        threshold_grad = torch.sum(
            -(threshold / BANDWIDTH)
            * rectangle_pt((x - threshold) / BANDWIDTH)
            * grad_output,
            dim=0,
        )
        return x_grad, threshold_grad


class Sae(nn.Module):
    def __init__(self, sae_width, activations_size, use_pre_enc_bias):
        super().__init__()
        self.dtype = torch.float
        self.device = device
        self.use_pre_enc_bias = use_pre_enc_bias
        self.W_enc = nn.Parameter(
            torch.nn.init.kaiming_uniform_(
                torch.empty(
                    activations_size, sae_width, dtype=self.dtype, device=self.device
                )
            )
        )
        self.b_enc = nn.Parameter(
            torch.zeros(sae_width, dtype=self.dtype, device=self.device)
        )
        self.W_dec = nn.Parameter(
            self.W_enc.data.T
            # torch.nn.init.kaiming_uniform_(
            #     torch.empty(
            #         sae_width, activations_size, dtype=self.dtype, device=self.device
            #     )
            # )
        )
        self.b_dec = nn.Parameter(
            torch.zeros(activations_size, dtype=self.dtype, device=self.device)
        )
        self.log_threshold = nn.Parameter(
            np.log(THRESHOLD_INIT)*torch.ones(sae_width, dtype=self.dtype, device=self.device)
        )


    def __call__(self, x):
        if self.use_pre_enc_bias:
            x = x - self.b_dec

        pre_activations = x @ self.W_enc + self.b_enc
        # if self.use_jumprelu:
        threshold = torch.exp(self.log_threshold)
        feature_magnitudes = JumpReLU.apply(pre_activations, threshold)
        # else:
        #     feature_magnitudes = F.relu(pre_activations)
        
        x_reconstructed = feature_magnitudes @ self.W_dec + self.b_dec
        return x_reconstructed, pre_activations


def loss_fn_pt(sae, x, sparsity_coefficient):
    x_reconstructed, pre_activations = sae(x)

    # Compute per-example reconstruction loss
    reconstruction_error = x - x_reconstructed
    reconstruction_loss = torch.sum(reconstruction_error**2, dim=-1)

    # Compute per-example sparsity loss
    threshold = torch.exp(sae.log_threshold)
    l0 = torch.sum(Step.apply(pre_activations, threshold), dim=-1)
    sparsity_loss = sparsity_coefficient * l0

    # Return the batch-wise mean total loss
    return reconstruction_loss.mean(), sparsity_loss.mean()


def remove_parallel_component_pt(x, v):
    """Returns x with component parallel to v projected away (in PyTorch)."""
    v_normalised = v / (torch.norm(v, dim=-1, keepdim=True) + 1e-6)
    parallel_component = torch.einsum("...d,...d->...", x, v_normalised)
    return x - parallel_component[..., None] * v_normalised

def train_pt(
    sae, 
    optimizer,
    dataloader,
    sparsity_coefficient,
    num_steps = 1
):
    
    for i, tokens in enumerate(pbar := tqdm(dataloader)):
        with torch.no_grad():
            _, cache = model.run_with_cache(tokens['tokens'], names_filter = [HOOK_POINT], stop_at_layer = 9,)
            norm_res = F.normalize(cache[HOOK_POINT], dim=-1)
            # norm_res = cache[HOOK_POINT]
        
        optimizer.zero_grad()
        recon_loss, sparsity_loss = loss_fn_pt(
            sae, norm_res, sparsity_coefficient
        )
        loss_pt = recon_loss + sparsity_loss
        loss_pt.backward()

        if FIX_DECODER_NORMS:
            sae.W_dec.grad = remove_parallel_component_pt(
                sae.W_dec.grad, sae.W_dec.data
            )
        optimizer.step()
        if FIX_DECODER_NORMS:
            sae.W_dec.data = sae.W_dec.data / torch.norm(
                sae.W_dec.data, dim=-1, keepdim=True
            )

        pbar.set_description_str(f'recon loss: {recon_loss.item()} , sparsity loss: {sparsity_loss.item()}')

        if i == num_steps:
            break
    return sae

In [5]:
# @title STEs, forward pass and loss function
from datasets import load_dataset
from transformer_lens.utils import tokenize_and_concatenate
from torch.utils.data import DataLoader

dataset = load_dataset(
    path = "Skylion007/openwebtext",
    split="train[0:200000]",
    streaming=False,
)

In [6]:
token_dataset = tokenize_and_concatenate(
    dataset= dataset,
    tokenizer = model.tokenizer,
    streaming=True,
    max_length=128,
    add_bos_token=True,
)

dataloader = DataLoader(token_dataset, batch_size=BATCH_SIZE, shuffle=True)

In [7]:
# sparsity_coefficient = 0.01

# sae = Sae(768, ACTIVATIONS_SIZE, False)

# for i, tokens in enumerate(pbar := tqdm(dataloader)):
#     with torch.no_grad():
#         _, cache = model.run_with_cache(tokens['tokens'], names_filter = [HOOK_POINT], stop_at_layer = 9,)
#         norm_res = F.normalize(cache[HOOK_POINT], dim=-1)

# x_reconstructed, pre_activations = sae(norm_res)

In [7]:
sparsity_coefficient =  0.01

jump_sae_768 = Sae(768, ACTIVATIONS_SIZE, False)
optimizer = torch.optim.Adam(
    jump_sae_768.parameters(), lr=LEARNING_RATE, betas=(ADAM_B1, 0.999)
)
jump_sae_768 = train_pt(jump_sae_768, optimizer, dataloader, sparsity_coefficient, num_steps=1000)

  0%|          | 0/1743 [00:00<?, ?it/s]

In [8]:
torch.save(jump_sae_768.cpu().state_dict(), './jump_sae_768-final.pt')

In [17]:
sparsity_coefficient = 0.005

jump_sae_1536 = Sae(1536, ACTIVATIONS_SIZE, False)
optimizer = torch.optim.Adam(
    jump_sae_1536.parameters(), lr=LEARNING_RATE, betas=(ADAM_B1, 0.999)
)

jump_sae_1536 = train_pt(jump_sae_1536, optimizer, dataloader, sparsity_coefficient, num_steps=1000)

  0%|          | 0/1743 [00:00<?, ?it/s]


KeyboardInterrupt



In [20]:
jump_sae_1536.to(device)
optimizer = torch.optim.Adam(
    jump_sae_1536.parameters(), lr=LEARNING_RATE, betas=(ADAM_B1, 0.999)
)
sparsity_coefficient = 0.001
jump_sae_1536 = train_pt(jump_sae_1536, optimizer, dataloader, sparsity_coefficient, num_steps=1000)

  0%|          | 0/1743 [00:00<?, ?it/s]

KeyboardInterrupt: 

In [21]:
torch.save(jump_sae_1536.cpu().state_dict(), './jump_sae_1536-final.pt')