In [1]:
from huggingface_hub import login
import os
from dotenv import load_dotenv
load_dotenv()
HF_TOKEN = os.getenv("HF_TOKEN")

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
login(HF_TOKEN)

Note: Environment variable`HF_TOKEN` is set and is the current active token independently from the token you've just configured.


---

In [3]:
from esm.models.esmc import ESMC
# from esm.models.esm3 import ESM3
from esm.sdk.api import ESMProtein, LogitsConfig  # , GenerationConfig

from esm.utils.structure.protein_chain import ProteinChain
from biotite.database import rcsb
from rcsbapi.search import search_attributes as attrs
from Bio.PDB import PDBList

import torch 
import numpy as np
import pandas as pd 

### 1. Getting all PDBs

In [5]:
# The following gets all PDB entries, but this contains non-proteins which ESM can't handle
"""
if os.path.exists("data/pdb_ids.csv"):
    all_protein_ids = pd.read_csv("data/pdb_ids.csv").values.squeeze().tolist()
    print(len(all_protein_ids), "PDB IDs")
else:
    pdbl = PDBList()
    all_protein_ids = pdbl.get_all_entries()
    # all_protein_ids = ["1CM4"]
    print(len(all_protein_ids), "PDB IDs")
    pd.Series(all_protein_ids).to_csv("data/pdb_ids.csv", index=False)
"""

# Get only proteins from PDB
if os.path.exists("data/pdb_ids.csv"):
    all_protein_ids = pd.read_csv("data/pdb_ids.csv").values.squeeze().tolist()
    print(len(all_protein_ids), "PDB IDs")
else:
    # pdbl = PDBList()
    # all_protein_ids = pdbl.get_all_entries()
    q = (attrs.rcsb_entry_info.polymer_entity_count_protein > 0)
    all_protein_ids = list(q())
    # all_protein_ids = ["1CM4"]
    print(len(all_protein_ids), "PDB IDs")
    pd.Series(all_protein_ids).to_csv("data/pdb_ids.csv", index=False)

# Alternatively: 
# Download this file https://ftp.ebi.ac.uk/pub/databases/pdb/derived_data/pdb_entry_type.txt
# and then filter by type "prot"

228524 PDB IDs


### 2A. Getting the protein embeddings for all PDB IDs using **`esm`** library

In [None]:
# Problem: there doesn't appear to be a good way to pass a whole batch of proteins!! 
# So we would have to run all proteins sequentially ... which is not desirable 

In [None]:
# esmc = ESMC.from_pretrained("esmc_300m")
# esm3 = ESM3.from_pretrained("esm3-open")

In [None]:
"""
%%time 

protein_ids = []
protein_embeddings = [] 

for protein_id in all_protein_ids: 
    protein_chain = ProteinChain.from_pdb(rcsb.fetch(protein_id, "pdb")) # , chain_id="A") 
    # Get protein object with all the ground-truth data (except function for some reason) 
    # In the code, they don't provide a way to automatically fetch function annotations, 
    # instead I have to fetch them myself and then set protein.function_annotations 
    # known_protein = ESMProtein.from_protein_chain(protein_chain) 
    # Get protein with just the sequence data 
    protein = ESMProtein(sequence=protein_chain.sequence) 
    # I don't think we can put all tokens into a batch to run through the model at once? 
    protein_tensor = esmc.encode(protein)
    output = esmc.logits(
        protein_tensor, 
        LogitsConfig(
            return_hidden_states=True,  # !!
            # ESMC-300m has 30 layers, so final layer is at index 29:
            ith_hidden_layer=29
        )
    )
    protein_ids.append(protein_id)
    protein_embeddings.append(output.hidden_states.squeeze())
""";

### 2B. Getting the protein embeddings for all PDB IDs using **`huggingface`** and **Synthyra** implementations of ESM 

Source: https://huggingface.co/Synthyra/ESMplusplus_large

#### 2B.1. Sequences

In [19]:
import requests

In [None]:
# For this approach, we need all the proteins' sequences
# We could just do this:
# ProteinChain.from_pdb(rcsb.fetch(protein_id, "pdb")).sequence 
# but we would have to do it for each protein sequentially

In [None]:
"""
# Let's try it by running the API queries to PDB (RCSB) directly 

batch_size = 5000
batches = [all_protein_ids[i:i+batch_size] for i in range(0, len(all_protein_ids), batch_size)]
print(len(batches), "batches")

batches = [batches[0]]

results = []

# Loop through batches
for i, batch in enumerate(batches, start=1):
    print(f"Processing batch {i}/{len(batches)}")
    query = DataQuery(
        input_type="entries",
        input_ids=all_protein_ids,
        return_data_list=[
            "rcsb_id",
            "polymer_entities.entity_poly.rcsb_entity_polymer_type",
            "polymer_entities.entity_poly.pdbx_seq_one_letter_code_can"
        ]
    )
    try:
        # Execute the query for this batch
        batch_results = query.exec()
        json.dump(batch_results, open(f"data/pdb_sequences_batch_{i}.json", "w"))
        results.extend(batch_results)
        # Small delay to avoid rate limits
        time.sleep(1)
    except Exception as e:
        print(f"Failed batch {i}: {e}")
        continue 

json.dump(results, open(f"data/pdb_sequences.json", "w"))
""";

