In [12]:
import os, gzip, math, random, json, gc
from pathlib import Path
from typing import List, Iterable, Tuple, Dict

import torch as t
from torch.utils.data import DataLoader
from tqdm import tqdm
from Bio import SeqIO
from utils import extract_esm_features_batch

In [2]:
FASTA_GZ = Path("/home/ec2-user/SageMaker/InterPLM/data/uniprot/uniprot_sprot.fasta.gz")
OUT_DIR  = Path("/home/ec2-user/SageMaker/InterPLM/sae_runs/es2_swissprot_layer24")

# ---------- Device ----------
device = t.device("cuda" if t.cuda.is_available() else "cpu")
device

device(type='cuda')

In [13]:
ESM_NAME = "facebook/esm2_t33_650M_UR50D"
MAX_LEN = 1024 #truncate longer sequences
SEED = 17
random.seed(SEED)
BATCH_SIZE_SEQ = 16
LAYER_TO_TRAIN=24
TOEKNS_PER_STEP = 4096 #otken vectors per SAE update


In [11]:
from transformers import AutoTokenizer, AutoModel
import torch

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

tok = AutoTokenizer.from_pretrained(ESM_NAME, do_lower_case=False)
mdl = AutoModel.from_pretrained(ESM_NAME).eval().to(device)

d_in = mdl.config.hidden_size  # SAE input dimension
d_in

