In [18]:
import os, gzip, math, random, json, gc
from pathlib import Path
from typing import List, Iterable, Tuple, Dict
import torch
from torch.utils.data import DataLoader
from tqdm import tqdm
from Bio import SeqIO
from utils import extract_esm_features_batch
import random
import os
import math
from dataclasses import dataclass
from typing import Optional, Iterable, Dict

import torch
import torch.nn.functional as F
from torch import nn
from torch.cuda.amp import autocast, GradScaler


In [8]:
FASTA_GZ = Path("/home/ec2-user/SageMaker/InterPLM/data/uniprot/uniprot_sprot.fasta.gz")
OUT_DIR  = Path("/home/ec2-user/SageMaker/InterPLM/sae_runs/es2_swissprot_layer24")

# ---------- Device ----------
device = t.device("cuda" if t.cuda.is_available() else "cpu")
device

device(type='cuda')

In [None]:
from transformers import AutoTokenizer, AutoModel
import torch

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

tokenizer = AutoTokenizer.from_pretrained(ESM_NAME, do_lower_case=False)
model = AutoModel.from_pretrained(ESM_NAME).eval().to(device)

In [9]:
ESM_NAME = "facebook/esm2_t33_650M_UR50D"
MAX_LEN = 1024 #truncate longer sequences
SEED = 17
random.seed(SEED)
BATCH_SIZE_SEQ = 16
LAYER_TO_TRAIN=24
TOKENS_PER_STEP = 4096 #otken vectors per SAE update


## Stream Swissprot protein sequences

In [11]:
def iter_swissprot_sequences(fasta_gz: Path, max_len: int = MAX_LEN):
    with gzip.open(fasta_gz, "rt") as fh:
        for rec in SeqIO.parse(fh, "fasta"):
            seq = str(rec.seq)
            if not seq:
                continue
            yield rec.id, (seq[:max_len] if len(seq) > max_len else seq)

def batched_sequences(fasta_gz: Path, batch_size: int = BATCH_SIZE_SEQ):
    buf = []
    for _, seq in iter_swissprot_sequences(fasta_gz):
        buf.append(seq)
        if len(buf) == batch_size:
            yield buf
            buf = []
    if buf:
        yield buf #yields array of [B, L] sequences


## Grab ESM activations in batches

In [46]:

@torch.no_grad()
def activation_batches(
    fasta_gz: Path,
    tokens_per_step: int = TOKENS_PER_STEP,
    layer_sel=LAYER_TO_TRAIN,
    dtype = torch.float16,
    *,
    model,
    tokenizer,
) -> Iterable[torch.Tensor]:

    buf = []
    total = 0
    for seq_batch in batched_sequences(fasta_gz, BATCH_SIZE_SEQ):
        reps, attn = extract_esm_features_batch(
            sequences=seq_batch,
            layer_sel=layer_sel,
            device=DEVICE,
            dtype=dtype,
            model=model,
            tokenizer=tokenizer,
        ) #reps: [B,L,d], attn:[B, L]

        B, L, D = reps.shape #Batch, sequence length, hidden dimension = 1280
        for b in range(B): #for sequence in the batch
            valid = attn[b] # bool[L], which ones were we supposed to attend to?
            if valid.any():
                buf.append(reps[b][valid].detach().to("cpu", dtype=torch.float32)) #float32 for SAE
                total += int(valid.sum())
     
        # When we have enough tokens, build a randomized batch and a randomized remainder.
        #Yield turns the function into a generator, so the local variables are preserved between next() calls. once we drop references to it or stop iteration, it gets cleaned up
        if total >= tokens_per_step:
            all_tokens = torch.cat(buf, dim=0)                # [N, d_in]
            N = all_tokens.size(0)
            perm = torch.randperm(N)                          # random order
            take = perm[:tokens_per_step]
            keep = perm[tokens_per_step:]

            X    = all_tokens[take]                           # [tokens_per_step, d_in]
            rest = all_tokens[keep]                           # [N - tokens_per_step, d_in]

            # Carry forward the (already shuffled) remainder
            buf = ([rest] if rest.numel() else [])
            total = int(rest.size(0)) if rest.numel() else 0

            yield X

    # flush remainder (possibly a short batch)
    if buf:
        X = torch.cat(buf, dim=0)
        if X.numel() > 0:
            # optional: shuffle this short batch, too
            idx = torch.randperm(X.size(0))
            X = X[idx]
            yield X