In [None]:
# PDB: drugbank_target, drugbank_info, drugbank_container_identifiers ???
# PDB: ligands ????
# attrs.rcsb_binding_affinity

In [20]:
# Download `pdb_seqres.txt` from here: https://ftp.ebi.ac.uk/pub/databases/pdb/derived_data/ 

if not os.path.exists("data/pdb_seqres.txt"):
    response = requests.get("https://ftp.ebi.ac.uk/pub/databases/pdb/derived_data/pdb_seqres.txt")
    if response.status_code != 200:
        raise Exception(f"There was a problem (status code {response.status_code}). Please try again.")
    else:
        with open("data/pdb_seqres.txt", "w") as f:
            f.write(response.text)

In [8]:
if os.path.exists("data/pdb_sequences.csv"):
    df = pd.read_csv("data/pdb_sequences.csv", index_col=False)
    print(len(df), "PDB sequences")
else:
    if not os.path.exists("data/pdb_seqres.txt"):
        raise Exception("Please download pdb_seqres.txt from https://ftp.ebi.ac.uk/pub/databases/pdb/derived_data/ and put it in the data/ directory.")
    with open("data/pdb_seqres.txt") as f:
        lines = f.readlines()
    
    sequences = []
    
    for i, line in enumerate(lines):
        if "mol:protein" in line:
            raw = line[1:].strip().split()
            pdb_id, chain_id = raw[0].upper().split("_")
            # mol_type = raw[1].split(":")[1]
            length = raw[2].split(":")[1]
            name = raw[3]
            record = {
                "pdb_id": pdb_id,
                "chain_id": chain_id,
                # "type": mol_type,
                "length": length,
                "name": name,
                "sequence": lines[i+1].strip()
            }
            sequences.append(record)
    
    df = pd.DataFrame(sequences)
    df = df[df["pdb_id"].isin(all_protein_ids)]
    df = df.drop_duplicates("sequence").reset_index(drop=True)
    print(len(df), "PDB sequences")
    df.to_csv("data/pdb_sequences.csv", index=False)

166982 PDB sequences


#### 2B.2. Embeddings 

Using `Synthyra/ESMplusplus` via `transformers` library, as it is easier to run batched inputs. 

In [9]:
from transformers import AutoModelForMaskedLM
import pickle

In [None]:
# Synthyra ESM models:
#   ESMplusplus_large: corresponds to ESM-C 600m 
#   ESMplusplus_small: corresponds to ESM-C 300m

In [10]:
# Detect device: CUDA or CPU
device = "cuda" if torch.cuda.is_available() else "cpu"
print(f"Using device: {device}")

Using device: cpu


In [11]:
model = AutoModelForMaskedLM.from_pretrained("Synthyra/ESMplusplus_small", trust_remote_code=True)
model = model.to(device) 

In [12]:
# Test for a small batch (10 proteins): 

x_sequences = df["sequence"][:10].values
tokenized_sequences = model.tokenizer(x_sequences.tolist(), padding=True, return_tensors="pt")

with torch.no_grad():
    output = model(**tokenized_sequences)  # get ALL hidden states by setting output_hidden_states=True

y_embeddings = output.last_hidden_state
y_embeddings.shape 
# (batch_size, seq_len, hidden_size)

torch.Size([10, 169, 960])

In [21]:
# Divide the dataframe into batches of size batch_size

batch_size = 100

num_batches = int(len(df)/batch_size)
print(num_batches, "batches")

166982 batches


In [14]:
if not os.path.exists("data/esm_embeddings"):
    os.mkdir("data/esm_embeddings")

In [22]:
%%time 

x_ids = []
y_embeddings = []

for i_batch in range(num_batches):
    print("Batch", i_batch)
    
    batch_df = df.iloc[i_batch*batch_size:i_batch*batch_size+batch_size]
    
    # Combined (pdb_id + chain_id) as identifier: 
    x_ids = (batch_df["pdb_id"] + "_" + batch_df["chain_id"]).values  
    x_sequences = batch_df["sequence"].values
    
    tokenized_sequences = model.tokenizer(
        x_sequences.tolist(), 
        padding=True, 
        return_tensors="pt"
    )
    
    with torch.no_grad():
        output = model(**tokenized_sequences)  # get ALL hidden states by setting output_hidden_states=True
    
    y_embeddings = output.last_hidden_state 
    
    # Save to file:
    # I will save each batch separately for now, and deal with concatenation in a later step 
    with open(f"data/esm_embeddings/esm_embeddings_batch_{i_batch}.pkl", "wb") as f:
        pickle.dump((x_ids, y_embeddings), f) 

Batch 0
CPU times: user 633 ms, sys: 110 ms, total: 743 ms
Wall time: 196 ms


In [24]:
# E.g. load data for batch 0 
with open("data/esm_embeddings/esm_embeddings_batch_0.pkl", "rb") as f:
    x_ids, y_embeddings = pickle.load(f)