Saprot

In [1]:
import os
print("Current working dir:", os.getcwd())

Current working dir: /home/houxuc/Documents/Topt/SaProt


In [2]:
import os
import pandas as pd
from tqdm import tqdm
from utils.foldseek_util import get_struc_seq

FOLDSEEK_BIN = "bin/foldseek"
INPUT_DIR = "/home/houxuc/Documents/Topt/data/pdbs"
OUTPUT_CSV = "/home/houxuc/Documents/Topt/data/structure_sequences.csv"
def main():
    all_records = []

    structure_files = [
        os.path.join(INPUT_DIR, f)
        for f in os.listdir(INPUT_DIR)
        if f.endswith((".pdb", ".cif"))
    ]

    print(f"Found {len(structure_files)} structure files")

    for path in tqdm(structure_files):
        try:
            result = get_struc_seq(
                FOLDSEEK_BIN,
                path,
                None,          # all chains
                plddt_mask=False
            )

            for chain_id, parsed in result.items():
                seq, foldseek_seq, combined_seq = parsed

                all_records.append({
                    "structure_file": os.path.basename(path),
                    "chain": chain_id,
                    "combined_seq": combined_seq
                })

        except Exception as e:
            print(f"rror: {path} -> {e}")

    df = pd.DataFrame(all_records)
    df.to_csv(OUTPUT_CSV, index=False)

    print(f"\nsaved to {OUTPUT_CSV}")
    print(f"Total chains processed: {len(df)}")


if __name__ == "__main__":
    main()

Found 3129 structure files


100%|██████████| 3129/3129 [01:29<00:00, 35.00it/s]



saved to /home/houxuc/Documents/Topt/data/structure_sequences.csv
Total chains processed: 3129


In [None]:
import torch
import pandas as pd
import numpy as np
from tqdm import tqdm
from transformers import AutoTokenizer, AutoModelForMaskedLM

MODEL_NAME = "westlake-repl/SaProt_650M_PDB"
CSV_PATH = "/home/houxuc/Documents/Topt/data/structure_sequences.csv"
SAVE_PATH = "/home/houxuc/Documents/Topt/data/saprot_residue_embeddings.npz"

device = "cuda" if torch.cuda.is_available() else "cpu"

print("Loading SaProt model...")
tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
model = AutoModelForMaskedLM.from_pretrained(MODEL_NAME)
model.to(device)
model.eval()

df = pd.read_csv(CSV_PATH)

ids = []
embeddings_list = []
lengths = []

print("Generating residue-level embeddings...")

for _, row in tqdm(df.iterrows(), total=len(df)):
    
    raw_key = f"{row['structure_file']}_{row['chain']}"
    
    clean_key = raw_key.split(".pdb")[0]

    seq = row["combined_seq"]

    inputs = tokenizer(seq, return_tensors="pt", truncation=True)
    inputs = {k: v.to(device) for k, v in inputs.items()}

    with torch.no_grad():
        outputs = model(**inputs, output_hidden_states=True)
        hidden = outputs.hidden_states[-1]  # (1, seq_len, 1280)

    hidden = hidden.squeeze(0)      # (seq_len, 1280)
    hidden = hidden[1:-1]           # remove <cls>, <eos>

    L = hidden.shape[0]

    ids.append(clean_key)
    embeddings_list.append(hidden.cpu().numpy())
    lengths.append(L)
    
# ---------- Save ----------
print("Saving variable-length embeddings (no global padding)...")

# Avoid huge dense allocation: keep each sequence as its own array
embeddings_obj = np.array(embeddings_list, dtype=object)

np.savez_compressed(
    SAVE_PATH,
    ids=np.array(ids),
    embeddings=embeddings_obj,
    lengths=np.array(lengths, dtype=np.int32),
)

print(f"Saved cleaned residue-level SaProt embeddings to {SAVE_PATH}")


Loading SaProt model...


Some weights of EsmForMaskedLM were not initialized from the model checkpoint at westlake-repl/SaProt_650M_PDB and are newly initialized: ['esm.embeddings.position_embeddings.weight', 'esm.contact_head.regression.bias', 'esm.contact_head.regression.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


Generating embeddings...


  0%|          | 0/3129 [00:00<?, ?it/s]Asking to truncate to max_length but no maximum length is provided and the model has no predefined maximum length. Default to no truncation.
100%|██████████| 3129/3129 [05:03<00:00, 10.30it/s]

Padding...





MemoryError: Unable to allocate 28.5 GiB for an array with shape (3129, 1912, 1280) and data type float32

In [None]:
import torch
import pandas as pd
from tqdm import tqdm
from transformers import AutoTokenizer, AutoModelForMaskedLM

MODEL_NAME = "westlake-repl/SaProt_650M_PDB"
CSV_PATH = "/home/houxuc/Documents/Topt/data/structure_sequences.csv"
SAVE_PATH = "/home/houxuc/Documents/Topt/data/saprot_embeddings.pt"

device = "cuda" if torch.cuda.is_available() else "cpu"

tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
model = AutoModelForMaskedLM.from_pretrained(MODEL_NAME)
model.to(device)
model.eval()

df = pd.read_csv(CSV_PATH)

embeddings = {}

for _, row in tqdm(df.iterrows(), total=len(df)):
    key = f"{row['structure_file']}_{row['chain']}"
    seq = row["combined_seq"]

    inputs = tokenizer(seq, return_tensors="pt", truncation=True)
    inputs = {k: v.to(device) for k, v in inputs.items()}

    with torch.no_grad():
        outputs = model(**inputs, output_hidden_states=True)
        hidden = outputs.hidden_states[-1]

    # mean pooling
    embedding = hidden.mean(dim=1).squeeze().cpu()

    embeddings[key] = embedding

torch.save(embeddings, SAVE_PATH)

print("Saved SaProt embeddings")

  from .autonotebook import tqdm as notebook_tqdm
Some weights of EsmForMaskedLM were not initialized from the model checkpoint at westlake-repl/SaProt_650M_PDB and are newly initialized: ['esm.contact_head.regression.bias', 'esm.contact_head.regression.weight', 'esm.embeddings.position_embeddings.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.
  0%|          | 0/3129 [00:00<?, ?it/s]Asking to truncate to max_length but no maximum length is provided and the model has no predefined maximum length. Default to no truncation.
100%|██████████| 3129/3129 [05:08<00:00, 10.14it/s]


Saved SaProt embeddings