## Sparse Autoencoder (ReLU activation function)

In [42]:

class SAE(nn.Module):
    """
    One-layer sparse autoencoder (untied decoder).

    Design:
      x -- (center by bias) --> encoder (Linear) --> LayerNorm? --> ReLU --> f (sparse features)
      f -- decoder(linear) --> + bias --> xhat

    Shapes (let:
    
        B = batch size in *tokens* (each row is one token embedding),
        D = input dim (ESM hidden size for a given layer, e.g., 320/480/640/1280)
        K = number of features or Dictionary size
    ):
        -x: [B, D]
        -f: [B, K]
        -x_hat: [B, D]
        -encoder.weight: [K, D] (each row maps input -> a feature pre-activation)
        -decoder.weight: [D, K] (each column is a dictionary atom in input space)
    
    Notes:
        - The learned bias centers inputs before encoding, then is added back after decoding.
        - Decoder is *untied* (separate matrix from encoder) - standard in mech interp saes.
        - ReLU enforces Nonnegativity/sparsity in f.

    """

    def __init__(
        self,
        d_in: int, #D
        d_hidden: int, #K
        use_layernorm: bool=False,
        init_unitnorm_decoder: bool = True,
    ):
        super().__init__()
        self.d_in = d_in #D
        self.d_hidden = d_hidden #K

        #Learned centering bias b \ in R^D
        self.bias = nn.Parameter(torch.zeros(d_in)) # [D]

        #Encoder W_e: R^D => Expanded into R^K (plus bias)
        self.encoder = nn.Linear(d_in, d_hidden, bias=True) #weight: [K, D], bias: [K]

        self.ln = nn.LayerNorm(d_hidden) if use_layernorm else nn.Identity()

        #Decoder W_d: R^k => R^D (no bias; we add +b afterwards)
        self.decoder = nn.Linear(d_hidden, d_in, bias=False) # weight: [D, K]

        #Initialization

        nn.init.xavier_uniform_(self.encoder.weight) #[K, D], _ represents in place operation
        if self.encoder.bias is not None:
            nn.init.zeros_(self.encoder.bias) #[K]
        
        if init_unitnorm_decoder:
            #make decoder columns unit norm: each atom d_i = decoder.weight[:, i] / ||.||
            with torch.no_grad():
                W = torch.randn_like(self.decoder.weight) #[D, K]
                W = W / (W.norm(dim=0, keepdim=True) + 1e-12)
                self.decoder.weight.copy_(W)
        else:
            nn.init.xavier_uniform_(self.decoder.weight) #[D, K]
    
    #Core API

    def encode(self, x: torch.Tensor) -> torch.Tensor:
        """
        x: [B, D]
        returns f: [B, K] (sparse, nonnegatvie)
        #Multiplication happens as x @ W_e.T we transpose for dimensions to work out
        #[B, D] @ [D, K] => [B, K]

        """


        #cneter: (x - b) keeps encoder from wasting capacity ont he mean
        f_pre = self.encoder(x - self.bias) #[B, K], encoder has weights [K, D]
        f = F.relu(self.ln(f_pre)) #[B, K]
        return f

    def decode(self, f: torch.Tensor) -> torch.Tensor:
        """
        f: [B, K]
        returns x_hat: [B,D]

        #Same thing here, where is f @ W_d.t + bias
        so [B, K] @ [K, D]

        but the decoder is actually [D, K] where the columns represent each feature
        """

        x_hat_no_bias = self.decoder(f) #[B, D]
        return x_hat_no_bias + self.bias #[B, D]

    def forward(self, x:torch.Tensor, output_features: bool = False) -> Tuple[torch.Tensor, torch.Tensor] | torch.Tensor:
        """
        x: [B, D]
        returns:
            - if output_features=False: x_hat[B, D]  
            - if output_features=True: (x_hat [B, D], f [B, K])    
        """
        f = self.encode(x) #[B, K] feature activations for each dictionary vector
        x_hat = self.decode(f) #[B, D]

        return (x_hat, f) if output_features else x_hat
    
    #Helpers
    @staticmethod
    def l1(f: torch.Tensor) -> torch.Tensor:
        """
        Mean absolute activation (sparsity penalty).
        f: [B,K] -> scalar
        
        """
        return f.abs().mean()
    
    @staticmethod
    def l0(f: torch.Tensor, eps: float = 1e-12) -> torch.Tensor:
        """
        Average number of active (nonzero) features per example.
        f: [B, K] -> scalar
        """
        return (f > eps).float().sum(dim=-1).mean()

    @torch.no_grad()
    def rescale_features_(self, scale: torch.Tensor, eps: float=1e-8) -> None:
        """
        Per-feature rescaling that keeps reconstructions invariant but changes feature scale.
        Use this after estimating per-feature maxima or P99 values.
        Goal: want f' = f/s, and x_hat' = x_hat.

        Achieve by:
            encoder.weight(rows) /= s
            encoder.bias /= s
            decoder.weight(cols) *= s
        Args: 
            scale: [K] positive scale per feature (clipped ot >= eps)
        """
        s = torch.clamp(scale.detach().to(self.encoder.weight.device), min=eps) #[K]

        #encoder.weight: [K, D] (row i corresponds to feature i)
        self.encoder.weight.data.div_(s[:, None])
        if self.encoder.bias is not None:
            self.encoder.bias.data.div_(s)
        
        #decoder.weight: [D, K] (column i corresponds to feature i)
        self.decoder.weight.data.mul_(s[None, :])


