#Setup

Install InterProt, load ESM and SAE.


In [1]:
%%capture
!pip install git+https://github.com/etowahadams/interprot.git

In [1]:
import torch
device = "cuda" if torch.cuda.is_available() else "cpu"

from transformers import AutoTokenizer, EsmForMaskedLM
from safetensors.torch import load_file
from interprot.sae_model import SparseAutoencoder
from huggingface_hub import hf_hub_download

ESM_DIM = 1280
SAE_DIM = 4096
LAYER = 24

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

tokenizer = AutoTokenizer.from_pretrained("facebook/esm2_t33_650M_UR50D")
esm_lm = EsmForMaskedLM.from_pretrained(
    "facebook/esm2_t33_650M_UR50D")
esm_lm.to(device).eval()

# Load SAE model
checkpoint_path = hf_hub_download(
    repo_id="liambai/InterProt-ESM2-SAEs",
    filename="esm2_plm1280_l24_sae4096_100k.safetensors"
)
sae_model = SparseAutoencoder(ESM_DIM, SAE_DIM)
sae_model.load_state_dict(load_file(checkpoint_path))
sae_model.to(device)
sae_model.eval()

tokenizer_config.json:   0%|          | 0.00/95.0 [00:00<?, ?B/s]

vocab.txt:   0%|          | 0.00/93.0 [00:00<?, ?B/s]

special_tokens_map.json:   0%|          | 0.00/125 [00:00<?, ?B/s]



config.json:   0%|          | 0.00/724 [00:00<?, ?B/s]

model.safetensors:   0%|          | 0.00/2.61G [00:00<?, ?B/s]

(…)sm2_plm1280_l24_sae4096_100k.safetensors:   0%|          | 0.00/42.0M [00:00<?, ?B/s]

SparseAutoencoder()

#Inference


First, get ESM layer 24 activations, encode it with SAE to get a (L, 4096) tensor

In [2]:
seq = "MATLFHDTSQSEENGSDDNLSLENEEKLKALGCREDPVNILVIGPAGAGKSTLINALFGKDVATVGYGARGVTTEIHSYEGEYKGVRIRVYDTVGFEGRSDWSYLRNIRRHEKYDLVLLCTKLGGRVDRDTFLELASVLHEEMWKKTIVVLTFANQFITLGSVAKSNDLEGEINKQIEEYKSYLTGRLSNCVRKEALVGIPFCIAGVEDERELPTTEDWVNTLWDKCIDRCSNETYHFASWFSIIKIVAFGFGVAIGTAIGAIVGSIVPVTGTIIGAIAGGYIGAAITKRVVEKFKNY"

inputs = tokenizer(seq, return_tensors="pt", add_special_tokens=True)
inputs = {k: v.to(device) for k, v in inputs.items()}

# Forward pass → grab hidden_states[LAYER]
with torch.no_grad():
    out = esm_lm(**inputs,
                 output_hidden_states=True,
                 return_dict=True)

# hidden_states is tuple(len=layers+1) of (batch, seq_len, hidden_dim)
esm_layer_acts = out.hidden_states[LAYER]

print("ESM layer-24 acts shape:", esm_layer_acts.shape)

ESM layer-24 acts shape: torch.Size([1, 300, 1280])


In [3]:
sae_latents, mu, std = sae_model.encode(esm_layer_acts[0])

print("SAE latents shape:", sae_latents.shape)

SAE latents shape: torch.Size([300, 4096])


Decoding the SAE latents yields a (L, 1280) tensor `decoded_esm_layer_acts`
(i.e., the SAE's prediction of ESM layer 24 acts). This prediction can be off: compute the error as `recons_error`.

In [4]:
decoded_esm_layer_acts = sae_model.decode(sae_latents, mu, std)
recons_error = esm_layer_acts - decoded_esm_layer_acts

recons_error

tensor([[[-0.4678,  0.7146,  3.8358,  ..., -1.6600, -0.6162, -1.5107],
         [-1.4838,  0.7454,  1.7162,  ..., -2.4572, -4.2533, -4.9109],
         [-1.5524,  2.9380,  1.2259,  ...,  1.1724, -0.8309, -2.2150],
         ...,
         [ 1.8651, -4.5614,  3.0894,  ...,  0.3076,  1.5975,  0.7138],
         [ 4.4343, -3.4748,  2.7736,  ..., -3.1804, -0.1675,  0.2339],
         [ 2.1849, -1.0958, -1.0139,  ..., -4.1312, -0.2550, -2.3275]]])

In [5]:
assert torch.allclose(decoded_esm_layer_acts + recons_error, esm_layer_acts, atol=1e-05)

In [6]:
max_act = sae_latents.max()
sae_latents[:, 220] = max_act * 5

Decode modified SAE latents back into ESM feature‐space.

In [7]:
clamped = sae_model.decode(sae_latents, mu, std).unsqueeze(0)   # (1, L, 1280)

# Get full‐vocab logits
logits = esm_lm.lm_head(clamped + recons_error)                 # (1, L, vocab_size)

# Mask out all special‐token IDs
special_ids = {
    tokenizer.pad_token_id,
    tokenizer.cls_token_id,
    tokenizer.eos_token_id,
    tokenizer.mask_token_id,
    tokenizer.unk_token_id,
}
for sid in special_ids:
    logits[..., sid] = -1e9

# Argmax over the remaining classes, skipping BoS/EoS
tokens = torch.argmax(logits[:, 1:-1, :], dim=-1)               # (1, L-2)

# Convert to letters
steered_seq = "".join(tokenizer.convert_ids_to_tokens(tokens[0].tolist()))
print("Steered sequence:", steered_seq)

Steered sequence: TTTTTTTTTTTTTTTTTTTTTTTTTTTTTTTTTTTTTTTTTTTTTTTTTTTTTTTTTTTTTTTTTTTTTTTTTTTTTTTTTTTTTTTTTTTTTTTTTTTTTTTTTTTTTTTTTTTTTTTTTTTTTTTTTTTTTTTTTTTTTTTTTTTTTTTTTTTTTTTTTTTTTTTTTTTTTTTTTTTTTTTTTTTTTTTTTTTTTTTTTTTTTTTTTTTTTTTTTTTTTTTTTTTTTTTTTTTTTTTTTTTTTTTTTTTTTTTTTTTTTTTTTTTTTTTTTTTTTTTTTTTTTTTTTTTTTTTTTT
