In [None]:
from biotite.database import rcsb
from rcsbapi.search import search_attributes as attrs
from Bio.PDB import PDBList

import torch 
import numpy as np
import pandas as pd 

from huggingface_hub import login
import os
from dotenv import load_dotenv

In [None]:
load_dotenv()
HF_TOKEN = os.getenv("HF_TOKEN")
login(HF_TOKEN)

### 1. Get all PDBs from Huggingface dataset `pdb_protein_ligand_complexes` 

Download from here https://huggingface.co/datasets/jglaser/pdb_protein_ligand_complexes  
and put in `data/` directory.  
I removed unnecessary columns and saved the new datasets to: `pdb_protein_ligand_train.p` and `pdb_protein_ligand_test.p` which are much smaller. 

### 2. Apply **ESM** (via **Synthyra**) to PDB IDs to get protein embeddings**

Source: https://huggingface.co/Synthyra/ESMplusplus_large

#### 2.1. Sequences

In [None]:
# PDB: drugbank_target, drugbank_info, drugbank_container_identifiers ???
# PDB: ligands ????
# attrs.rcsb_binding_affinity

In [None]:
if os.path.exists("data/pdb_sequences.csv"):
    df = pd.read_csv("data/pdb_sequences.csv", index_col=False)
    print(len(df), "PDB sequences")

In [None]:
pd.read_pickle("data/pdb_protein_ligand_train.p")[]

In [None]:
# TODO: filter df to only include proteins with ligand 













#### 2.2. Embeddings 

Using `Synthyra/ESMplusplus` via `transformers` library, as it is easier to run batched inputs. 

In [None]:
from transformers import AutoModelForMaskedLM
import pickle

In [None]:
# Detect device: CUDA or CPU
device = "cuda" if torch.cuda.is_available() else "cpu"
print(f"Using device: {device}")

In [None]:
# Synthyra ESM models:
#   ESMplusplus_large: corresponds to ESM-C 600m 
#   ESMplusplus_small: corresponds to ESM-C 300m

In [None]:
model = AutoModelForMaskedLM.from_pretrained("Synthyra/ESMplusplus_small", trust_remote_code=True)
model = model.to(device) 

In [None]:
# Test for a small batch (10 proteins): 

x_sequences = df["sequence"][:10].values
tokenized_sequences = model.tokenizer(x_sequences.tolist(), padding=True, return_tensors="pt")

with torch.no_grad():
    output = model(**tokenized_sequences)  # get ALL hidden states by setting output_hidden_states=True

y_embeddings = output.last_hidden_state
y_embeddings.shape 
# (batch_size, seq_len, hidden_size)

In [None]:
# Divide the dataframe into batches of size batch_size
batch_size = 100
num_batches = int(len(df)/batch_size)
print(num_batches, "batches")

# TODO: do proper torch dataset object









In [None]:
if not os.path.exists("data/_esm"):
    os.mkdir("data/_esm")

In [None]:
%%time 

x_ids = []
y_embeddings = []

for i_batch in range(num_batches):
    print("Batch", i_batch)
    
    batch_df = df.iloc[i_batch*batch_size:i_batch*batch_size+batch_size]
    
    # Combined (pdb_id + chain_id) as identifier: 
    x_ids = (batch_df["pdb_id"] + "_" + batch_df["chain_id"]).values  
    x_sequences = batch_df["sequence"].values
    
    tokenized_sequences = model.tokenizer(
        x_sequences.tolist(), 
        padding=True, 
        return_tensors="pt"
    )
    
    with torch.no_grad():
        output = model(**tokenized_sequences)  
    
    y_embeddings = output.last_hidden_state 
    
    # Save to file:
    # I will save each batch separately for now, and deal with concatenation in a later step 
    with open(f"data/_esm/esm_embeddings_batch_{i_batch}.pkl", "wb") as f:
        pickle.dump((x_ids, y_embeddings), f) 

In [None]:
# Just to check, load data for batch 0: 

with open("data/_esm/esm_embeddings_batch_0.pkl", "rb") as f:
    x_ids, y_embeddings = pickle.load(f)

Optionally: combine all batch files into a single dataset file: 

In [None]:
# Would be preferable for the next steps 

In [None]:
# TODO 








