#Utils

In [None]:
import os
from typing import Optional

import numpy as np
import polars as pl
import torch
from transformers import PreTrainedModel, PreTrainedTokenizer

In [None]:
def create_file(dir: str, file_name: str) -> None:
    if not os.path.isdir(dir):
        raise ValueError(f"The specified directory '{dir}' does not exist.")

    file_path = os.path.join(dir, file_name)
    try:
        open(file_path, "x").close()
    except FileExistsError:
        pass


def get_layer_activations(
    tokenizer: PreTrainedTokenizer,
    plm: PreTrainedModel,
    seqs: list[str],
    layer: int,
    device: Optional[torch.device] = None,
) -> torch.Tensor:
    """
    Get the activations of a specific layer in a pLM model. Let:

    ```
    N = len(seqs)
    L = max(len(seq) for seq in seqs) + 2 # +2 for BOS and EOS tokens
    D_MODEL = the layer dimension of the pLM, i.e. "Embedding Dim" column here
        https://github.com/facebookresearch/esm/tree/main?tab=readme-ov-file#available-models
    ```

    The output tensor is of shape (N, L, D_MODEL)

    Args:
        tokenizer: The tokenizer to use.
        plm: The pLM model to get the activations from.
        seqs: The sequences to get the activations for.
        layer: The layer to get the activations from.
        device: The device to use.

    Returns:
        The (N, L, D_MODEL) activations of the specified layer.
    """
    if device is None:
        device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

    inputs = tokenizer(seqs, padding=True, return_tensors="pt").to(device)
    with torch.no_grad():
        outputs = plm(**inputs, output_hidden_states=True)
    layer_acts = outputs.hidden_states[layer]
    del outputs
    return layer_acts


def train_val_test_split(
    df: pl.DataFrame, train_frac: float = 0.9
) -> tuple[pl.DataFrame, pl.DataFrame, pl.DataFrame]:
    """
    Split the sequences into training, validation, and test sets. train_frac specifies
    the fraction of examples to use for training; the rest is split evenly between
    validation and test.

    Doing this by samples so it's stochastic.

    Args:
        seqs: The sequences to split.
        train_frac: The fraction of examples to use for training.

    Returns:
        A tuple containing the training, validation, and test sets.
    """
    is_train = pl.Series(
        np.random.choice([True, False], size=len(df), p=[train_frac, 1 - train_frac])
    )
    seqs_train = df.filter(is_train)
    seqs_val_test = df.filter(~is_train)

    is_val = pl.Series(np.random.choice([True, False], size=len(seqs_val_test), p=[0.1, 0.9]))
    seqs_val = seqs_val_test.filter(is_val)
    seqs_test = seqs_val_test.filter(~is_val)
    return seqs_train, seqs_val, seqs_test


def parse_swissprot_annotation(annotation_str: str, header: str) -> list[dict]:
    """
    Parse a SwissProt annotation string like this:

    ```
    MOTIF 119..132; /note="JAMM motif"; /evidence="ECO:0000255|PROSITE-ProRule:PRU01182"
    ```
    where MOTIF is the header argument.

    Returns:
        [{
            "start": 119,
            "end": 132,
            "note": "JAMM motif",
            "evidence": "ECO:0000255|PROSITE-ProRule:PRU01182",
        }]
    """
    res = []
    occurrences = [o for o in annotation_str.split(header + " ") if len(o) > 0]
    for o in occurrences:
        parts = [p for p in o.split("; /") if len(p) > 0]

        pos_part = parts[0]
        coords = pos_part.split("..")

        annotations_dict = {}
        for part in parts[1:]:
            key, value = part.split("=", 1)
            annotations_dict[key] = value.replace('"', "").replace(";", "").strip()

        try:
            list(map(int, coords))
        except ValueError:
            continue
        if len(annotations_dict) == 0:
            continue

        res.append(
            {
                "start": int(coords[0]),
                "end": int(coords[1]) if len(coords) > 1 else int(coords[0]),
                **annotations_dict,
            }
        )
    return res

