#Setup

Install InterProt, load ESM and SAE.


In [None]:
%%capture
!pip install git+https://github.com/etowahadams/interprot.git

In [None]:
import torch
device = "cuda" if torch.cuda.is_available() else "cpu"

from transformers import AutoTokenizer, EsmForMaskedLM
from safetensors.torch import load_file
from interprot.sae_model import SparseAutoencoder
from huggingface_hub import hf_hub_download

ESM_DIM = 1280
SAE_DIM = 4096
LAYER = 24

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

tokenizer = AutoTokenizer.from_pretrained("facebook/esm2_t33_650M_UR50D")
esm_lm = EsmForMaskedLM.from_pretrained(
    "facebook/esm2_t33_650M_UR50D")
esm_lm.to(device).eval()

# Load SAE model
checkpoint_path = hf_hub_download(
    repo_id="liambai/InterProt-ESM2-SAEs",
    filename="esm2_plm1280_l24_sae4096.safetensors"
)
sae_model = SparseAutoencoder(ESM_DIM, SAE_DIM)
sae_model.load_state_dict(load_file(checkpoint_path))
sae_model.to(device)
sae_model.eval()

esm2_plm1280_l24_sae4096.safetensors:   0%|          | 0.00/42.0M [00:00<?, ?B/s]

SparseAutoencoder()

#Inference


First, get ESM layer 24 activations, encode it with SAE to get a (L, 4096) tensor

In [None]:
seq = "MATLFHDTSQSEENGSDDNLSLENEEKLKALGCREDPVNILVIGPAGAGKSTLINALFGKDVATVGYGARGVTTEIHSYEGEYKGVRIRVYDTVGFEGRSDWSYLRNIRRHEKYDLVLLCTKLGGRVDRDTFLELASVLHEEMWKKTIVVLTFANQFITLGSVAKSNDLEGEINKQIEEYKSYLTGRLSNCVRKEALVGIPFCIAGVEDERELPTTEDWVNTLWDKCIDRCSNETYHFASWFSIIKIVAFGFGVAIGTAIGAIVGSIVPVTGTIIGAIAGGYIGAAITKRVVEKFKNY"

inputs = tokenizer(seq, return_tensors="pt", add_special_tokens=True)
inputs = {k: v.to(device) for k, v in inputs.items()}

# Forward pass → grab hidden_states[LAYER]
with torch.no_grad():
    out = esm_lm(**inputs,
                 output_hidden_states=True,
                 return_dict=True)

# hidden_states is tuple(len=layers+1) of (batch, seq_len, hidden_dim)
esm_layer_acts = out.hidden_states[LAYER]

print("ESM layer-24 acts shape:", esm_layer_acts.shape)

torch.Size([1, 300, 1280])

In [None]:
sae_latents, mu, std = sae_model.encode(esm_layer_acts[0])

print("SAE latents shape:", sae_latents.shape)

SAE latents shape: torch.Size([300, 4096])


Decode the SAE latents yields a (L, 1280) tensor `decoded_esm_layer_acts`,
i.e. the SAE's prediction of ESM layer 24 acts. This prediction can be off: compute the error as `recons_error`.

In [None]:
decoded_esm_layer_acts = sae_model.decode(sae_latents, mu, std)
recons_error = esm_layer_acts - decoded_esm_layer_acts

recons_error

tensor([[[-1.6259,  1.0401,  1.6099,  ..., -2.9674, -1.3321,  1.6541],
         [-1.4183,  0.1804, -0.9507,  ..., -0.6793, -2.6099, -1.9202],
         [ 0.0181,  5.0232,  3.1489,  ...,  0.4581, -0.1087,  1.5976],
         ...,
         [ 2.1382,  2.4558,  2.0668,  ...,  2.7426, -0.4519, -2.6419],
         [ 3.0010,  1.8902,  5.3748,  ..., -1.5816,  1.0738, -5.1460],
         [-0.2634, -1.1311, -1.2312,  ..., -4.4766,  2.6242, -0.4650]]])

In [None]:
assert torch.allclose(decoded_esm_layer_acts + recons_error, esm_layer_acts, atol=1e-05)

In [None]:
max_act = sae_latents.max()
sae_latents[:, 220] = max_act * 5

Decode modified SAE latents back into ESM feature‐space.

In [None]:
clamped = sae_model.decode(sae_latents, mu, std).unsqueeze(0)   # (1, L, 1280)

# Get full‐vocab logits
logits = esm_lm.lm_head(clamped + recons_error)                 # (1, L, vocab_size)

# Mask out all special‐token IDs
special_ids = {
    tokenizer.pad_token_id,
    tokenizer.cls_token_id,
    tokenizer.eos_token_id,
    tokenizer.mask_token_id,
    tokenizer.unk_token_id,
}
for sid in special_ids:
    logits[..., sid] = -1e9

# Argmax over the remaining classes, skipping BoS/EoS
tokens = torch.argmax(logits[:, 1:-1, :], dim=-1)               # (1, L-2)

# Convert to letters
steered_seq = "".join(tokenizer.convert_ids_to_tokens(tokens[0].tolist()))
print("Steered sequence:", steered_seq)

Steered sequence: MSSSSSSSSSSEESSSSSSSAEEAQAAAAALAEEAKKLKIVVVGASSAGKSTFINSTSGTTSASSSSSSTSSTTTTVTYVSSSSNKRVVLVDTVGVFDSEEALVLLVLLLIASVDLVLLLLLLLNSSSEVTVETFTTAFDAEAANRVVVVLNNCDEVNNEEEEENNNNEEEEENEEIEEIINIIITIITNNNNNNIIVNIIVVVVNNAAAATLLVTLLLDLLLVLLLLLLLAAAAAAQAAKIAIIAAAATAATTAAGAAAGAAAGALAPAAGAAAGAAAGGAAGAAAAKKVAKKKKKK
