In [1]:
from IPython.display import clear_output
!pip install transformer-lens jaxtyping datasets

clear_output()

In [19]:
import json
import wandb
import torch
import einops
import torch.nn as nn
import torch.nn.functional as F

from tqdm import tqdm
from pathlib import Path
from torch import Tensor
from rich.table import Table
from jaxtyping import Float, Int
from rich import print as rprint
from typing import Callable, Tuple
from dataclasses import dataclass, asdict, field
from transformer_lens import HookedTransformer
from torch.distributions.categorical import Categorical
from transformer_lens.utils import (
    get_act_name,
    load_dataset,
    tokenize_and_concatenate,
    download_file_from_hf
)

In [3]:
@dataclass
class Config:
    d_in: int = 768
    dict_mult: int = 32
    d_sae: int = field(init=False)
    tied_weights: bool = False
    layer: int = 8
    device: str = 'cuda:3'
    l1_coefficient: int = 8e-5
    weight_normalize_eps: float = 1e-8

    seq_len: int = 128
    batch_size: int = 4096
    component_name: str = "resid_post"
    act_name: str = field(init=False)
    
    buffer_mult: int = 384
    buffer_size: int = field(init=False)
    buffer_batches: int = field(init=False)
    model_batch_size: int =  field(init=False)

    log_freq: int = 50
    lr:float = 4e-4

    def __post_init__(self):
        self.d_sae = self.d_in * self.dict_mult
        self.buffer_size = self.batch_size * self.buffer_mult
        self.buffer_batches = self.buffer_size // self.seq_len
        self.model_batch_size =  (self.batch_size // self.seq_len * 16)
        self.act_name = get_act_name(self.component_name, self.layer)

    
cfg = Config()

In [4]:
model = HookedTransformer.from_pretrained(
    "gpt2-small",
    device=cfg.device
)



Loaded pretrained model gpt2-small into HookedTransformer


In [5]:
data = load_dataset("NeelNanda/c4-code-20k", split="train")
tokenized_data = tokenize_and_concatenate(data, model.tokenizer, max_length=128)
tokenized_data = tokenized_data.shuffle(42)
all_tokens = tokenized_data["tokens"]
del data, tokenized_data
print("Tokens shape: ", all_tokens.shape)

Tokens shape:  torch.Size([325017, 128])


In [6]:
def constant_lr(*_):
    return 1.0

class Buffer():
    def __init__(self, cfg: Config):
        self.cfg = cfg
        self.buffer = torch.zeros((self.cfg.buffer_size, self.cfg.d_in), requires_grad=False, dtype=torch.bfloat16).to(self.cfg.device)
        self.pointer = 0
        self.token_pointer = 0
        self.first = True  
        self.refresh()
    
    def refresh(self):
        self.pointer = 0
        with torch.autocast("cuda", torch.bfloat16):
            if self.first:
                num_batches = self.cfg.buffer_batches
            else:
                num_batches = self.cfg.buffer_batches // 2

            self.first = False

            for _ in range(0, num_batches, self.cfg.model_batch_size):
                tokens = all_tokens[self.token_pointer: self.token_pointer + self.cfg.model_batch_size]

                _, cache = model.run_with_cache(tokens, stop_at_layer=self.cfg.layer + 1, names_filter=self.cfg.act_name)

                # print(list(cache.keys()))
                acts = einops.rearrange(
                    cache[self.cfg.act_name],
                    "batch seq d_model -> (batch seq) d_model"
                )

                
                del cache
                
                self.buffer[self.pointer: self.pointer+acts.shape[0]] = acts
                self.pointer += acts.shape[0]
                self.token_pointer += self.cfg.model_batch_size

                if self.token_pointer + self.cfg.model_batch_size >= all_tokens.shape[0]:
                    self.token_pointer = 0
            
        self.pointer = 0
        self.buffer = self.buffer[torch.randperm(self.buffer.shape[0]).to(self.cfg.device)]
        
    def next(self) -> Int[Tensor, "batch_seq d_model"]:
        out = self.buffer[self.pointer: self.pointer + self.cfg.batch_size]
        self.pointer += self.cfg.batch_size

        if self.pointer + self.cfg.batch_size > self.buffer.shape[0] // 2:
            self.refresh()

        return out

buffer = Buffer(cfg)

In [7]:
class SAE(nn.Module):
    def __init__(self,
                 cfg: Config,
                 model 
                 ): 
        super().__init__()
        self.cfg = cfg
        self.model = model

        self.W_enc = nn.Parameter(nn.init.xavier_uniform_(torch.empty(self.cfg.d_in, self.cfg.d_sae)))
        
        self.b_enc = nn.Parameter(torch.zeros(self.cfg.d_sae))
        self.b_dec = nn.Parameter(torch.zeros(self.cfg.d_in))

        if self.cfg.tied_weights:
            self._W_dec = None
        else:
            self._W_dec = nn.Parameter(nn.init.xavier_uniform_(torch.empty(self.cfg.d_sae, self.cfg.d_in)))
    
    @property
    def W_dec(self) -> Float[Tensor, "d_sae d_in"]:
        return self._W_dec if self._W_dec is not None else self.W_enc.transpose(1, 0)
    
    @property
    def W_dec_normalized(self) -> Float[Tensor, "d_sae d_in"]:
        """Returns decoder weights, normalized over the autoencoder input dimension."""
        return self.W_dec / self.W_dec.norm(dim=-1, keepdim=True)

    def forward(self,
                h: Float[Tensor, "batch_seq d_in"] 
    ) -> tuple[
            dict[str, Float[Tensor, "batch_seq"]],
            Float[Tensor, ""],
            Float[Tensor, "batch_seq d_sae"],
            Float[Tensor, "batch_seq d_in"]
        ]:
        assert h.shape[1] == self.cfg.d_in 

        # print(f"This is the shape of h {h.shape}")

        acts = F.relu(
            einops.einsum(
                (h - self.b_dec), self.W_enc,
                "batch_seq d_in, d_in d_sae -> batch_seq d_sae"
            ) + self.b_enc
        )

        h_reconstructed = einops.einsum(
            acts, self.W_dec,
            "batch_seq d_sae, d_sae d_in -> batch_seq d_in"
        )

        assert h_reconstructed.shape == h.shape
        # h_reconstructed is batch_seq d_in
        L_reconstruction = ((h - h_reconstructed) ** 2).mean(dim=-1)
        L_sparsity = acts.abs().sum(dim=-1)
    
        loss_dict = {
            "L_reconstruction": L_reconstruction,
            "L_sparsity": L_sparsity
        }

        loss = (L_reconstruction + (self.cfg.l1_coefficient * L_sparsity)).mean()

        return loss_dict, loss, acts, h_reconstructed


    def optimize(
            self, 
            steps: int = 30_000,
            log_freq: int = 50,
            lr_scale: Callable[[int, int], float] = constant_lr,
            resample_freq: int = 2500,
            resample_window: int = 500,
            resample_scale: float = 0.5

    ):
        assert resample_window <= resample_freq

        name = f"L{self.cfg.layer}_{self.cfg.d_sae}_L1-{self.cfg.l1_coefficient}_Lr-{self.cfg.lr}"

        wandb.init(project="Autoencoders", name=name)

        optimizer = torch.optim.Adam(list(self.parameters()), lr=self.cfg.lr, betas=(0.0, 0.999))
        progress_bar = tqdm(range(steps))
        frac_active_list = []

        for step in progress_bar:
            if ((step + 1) % resample_freq == 0):
                frac_active_in_window = torch.stack(frac_active_list[-resample_window:], dim=0)
                self.resample_advanced(frac_active_in_window, resample_scale, self.cfg.batch_size)

            # Update learning rate
            step_lr = self.cfg.lr * lr_scale(step, steps)
            for group in optimizer.param_groups:
                group["lr"] = step_lr

            h = buffer.next()

            # print(h.shape)

            loss_dict, loss, acts, _ = self.forward(h)

            loss.backward()
            optimizer.step()
            optimizer.zero_grad()

            l2_loss = loss_dict["L_reconstruction"]
            l1_loss = loss_dict["L_sparsity"]
            
            assert(l1_loss.shape[0] == h.shape[0] and l2_loss.shape[0] == h.shape[0])
            l2_loss = l2_loss.mean()
            l1_loss = (self.cfg.l1_coefficient * l1_loss).mean()
            l0 = (acts > 0).float().sum(-1).mean()
            frac_active = (acts.abs() > 1e-8).float().mean(0)
            frac_active_list.append(frac_active)
            
            log_dict = {"losses/loss": loss, "losses/l1_loss":l1_loss, "losses/l2_loss": l2_loss, 'metrics/frac_active':frac_active.mean().item(), "metrics/l0": l0}

            if not self.cfg.tied_weights:
                    self.W_dec.data = self.W_dec_normalized

            wandb.log(log_dict)
            
            if step % self.cfg.log_freq == 0 or (step + 1 == steps):
                progress_bar.set_postfix(
                    lr=step_lr,
                    frac_active=frac_active.mean().item(),
                    **{k: v.mean(0).sum().item() for k, v in loss_dict.items()},  # type: ignore
                )
                # data_log["W_enc"].append(self.W_enc.detach().cpu().clone())
                # data_log["W_dec"].append(self.W_dec.detach().cpu().clone())
                # data_log["frac_active"].append(frac_active.detach().cpu().clone())
                # data_log["steps"].append(step)
            # return dat

    @torch.no_grad()
    def resample_advanced(
        self,
        frac_active_in_window: Float[Tensor, "window d_sae"],
        resample_scale: float,
        batch_size: int,
    ) -> None:
        """
        Resamples latents that have been dead for 'dead_feature_window' steps, according to `frac_active`.

        Resampling method is:
            - Compute the L2 reconstruction loss produced from the hidden state vectors `h`
            - Randomly choose values of `h` with probability proportional to their reconstruction loss
            - Set new values of W_dec and W_enc to be these (centered and normalized) vectors, at each dead neuron
            - Set b_enc to be zero, at each dead neuron

        Returns colors and titles (useful for creating the animation: resampled neurons appear in red).
        """
        h = buffer.next()
        l2_loss = self.forward(h)[0]["L_reconstruction"]

        # Find the dead latents in this instance. If all latents are alive, continue
        is_dead = (frac_active_in_window < 1e-8).all(dim=0)
        dead_latents = torch.nonzero(is_dead).squeeze(-1)
        n_dead = dead_latents.numel()
        if n_dead == 0:
            return  # If we have no dead features, then we don't need to resampl

        # Compute L2 loss for each element in the batch
        l2_loss_instance = l2_loss  # [batch_size]
        if l2_loss_instance.max() < 1e-6:
            return  # If we have zero reconstruction loss, we don't need to resample

        # Draw `d_sae` samples from [0, 1, ..., batch_size-1], with probabilities proportional to l2_loss
        distn = Categorical(probs=l2_loss_instance.pow(2) / l2_loss_instance.pow(2).sum())
        replacement_indices = distn.sample((n_dead,))  # type: ignore

        # Index into the batch of hidden activations to get our replacement values
        replacement_values = (h - self.b_dec)[replacement_indices]  # [n_dead d_in]
        replacement_values_normalized = replacement_values / (
            replacement_values.norm(dim=-1, keepdim=True) + self.cfg.weight_normalize_eps
        )

        # Get the norm of alive neurons (or 1.0 if there are no alive neurons)
        W_enc_norm_alive_mean = (
            self.W_enc[:, ~is_dead].norm(dim=0).mean().item()
            if (~is_dead).any()
            else 1.0
        )

        # Lastly, set the new weights & biases (W_dec is normalized, W_enc needs specific scaling, b_enc is zero)
        self.W_dec.data[dead_latents, :] = replacement_values_normalized
        self.W_enc.data[:, dead_latents] = (
            replacement_values_normalized.T * W_enc_norm_alive_mean * resample_scale
        )
        self.b_enc.data[dead_latents] = 0.0

    """
        Forward
        Optimize
        resample
    """

In [8]:
device = 'cuda:3' if torch.cuda.is_available() else 'cpu'
sae = SAE(cfg, model).to(device)

In [9]:
sae.optimize()

Failed to detect the name of this notebook, you can set it manually with the WANDB_NOTEBOOK_NAME environment variable to enable code saving.
[34m[1mwandb[0m: Currently logged in as: [33misistickz[0m ([33mself_research_[0m). Use [1m`wandb login --relogin`[0m to force relogin


100%|██████████| 30000/30000 [39:06<00:00, 12.79it/s, L_reconstruction=8.84e+6, L_sparsity=4.79e+5, frac_active=0.645, lr=0.0004]  


In [20]:
SAVE_DIR = Path("trained_saes")


def save():
    version = 1
    name = f"L{sae.cfg.layer}_{sae.cfg.d_sae}_L1-{sae.cfg.l1_coefficient}_Lr-{sae.cfg.lr}_V{version}"
    torch.save(sae.state_dict(), SAVE_DIR/(name+".pt"))
    with open(SAVE_DIR/(name+"_cfg.json"), "w") as f:
        json.dump(asdict(cfg), f)
    print("Saved as version", version)

save()

Saved as version 1


In [48]:
@torch.inference_mode()
def highest_activating_tokens(
    tokens: Int[Tensor, "batch seq"],
    model: HookedTransformer,
    autoencoder: SAE,
    feature_idx: int,
    autoencoder_B: bool = False,
    k: int = 20,
) -> Tuple[Int[Tensor, "k 2"], Float[Tensor, "k"]]:
    '''
    Returns the indices & values for the highest-activating tokens in the given batch of data.
    '''
    batch_size, seq_len = tokens.shape
    # instance_idx = 1 if autoencoder_B else 0/

    # Get the post activations from the clean run
    cache = model.run_with_cache(tokens, names_filter=["blocks.8.hook_resid_post"])[1]
    post = cache["blocks.8.hook_resid_post"]
    post_reshaped = einops.rearrange(post, "batch seq d_model -> (batch seq) d_model")

    # Compute activations (not from a fwd pass, but explicitly, by taking only the feature we want)
    # This code is copied from the first part of the 'forward' method of the AutoEncoder class
    h_cent = post_reshaped - autoencoder.b_dec
    acts = einops.einsum(
        h_cent, autoencoder.W_enc[:, feature_idx],
        "batch_size n_input_ae, n_input_ae -> batch_size"
    )

    # Get the top k largest activations
    top_acts_values, top_acts_indices = acts.topk(k)

    # Convert the indices into (batch, seq) indices
    top_acts_batch = top_acts_indices // seq_len
    top_acts_seq = top_acts_indices % seq_len

    return torch.stack([top_acts_batch, top_acts_seq], dim=-1), top_acts_values


def display_top_sequences(top_acts_indices, top_acts_values, tokens):
    table = Table("Sequence", "Activation", title="Tokens which most activate this feature")
    for (batch_idx, seq_idx), value in zip(top_acts_indices, top_acts_values):
        # Get the sequence as a string (with some padding on either side of our sequence)
        seq = ""
        for i in range(max(seq_idx-5, 0), min(seq_idx+5, all_tokens.shape[1])):
            new_str_token = model.to_single_str_token(tokens[batch_idx, i].item()).replace("\n", "\\n")
            # Highlight the token with the high activation
            if i == seq_idx: new_str_token = f"[b u dark_orange]{new_str_token}[/]"
            seq += new_str_token
        # Print the sequence, and the activation value
        table.add_row(seq, f'{value:.2f}')
    rprint(table)

tokens = all_tokens[:200]
top_acts_indices, top_acts_values = highest_activating_tokens(tokens, model, sae, feature_idx=11111, autoencoder_B=False)
display_top_sequences(top_acts_indices, top_acts_values, tokens)