#SAE Model

In [None]:
import math
from typing import Optional

import numpy as np
import polars as pl
import torch
import torch.nn as nn
from torch.nn import functional as F
from transformers import PreTrainedModel, PreTrainedTokenizer

In [None]:
class SparseAutoencoder(nn.Module):
    def __init__(
        self,
        d_model: int,
        d_hidden: int,
        k: int = 128,
        auxk: int = 256,
        batch_size: int = 256,
        dead_steps_threshold: int = 2000,
    ):
        """
        Initialize the Sparse Autoencoder.

        Args:
            d_model: Dimension of the pLM model.
            d_hidden: Dimension of the SAE hidden layer.
            k: Number of top-k activations to keep.
            auxk: Number of auxiliary activations.
            dead_steps_threshold: How many examples of inactivation before we consider
                a hidden dim dead.

        Adapted from https://github.com/tylercosgrove/sparse-autoencoder-mistral7b/blob/main/sae.py
        based on 'Scaling and evaluating sparse autoencoders' (Gao et al. 2024) https://arxiv.org/pdf/2406.04093
        """
        super().__init__()

        self.w_enc = nn.Parameter(torch.empty(d_model, d_hidden))
        self.w_dec = nn.Parameter(torch.empty(d_hidden, d_model))

        self.b_enc = nn.Parameter(torch.zeros(d_hidden))
        self.b_pre = nn.Parameter(torch.zeros(d_model))

        self.d_model = d_model
        self.d_hidden = d_hidden
        self.k = k
        self.auxk = auxk
        self.batch_size = batch_size

        self.dead_steps_threshold = dead_steps_threshold / batch_size

        # TODO: Revisit to see if this is the best way to initialize
        nn.init.kaiming_uniform_(self.w_enc, a=math.sqrt(5))
        self.w_dec.data = self.w_enc.data.T.clone()
        self.w_dec.data /= self.w_dec.data.norm(dim=0)

        # Initialize dead neuron tracking. For each hidden dimension, save the
        # index of the example at which it was last activated.
        self.register_buffer("stats_last_nonzero", torch.zeros(d_hidden, dtype=torch.long))

    def topK_activation(self, x: torch.Tensor, k: int) -> torch.Tensor:
        """
        Apply top-k activation to the input tensor.

        Args:
            x: (BATCH_SIZE, D_EMBED, D_MODEL) input tensor to apply top-k activation on.
            k: Number of top activations to keep.

        Returns:
            torch.Tensor: Tensor with only the top k activations preserved,and others
            set to zero.

        This function performs the following steps:
        1. Find the top k values and their indices in the input tensor.
        2. Apply ReLU activation to these top k values.
        3. Create a new tensor of zeros with the same shape as the input.
        4. Scatter the activated top k values back into their original positions.
        """
        topk = torch.topk(x, k=k, dim=-1, sorted=False)
        values = F.relu(topk.values)
        result = torch.zeros_like(x)
        result.scatter_(-1, topk.indices, values)
        return result

    def LN(
        self, x: torch.Tensor, eps: float = 1e-5
    ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
        """
        Apply Layer Normalization to the input tensor.

        Args:
            x: Input tensor to be normalized.
            eps: A small value added to the denominator for numerical stability.

        Returns:
            tuple[torch.Tensor, torch.Tensor, torch.Tensor]: A tuple containing:
                - The normalized tensor.
                - The mean of the input tensor.
                - The standard deviation of the input tensor.

        TODO: Is eps = 1e-5 the best value?
        """
        mu = x.mean(dim=-1, keepdim=True)
        x = x - mu
        std = x.std(dim=-1, keepdim=True)
        x = x / (std + eps)
        return x, mu, std

    def auxk_mask_fn(self) -> torch.Tensor:
        """
        Create a mask for dead neurons.

        Returns:
            torch.Tensor: A boolean tensor of shape (D_HIDDEN,) where True indicates
                a dead neuron.
        """
        dead_mask = self.stats_last_nonzero > self.dead_steps_threshold
        return dead_mask

    def forward(self, x: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
        """
        Forward pass of the Sparse Autoencoder. If there are dead neurons, compute the
        reconstruction using the AUXK auxiliary hidden dims as well.

        Args:
            x: (BATCH_SIZE, D_EMBED, D_MODEL) input tensor to the SAE.

        Returns:
            tuple[torch.Tensor, torch.Tensor, torch.Tensor]: A tuple containing:
                - The reconstructed activations via top K hidden dims.
                - If there are dead neurons, the auxiliary activations via top AUXK
                    hidden dims; otherwise, None.
                - The number of dead neurons.
        """
        x, mu, std = self.LN(x)
        x = x - self.b_pre

        pre_acts = x @ self.w_enc + self.b_enc

        # latents: (BATCH_SIZE, D_EMBED, D_HIDDEN)
        latents = self.topK_activation(pre_acts, k=self.k)

        # `(latents == 0)` creates a boolean tensor element-wise from `latents`.
        # `.all(dim=(0, 1))` preserves D_HIDDEN and does the boolean `all`
        # operation across BATCH_SIZE and D_EMBED. Finally, `.long()` turns
        # it into a vector of 0s and 1s of length D_HIDDEN.
        #
        # self.stats_last_nonzero is a vector of length D_HIDDEN. Doing
        # `*=` with `M = (latents == 0).all(dim=(0, 1)).long()` has the effect
        # of: if M[i] = 0, self.stats_last_nonzero[i] is cleared to 0, and then
        # immediately incremented; if M[i] = 1, self.stats_last_nonzero[i] is
        # unchanged. self.stats_last_nonzero[i] means "for how many consecutive
        # iterations has hidden dim i been zero".
        self.stats_last_nonzero *= (latents == 0).all(dim=(0, 1)).long()
        self.stats_last_nonzero += 1

        dead_mask = self.auxk_mask_fn()
        num_dead = dead_mask.sum().item()

        recons = latents @ self.w_dec + self.b_pre
        recons = recons * std + mu

        if num_dead > 0:
            k_aux = min(x.shape[-1] // 2, num_dead)

            auxk_latents = torch.where(dead_mask[None], pre_acts, -torch.inf)
            auxk_acts = self.topK_activation(auxk_latents, k=k_aux)

            auxk = auxk_acts @ self.w_dec + self.b_pre
            auxk = auxk * std + mu
        else:
            auxk = None

        return recons, auxk, num_dead

    @torch.no_grad()
    def forward_val(self, x: torch.Tensor) -> torch.Tensor:
        """
        Forward pass of the Sparse Autoencoder for validation.

        Args:
            x: (BATCH_SIZE, D_EMBED, D_MODEL) input tensor to the SAE.

        Returns:
            torch.Tensor: The reconstructed activations via top K hidden dims.
        """
        x, mu, std = self.LN(x)
        x = x - self.b_pre
        pre_acts = x @ self.w_enc + self.b_enc
        latents = self.topK_activation(pre_acts, self.k)

        recons = latents @ self.w_dec + self.b_pre
        recons = recons * std + mu
        return recons

    @torch.no_grad()
    def norm_weights(self) -> None:
        """
        Normalize the weights of the Sparse Autoencoder.
        """
        self.w_dec.data /= self.w_dec.data.norm(dim=0)

    @torch.no_grad()
    def norm_grad(self) -> None:
        """
        Normalize the gradient of the weights of the Sparse Autoencoder.
        """
        dot_products = torch.sum(self.w_dec.data * self.w_dec.grad, dim=0)
        self.w_dec.grad.sub_(self.w_dec.data * dot_products.unsqueeze(0))

    @torch.no_grad()
    def get_acts(self, x: torch.Tensor) -> torch.Tensor:
        """
        Get the activations of the Sparse Autoencoder.

        Args:
            x: (BATCH_SIZE, D_EMBED, D_MODEL) input tensor to the SAE.

        Returns:
            torch.Tensor: The activations of the Sparse Autoencoder.
        """
        x, _, _ = self.LN(x)
        x = x - self.b_pre
        pre_acts = x @ self.w_enc + self.b_enc
        latents = self.topK_activation(pre_acts, self.k)
        return latents

    @torch.no_grad()
    def encode(self, x: torch.Tensor) -> torch.Tensor:
        x, mu, std = self.LN(x)
        x = x - self.b_pre
        acts = x @ self.w_enc + self.b_enc
        return acts, mu, std

    @torch.no_grad()
    def decode(self, acts: torch.Tensor, mu: torch.Tensor, std: torch.Tensor) -> torch.Tensor:
        latents = self.topK_activation(acts, self.k)

        recons = latents @ self.w_dec + self.b_pre
        recons = recons * std + mu
        return recons


def loss_fn(
    x: torch.Tensor, recons: torch.Tensor, auxk: Optional[torch.Tensor] = None
) -> tuple[torch.Tensor, torch.Tensor]:
    """
    Compute the loss function for the Sparse Autoencoder.

    Args:
        x: (BATCH_SIZE, D_EMBED, D_MODEL) input tensor to the SAE.
        recons: (BATCH_SIZE, D_EMBED, D_MODEL) reconstructed activations via top K
            hidden dims.
        auxk: (BATCH_SIZE, D_EMBED, D_MODEL) auxiliary activations via top AUXK
            hidden dims. See A.2. in https://arxiv.org/pdf/2406.04093.

    Returns:
        tuple[torch.Tensor, torch.Tensor]: A tuple containing:
            - The MSE loss.
            - The auxiliary loss.
    """
    mse_scale = 1
    auxk_coeff = 1.0 / 32.0  # TODO: Is this the best coefficient?

    mse_loss = mse_scale * F.mse_loss(recons, x)
    if auxk is not None:
        auxk_loss = auxk_coeff * F.mse_loss(auxk, x - recons).nan_to_num(0)
    else:
        auxk_loss = torch.tensor(0.0)
    return mse_loss, auxk_loss


def estimate_loss(
    plm: PreTrainedModel,
    tokenizer: PreTrainedTokenizer,
    layer: int,
    sae_model: SparseAutoencoder,
    examples_set: pl.DataFrame,
    sample_size: int = 100,
):
    """
    Estimate the loss of the Sparse Autoencoder using a set of examples.

    Args:
        sae_model: The Sparse Autoencoder model.
        examples_set: The examples set to estimate the loss on.
        sample_size: The number of examples to sample.

    Returns:
        float: The estimated loss.
    """
    samples = examples_set.sample(sample_size)
    test_losses = []
    seqs = [row["Sequence"] for row in samples.iter_rows(named=True)]
    layer_acts = get_layer_activations(tokenizer=tokenizer, plm=plm, seqs=seqs, layer=layer)

    recons = sae_model.forward_val(layer_acts)
    mse_loss, _ = loss_fn(layer_acts, recons)
    test_losses.append(mse_loss.item())

    del layer_acts
    return np.mean(test_losses)

#Training

In [None]:
import argparse
import os

#fix for numpy version 2.0+
np.float_ = np.float64
np.complex_ = np.complex128

import pytorch_lightning as pl
import wandb
from data_module import SequenceDataModule
from pytorch_lightning.callbacks import ModelCheckpoint
from pytorch_lightning.loggers import WandbLogger
from sae_module import SAELightningModule

TypeError: Descriptors cannot be created directly.
If this call came from a _pb2.py file, your generated code is out of date and must be regenerated with protoc >= 3.19.0.
If you cannot immediately regenerate your protos, some other possible workarounds are:
 1. Downgrade the protobuf package to 3.20.x or lower.
 2. Set PROTOCOL_BUFFERS_PYTHON_IMPLEMENTATION=python (but this will use pure-Python parsing and will be much slower).

More information: https://developers.google.com/protocol-buffers/docs/news/2022-05-06#python-updates