# Setup

Install InterProt, load ESM and SAE

In [1]:
%%capture
!pip install git+https://github.com/etowahadams/interprot.git

In [2]:
import csv
import os
import torch
import warnings

import numpy as np
import pandas as pd
import polars as pl

from huggingface_hub import hf_hub_download
from safetensors.torch import load_file
from sklearn.linear_model import LogisticRegression
from sklearn.model_selection import GridSearchCV, PredefinedSplit
from sklearn.metrics import accuracy_score
from tqdm import tqdm
from transformers import AutoTokenizer, EsmModel

from interprot.sae_model import SparseAutoencoder
from interprot.utils import get_layer_activations

ESM_DIM = 1280
SAE_DIM = 4096
LAYER = 28

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# Load ESM model
tokenizer = AutoTokenizer.from_pretrained("facebook/esm2_t33_650M_UR50D")
esm_model = EsmModel.from_pretrained("facebook/esm2_t33_650M_UR50D")
esm_model.to(device)
esm_model.eval()

# Load SAE model
checkpoint_path = hf_hub_download(
    repo_id="liambai/InterProt-ESM2-SAEs",
    filename="esm2_plm1280_l28_sae4096.safetensors"
)
sae_model = SparseAutoencoder(ESM_DIM, SAE_DIM)
sae_model.load_state_dict(load_file(checkpoint_path))
sae_model.to(device)
sae_model.eval()