## Training runner


In [59]:
@dataclass
class SAERunnerConfig:
    d_in: int # ESM layer hidden size (D)
    d_hidden: int #dictionary size (K)
    lr: float = 1e-3
    weight_decay: float = 0.0
    lambda_l1: float = 0.1 #sparsity strength
    grad_clip: Optional[float] = 1.0
    amp_dtype: torch.dtype = torch.float16
    use_amp: bool = True
    device: str = "cuda"
    save_dir: Optional[str] = None
    log_every: int = 50 #Steps for printing
    max_steps: Optional[int] = None #stops early if set

class SAERunner: 
    def __init__(self, ae: nn.Module, cfg: SAERunnerConfig):
        self.ae = ae.to(cfg.device)
        self.cfg = cfg

        self.opt = torch.optim.AdamW(self.ae.parameters(), lr = cfg.lr, weight_decay=cfg.weight_decay)
        self.scaler = GradScaler(enabled=cfg.use_amp)

        if cfg.save_dir:
            os.makedirs(cfg.save_dir, exist_ok=True)
            os.makedirs(os.path.join(cfg.save_dir, "checkpoints"), exist_ok=True)
        self.step = 0
    
    def _step(self, x_cpu: torch.Tensor) -> Dict[str, float]:
        """
        x_cpu: [B, D] float32 on CPU from activation_batches.
        Moves to device and runs one optimization step.
        """
        # Move batch to device (e.g. cuda:0)
        x = x_cpu.to(self.cfg.device, non_blocking=True)  # [B, D]

        # Forward + loss
        x_hat, f = self.ae(x, output_features=True)       # x_hat: [B, D], f: [B, K]
        mse = F.mse_loss(x_hat, x)                        # reconstruction loss
        l1  = f.abs().mean()                              # sparsity penalty
        loss = mse + self.cfg.lambda_l1 * l1              # total loss

        # Backward + step
        self.opt.zero_grad(set_to_none=True)
        loss.backward()
        self.opt.step()

        # Compute l0 (non-differentiable, just for logging)
        with torch.no_grad():
            l0 = (f > 1e-12).float().sum(dim=-1).mean().item()

        return {
            "loss": float(loss.item()),
            "mse": float(mse.item()),
            "l1": float(l1.item()),
            "l0": float(l0),
        }


    def save(self, name: str):
        if not self.cfg.save_dir:
            return
        path = os.path.join(self.cfg.save_dir, "checkpoints", f"{name}.pt")
        torch.save({
            "step": self.step,
            "ae_state_dict": self.ae.state_dict(),
            "opt_state_dict": self.opt.state_dict(),
            "cfg": self.cfg.__dict__,
        }, path)

    def train_stream(self, stream: Iterable[torch.Tensor]):
        """
        stream: yields CPU float32 tensors X of shape [N_tokens, D].
        """
        total_steps = self.cfg.max_steps if self.cfg.max_steps is not None else None

        with tqdm(total=total_steps, desc="Training", unit="batch") as pbar:
            for x in stream:
                stats = self._step(x)
                self.step += 1

                # Update bar
                pbar.set_postfix({
                    "loss": f"{stats['loss']:.4f}",
                    "mse":  f"{stats['mse']:.4f}",
                    "l1":   f"{stats['l1']:.4f}",
                    "l0":   f"{stats['l0']:.1f}",
                })
                pbar.update(1)

                if self.cfg.save_dir and (self.step % (self.cfg.log_every * 10) == 0):
                    self.save(f"ae_{self.step}")

                if self.cfg.max_steps is not None and self.step >= self.cfg.max_steps:
                    break

            # Final checkpoint
            self.save("final")