Some weights of EsmModel were not initialized from the model checkpoint at facebook/esm2_t33_650M_UR50D and are newly initialized: ['pooler.dense.bias', 'pooler.dense.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


1280

In [7]:


def iter_swissprot_sequences(fasta_gz: Path, max_len: int = MAX_LEN):
    with gzip.open(fasta_gz, "rt") as fh:
        for rec in SeqIO.parse(fh, "fasta"):
            seq = str(rec.seq)
            if not seq:
                continue
            yield rec.id, (seq[:max_len] if len(seq) > max_len else seq)

def batched_sequences(fasta_gz: Path, batch_size: int = BATCH_SIZE_SEQ):
    buf = []
    for _, seq in iter_swissprot_sequences(fasta_gz):
        buf.append(seq)
        if len(buf) == batch_size:
            yield buf
            buf = []
    if buf:
        yield buf #yields array of [B, L] sequences


MAFSAEDVLKEYDRRRRMEALLLSLYYPNDRKLLDYKEWSPPRVQVECPKAPVEWNNPPSEKGLIVGHFSGIKYKGEKAQASEVDVNKMCCWVSKFKDAMRRYQGIQTCKIPGKVLSDLDAKIKAYNLTVEGVEGFVRYSRVTKQHVAAFLKELRHSKQYENVNLIHYILTDKRVDIQHLEKDLVKDFKALVESAHRMRQGHMINVKYILYQLLKKHGHGPDGPDILTVKTGSKGVLYDDSFRKIYTDLGWKFTPL
('sp|Q6GZX4|001R_FRG3G', 'MAFSAEDVLKEYDRRRRMEALLLSLYYPNDRKLLDYKEWSPPRVQVECPKAPVEWNNPPSEKGLIVGHFSGIKYKGEKAQASEVDVNKMCCWVSKFKDAMRRYQGIQTCKIPGKVLSDLDAKIKAYNLTVEGVEGFVRYSRVTKQHVAAFLKELRHSKQYENVNLIHYILTDKRVDIQHLEKDLVKDFKALVESAHRMRQGHMINVKYILYQLLKKHGHGPDGPDILTVKTGSKGVLYDDSFRKIYTDLGWKFTPL')
MSIIGATRLQNDKSDTYSAGPCYAGGCSAFTPRGTCGKDWDLGEQTCASGFCTSQPLCARIKKTQVCGLRYSSKGKDPLVSAEWDSRGAPYVRCTYDADLIDTQAQVDQFVSMFGESPSLAERYCMRGVKNTAGELVSRVSSDADPAGGWCRKWYSAHRGPDQDAALGSFCIKNPGAADCKCINRASDPVYQKVKTLHAYPDQCWYVPCAADVGELKMGTQRDTPTNCPTQVCQIVFNMLDDGSVTMDDVKNTINCDFSKYVPPPPPPKPTPPTPPTPPTPPTPPTPPTPPTPRPVHNRKVMFFVAGAVLVAILISTVRW
('sp|Q6GZX3|002L_FRG3G', 'MSIIGATRLQNDKSDTYSAGPCYAGGCSAFTPRGTCGKDWDLGEQTCASGFCTSQPLCARIKKTQVCGLRYSSKGKDPLVSAEWDSRGAPYVRCTYDADLIDTQAQVDQFV

## Grab ESM activations in batches

In [15]:

@torch.no_grad()
def activation_batches(
    fasta_gz: Path,
    tokens_per_step: int = TOEKNS_PER_STEP,
    layer_sel=LAYER_TO_TRAIN,
    dtype = torch.float16,
) -> Iterable[torch.Tensor]:

    buf = []
    total = 0
    for seq_batch in batched_sequences(fasta_gz, BATCH_SIZE_SEQ):
        reps, attn = extract_esm_features_batch(
            sequences=seq_batch,
            layer_sel=layer_sel,
            device=DEVICE,
            dtype=dtype,
            model=model,
            tokenizer=tokenizer,
        ) #reps: [B,L,d], attn:[B, L]

        B, L, D = reps.shape #Batch, sequence length, hidden dimension = 1280
        for b in range(B): #for sequence in the batch
            valid = attn[b] # bool[L], which ones were we supposed to attend to?
            if valid.any():
                buf.append(reps[b][valid].detach().to("cpu", dtype=torch.float32)) #float32 for SAE
                total += int(valid.sum())
     
        # When we have enough tokens, build a randomized batch and a randomized remainder.
        #Yield turns the function into a generator, so the local variables are preserved between next() calls. once we drop references to it or stop iteration, it gets cleaned up
        if total >= tokens_per_step:
            all_tokens = torch.cat(buf, dim=0)                # [N, d_in]
            N = all_tokens.size(0)
            perm = torch.randperm(N)                          # random order
            take = perm[:tokens_per_step]
            keep = perm[tokens_per_step:]

            X    = all_tokens[take]                           # [tokens_per_step, d_in]
            rest = all_tokens[keep]                           # [N - tokens_per_step, d_in]

            # Carry forward the (already shuffled) remainder
            buf = ([rest] if rest.numel() else [])
            total = int(rest.size(0)) if rest.numel() else 0

            yield X

    # flush remainder (possibly a short batch)
    if buf:
        X = torch.cat(buf, dim=0)
        if X.numel() > 0:
            # optional: shuffle this short batch, too
            idx = torch.randperm(X.size(0))
            X = X[idx]
            yield X

## Sparse Autoencoder (ReLU activation function)

In [None]:
class SAE(nn.Module):
    def __init__(self, d_in: int, d_hidden: int, tied: bool=False, use_layernorm: bool = False):
        super().__init__()
        self.d_in, self.d_hidden, self.tied=d_in, d_hidden, tied
        

In [None]:
import torch.nn as nn 
import torch.nn.function as F 

class SAE(nn.Module):
    """
    One-layer sparse autoencoder (untied decoder).

    Design:
      x -- (center by bias) --> encoder (Linear) --> LayerNorm? --> ReLU --> f (sparse features)
      f -- decoder(linear) --> + bias --> xhat

    Shapes (let:
    
        B = batch size in *tokens* (each row is one token embedding),
        D = input dim (ESM hidden size for a given layer, e.g., 320/480/640/1280)
        K = number of features or Dictionary size
    ):
        -x: [B, D]
        -f: [B, K]
        -x_hat: [B, D]
        -encoder.weight: [K, D] (each row maps input -> a feature pre-activation)
        -decoder.weight: [D, K] (each column is a dictionary atom in input space)
    
    Notes:
        - The learned bias centers inputs before encoding, then is added back after decoding.
        - Decoder is *untied* (separate matrix from encoder) - standard in mech interp saes.
        - ReLU enforces Nonnegativity/sparsity in f.

    """

    def __init__(
        self,
        d_in: int, #D
        d_hidden: int, #K
        use_layernorm: bool=False,
        init_unitnorm_decoder: bool = True,
    ):
        super().__init__()
        self.d_in = d_in #D
        self.d_hidden = d_hidden #K

        #Learned centering bias b \ in R^D
        self.bias = nn.Parameter(torch.zeros(d_in)) # [D]

        #Encoder W_e: R^D => Expanded into R^K (plus bias)
        self.encoder = nn.Linear(d_in, d_hidden, bias=True) #weight: [K, D], bias: [K]

        self.ln = nn.LayerNorm(d_hidden) if use_layernorm else nn.Identity()

        #Decoder W_d: R^k => R^D (no bias; we add +b afterwards)
        self.decoder = nn.Linear(d-Hidden, d_in, bias=False) # weight: [D, K]

        #Initialization

        nn.init.xavier_uniform_(self.encoder.weight) #[K, D], _ represents in place operation
        if self.encoder.bias is not None:
            nn.init.zeros_(self.encoder.bias) #[K]
        
        if init_unitnorm_decoder:
            #make decoder columns unit norm: each atom d_i = decoder.weight[:, i] / ||.||
            with torch.no_grad():
                W = torch.randn_like(self.decoder.weight) #[D, K]
                W = W / (W.norm(dim=0, keepdim=True) + 1e-12)
                self.decoder.weight.copy_(W)
        else:
            nn.init.xavier_uniform_(self.decoder.weight) #[D, K]
    
    #Core API

    def encode(self, x: torch.Tensor) -> torch.Tensor:
        """
        x: [B, D]
        returns f: [B, K] (sparse, nonnegatvie)
        #Multiplication happens as x @ W_e.T we transpose for dimensions to work out
        #[B, D] @ [D, K] => [B, K]

        """


        #cneter: (x - b) keeps encoder from wasting capacity ont he mean
        f_pre = self.encoder(x - self.bias) #[B, K], encoder has weights [K, D]
        f. F.relu(self.ln(f_pre)) #[B, K]
        return f

    def decode(self, f: torch.Tensor) -> torch.Tensor:
        """
        f: [B, K]
        returns x_hat: [B,D]

        #Same thing here, where is f @ W_d.t + bias
        so [B, K] @ [K, D]

        but the decoder is actually [D, K] where the columns represent each feature
        """

        x_hat_no_bias = self.decoder(f) #[B, D]
        return x_hat_no_bias + self.bias #[B, D]

    def forward(self, x.torch.Tensor, output_features: bool = False) -> Tuple[torch.Tensor, torch.Tensor] | torhc.Tensor:
        """
        x: [B, D]
        returns:
            - if output_features=False: x_hat[B, D]  
            - if output_features=True: (x_hat [B, D], f [B, K])    
        """
        f = self.encode(x) #[B, K] feature activations for each dictionary vector
        x_hat = self.decode(f) #[B, D]

        return (x_hat, f) if output_features else x_hat
    
    #Helpers
    @staticmethod
    def l1(f: torch.Tensor) -> torch.Tensor:
        """
        Mean absolute activation (sparsity penalty).
        f: [B,K] -> scalar
        
        """
        return f.abs().mean()
    
    @staticmethod
    def l0(f: torch.Tensor, eps: float = 1e-12) -> torch.Tensor:
        """
        Average number of active (nonzero) features per example.
        f: [B, K] -> scalar
        """
        return (f > eps).float().sum(dim=-1).mean()

    @torch.no_grad()
    def rescale_features_(self, scale: torch.Tensor, eps: float=1e-8) -> None:
        """
        Per-feature rescaling that keeps reconstructions invariant but changes feature scale.
        Use this after estimating per-feature maxima or P99 values.
        Goal: want f' = f/s, and x_hat' = x_hat.

        Achieve by:
            encoder.weight(rows) /= s
            encoder.bias /= s
            decoder.weight(cols) *= s
        Args: 
            scale: [K] positive scale per feature (clipped ot >= eps)
        """
        s = torch.clamp(scale.detach().to(self.encoder.wight.device), min=eps) #[K]

        #encoder.weight: [K, D] (row i corresponds to feature i)
        self.encoder.weight.data.div_(s[:, None])
        if self.encoder.bias is not None:
            self.encoder.bias.data.div_(s)
        
        #decoder.weight: [D, K] (column i corresponds to feature i)
        self.decoder.weight.data.mul_(s[None, :])