The secret `HF_TOKEN` does not exist in your Colab secrets.
To authenticate with the Hugging Face Hub, create a token in your settings tab (https://huggingface.co/settings/tokens), set it as secret in your Google Colab and restart your session.
You will be able to reuse this secret in all of your notebooks.
Please note that authentication is recommended but still optional to access public models or datasets.


tokenizer_config.json:   0%|          | 0.00/95.0 [00:00<?, ?B/s]

vocab.txt:   0%|          | 0.00/93.0 [00:00<?, ?B/s]

special_tokens_map.json:   0%|          | 0.00/125 [00:00<?, ?B/s]

config.json:   0%|          | 0.00/724 [00:00<?, ?B/s]

model.safetensors:   0%|          | 0.00/2.61G [00:00<?, ?B/s]

Some weights of EsmModel were not initialized from the model checkpoint at facebook/esm2_t33_650M_UR50D and are newly initialized: ['esm.pooler.dense.bias', 'esm.pooler.dense.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


esm2_plm1280_l28_sae4096.safetensors:   0%|          | 0.00/42.0M [00:00<?, ?B/s]

SparseAutoencoder()

Load the subcellular localization dataset

In [3]:
# Download subcellular localization dataset from https://zenodo.org/records/10631963
!gdown https://drive.google.com/uc?id=1BG91Eu80t546q-9wIDxCaSYpzGh4lPHX

data_path = "balanced.csv"
df = pl.read_csv(data_path)
df = df.filter(pl.col("sequence").str.len_chars() <= 1000)

train_df = df.filter(pl.col("set") == "train")
test_df = df.filter(pl.col("set") != "train")

df.head()

Downloading...
From: https://drive.google.com/uc?id=1BG91Eu80t546q-9wIDxCaSYpzGh4lPHX
To: /content/balanced.csv
100% 6.25M/6.25M [00:00<00:00, 20.3MB/s]


sequence,target,set,validation
str,str,str,str
"""MEVLEEPAPGPGGADAAERRGLRRLLLSGF…","""Cell membrane""","""train""",
"""MMKTLSSGNCTLNVPAKNSYRMVVLGASRV…","""Cell membrane""","""train""",
"""MAKRTFSNLETFLIFLLVMMSAITVALLSL…","""Cell membrane""","""train""",
"""MGNCQAGHNLHLCLAHHPPLVCATLILLLL…","""Cell membrane""","""train""",
"""MDPSKQGTLNRVENSVYRTAFKLRSVQTLC…","""Cell membrane""","""train""",


Define some helpers for linear classified training, evaluation, and inspection.

In [4]:
def train_classifier(seq_acts):
    """Trains a linear classifier with a predefined validation split.

    Uses grid search over regularization strengths to find the best classifier.
    """
    X_train = []
    y_train = []
    test_fold = []

    for row in train_df.iter_rows(named=True):
        X_train.append(seq_acts[row['sequence']].cpu().detach().numpy())
        y_train.append(row['target'])
        if row['validation']:
            test_fold.append(0)
        else:
            test_fold.append(-1)

    X_train = np.array(X_train)
    y_train = np.array(y_train)

    classifier = LogisticRegression(multi_class='ovr')
    ps = PredefinedSplit(test_fold=test_fold)
    param_grid = {'C': [0.01, 0.1, 1, 10, 100]}
    grid_search = GridSearchCV(classifier, param_grid, cv=ps, scoring='accuracy')
    grid_search.fit(X_train, y_train)

    best_classifier = grid_search.best_estimator_
    return best_classifier, grid_search.best_params_['C'], grid_search.best_score_


def evaluate_classifier(classifier, seq_acts):
    """Computes the accuracy of the classifier over a test set"""
    X_test = []
    y_test = []
    for row in test_df.iter_rows(named=True):
        X_test.append(seq_acts[row['sequence']].cpu().detach().numpy())
        y_test.append(row['target'])
    X_test = np.array(X_test)
    y_test = np.array(y_test)

    y_pred = classifier.predict(X_test)

    accuracy = accuracy_score(y_test, y_pred)
    return accuracy


def write_sorted_weights(classifier, output_dir):
    """Writes model weights to CSV files in sorted order.

    For multi-class models, writes one CSV per class.
    Each CSV has columns "Index" and "Weight" sorted descending.

    Args:
        output_dir: Directory to write CSV files to
    """
    os.makedirs(output_dir, exist_ok=True)

    for i, class_label in enumerate(classifier.classes_):
        class_label = class_label.replace("/", "_").lower()
        output_file = os.path.join(output_dir, f"class_{class_label}_weights.csv")

        print(f"Class: {class_label}")
        class_weights = classifier.coef_[i]
        sorted_weights = sorted(enumerate(class_weights), key=lambda x: x[1], reverse=True)
        with open(output_file, 'w', newline='') as csvfile:
            writer = csv.writer(csvfile)
            writer.writerow(['Index', 'Weight'])
            for index, weight in sorted_weights:
                writer.writerow([index, weight])

        print(pd.read_csv(output_file).head())

# ESM & SAE Inference

Compute and store mean-pooled embeddings for both ESM and SAE

In [5]:
seq_esm_acts = {}
seq_sae_acts = {}

for seq_idx, row in tqdm(enumerate(df.iter_rows(named=True)), total=len(df)):
    seq = row['sequence']
    esm_layer_acts = get_layer_activations(
        tokenizer=tokenizer, plm=esm_model, seqs=[seq], layer=LAYER
    )[0][1:-1] # Trim off BoS and EoS tokens

    seq_esm_acts[seq] = torch.mean(esm_layer_acts, axis=0)

    sae_acts = sae_model.get_acts(esm_layer_acts)
    seq_sae_acts[seq] = torch.mean(sae_acts, axis=0)

100%|██████████| 10461/10461 [17:57<00:00,  9.71it/s]


# Linear classifiers

- Train classifier on ESM embedding (for accuracy comparison)
- Train classifier on SAE embedding
- Inspect and write the weights of the SAE classifier

In [6]:
with warnings.catch_warnings():
    warnings.simplefilter("ignore")

    print("Training ESM classifier")
    esm_classifier, _, _ = train_classifier(seq_esm_acts)
    esm_accuracy = evaluate_classifier(esm_classifier, seq_esm_acts)
    print(f"ESM accuracy: {esm_accuracy}")

    print("Training SAE classifier")
    sae_classifier, _, _ = train_classifier(seq_sae_acts)
    sae_accuracy = evaluate_classifier(sae_classifier, seq_sae_acts)
    print(f"SAE accuracy: {sae_accuracy}")
    write_sorted_weights(sae_classifier, 'weights')

Training ESM classifier
ESM accuracy: 0.6194690265486725
Training SAE classifier
SAE accuracy: 0.640117994100295
Class: cell membrane
   Index    Weight
0   3815  2.481609
1    550  1.947456
2   2818  1.482108
3   2000  1.283794
4   3154  1.073169
Class: cytoplasm
   Index    Weight
0   3274  1.052095
1   1849  0.893690
2   2966  0.867804
3   2253  0.771584
4   2554  0.768400
Class: endoplasmic reticulum
   Index    Weight
0   1053  1.239194
1   1298  1.229100
2    848  1.148793
3   1055  0.984984
4    565  0.968559
Class: extracellular
   Index    Weight
0   1541  1.293407
1   1470  1.264098
2   1555  1.151713
3   2111  1.135436
4   2472  1.035609
Class: golgi apparatus
   Index    Weight
0   4086  1.601550
1   1635  1.204602
2    880  1.179289
3   3729  1.169152
4     75  1.096873
Class: lysosome_vacuole
   Index    Weight
0   2354  1.203674
1   3592  1.142584
2   3021  1.105611
3   1683  1.094772
4   2867  0.992007
Class: mitochondrion
   Index    Weight
0   3277  2.362949
1   1948 