## Full Training script

In [60]:
def train_sae_on_esm_stream(
    fasta_gz: Path,
    d_in: int, #ESM layer hidden dim, e.g. 320/480/640/1280
    d_hidden: int, #dictionary size(features), e.g. 32 * d_in
    tokens_per_step: int,
    max_steps: int,
    save_dir: str,
    *,
    model,
    tokenizer,
    layer_sel = LAYER_TO_TRAIN,
):
# 1) build model

    ae = SAE(d_in = d_in, d_hidden=d_hidden, use_layernorm=False,init_unitnorm_decoder=True)

    #runner config
    cfg = SAERunnerConfig(
        d_in=d_in,
        d_hidden=d_hidden,
        lr = 1e-3,
        weight_decay=0.0,
        lambda_l1 = 0.1, #tune based on desired sparsity
        grad_clip = 1.0,
        use_amp=False,
        amp_dtype=torch.float16,
        device=DEVICE, 
        save_dir = save_dir,
        log_every=50,
        max_steps=max_steps,
    )

    runner = SAERunner(ae, cfg)

    #3) stream activations and train
    stream = activation_batches(
        fasta_gz = fasta_gz,
        tokens_per_step = tokens_per_step,
        layer_sel = layer_sel,
        dtype=torch.float16,
        model = model,
        tokenizer=tokenizer
    )

    runner.train_stream(stream)
    return runner.ae

In [61]:
fasta_path = Path("/home/ec2-user/SageMaker/InterPLM/data/uniprot/uniprot_sprot.fasta.gz")

TOKENS_PER_STEP = 2048
MAX_STEPS = 5000
SAVE_DIR = f"./sae_runs/layer_{LAYER_TO_TRAIN}_32x"

# Choose layer and dims
ESM_D_IN = 1280                # depends on your ESM checkpoint
EXPANSION = 32
K = EXPANSION * ESM_D_IN
DEVICE="cuda"

d_in = mdl.config.hidden_size  # SAE input dimension

ae = train_sae_on_esm_stream(
    fasta_gz=fasta_path,
    d_in=ESM_D_IN,
    d_hidden=K,
    tokens_per_step=TOKENS_PER_STEP,
    max_steps=MAX_STEPS,
    save_dir=SAVE_DIR,
    model=model,
    tokenizer=tokenizer
)

  self.scaler = GradScaler(enabled=cfg.use_amp)
Training:   1%|          | 47/5000 [00:13<23:07,  3.57batch/s, loss=26.1822, mse=26.1697, l1=0.1245, l0=1504.1]    


OutOfMemoryError: CUDA out of memory. Tried to allocate 86.00 MiB. GPU 0 has a total capacity of 79.25 GiB of which 46.94 MiB is free. Including non-PyTorch memory, this process has 79.20 GiB memory in use. Of the allocated memory 74.81 GiB is allocated by PyTorch, and 3.88 GiB is reserved by PyTorch but unallocated. If reserved but unallocated memory is large try setting PYTORCH_CUDA_ALLOC_CONF=expandable_segments:True to avoid fragmentation.  See documentation for Memory Management  (https://pytorch.org/docs/stable/notes/cuda.html#environment-variables)