In [1]:
# Import libraries
import pandas as pd
import numpy as np
import requests
import pickle
import os

from tqdm import tqdm
from collections import defaultdict
from rdkit.Chem.rdmolfiles import *
from rdkit.Chem.rdchem import *
from rdkit.Chem.rdmolops import *
from rdkit.Chem.Draw import *
from rdkit.Chem.Lipinski import *

import re
import time
import json
import zlib
from xml.etree import ElementTree
from urllib.parse import urlparse, parse_qs, urlencode
from requests.adapters import HTTPAdapter, Retry

In [25]:
# Constants
DATA_FOLDER = "./iuphar"
PDB_FOLDER = "./iuphar/targets"
UNIPROT_API_URL = "https://rest.uniprot.org"
POLLING_INTERVAL = 3

TRAINING_PROTEINS = ["P07550"] # beta_2 adrenoceptor
# TESTING_PROTEINS = ["P08913"] # alpha_2A adrenoceptor
TESTING_PROTEINS = ["P06213"] # Negative tesing

In [18]:
session = requests.Session()
retries = Retry(total=5, backoff_factor=0.25, status_forcelist=[500, 502, 503, 504])
session.mount("https://", HTTPAdapter(max_retries=retries))

def check_response(response):
    try:
        response.raise_for_status()
    except requests.HTTPError:
        print(response.json())
        raise
        
def get_id_mapping_results_link(job_id):
    url = f"{UNIPROT_API_URL}/idmapping/details/{job_id}"
    request = session.get(url)
    check_response(request)
    return request.json()["redirectURL"]


def submit_id_mapping(from_db, to_db, ids):
    request = requests.post(
        f"{UNIPROT_API_URL}/idmapping/run",
        data={"from": from_db, "to": to_db, "ids": ",".join(ids)},
    )
    check_response(request)
    return request.json()["jobId"]

def get_batch(batch_response, file_format, compressed):
    batch_url = get_next_link(batch_response.headers)
    while batch_url:
        batch_response = session.get(batch_url)
        batch_response.raise_for_status()
        yield decode_results(batch_response, file_format, compressed)
        batch_url = get_next_link(batch_response.headers)


def combine_batches(all_results, batch_results, file_format):
    if file_format == "json":
        for key in ("results", "failedIds"):
            if key in batch_results and batch_results[key]:
                all_results[key] += batch_results[key]
    elif file_format == "tsv":
        return all_results + batch_results[1:]
    else:
        return all_results + batch_results
    return all_results

def decode_results(response, file_format, compressed):
    if compressed:
        decompressed = zlib.decompress(response.content, 16 + zlib.MAX_WBITS)
        if file_format == "json":
            j = json.loads(decompressed.decode("utf-8"))
            return j
        elif file_format == "tsv":
            return [line for line in decompressed.decode("utf-8").split("\n") if line]
        elif file_format == "xlsx":
            return [decompressed]
        elif file_format == "xml":
            return [decompressed.decode("utf-8")]
        else:
            return decompressed.decode("utf-8")
    elif file_format == "json":
        return response.json()
    elif file_format == "tsv":
        return [line for line in response.text.split("\n") if line]
    elif file_format == "xlsx":
        return [response.content]
    elif file_format == "xml":
        return [response.text]
    return response.text


def get_next_link(headers):
    re_next_link = re.compile(r'<(.+)>; rel="next"')
    if "Link" in headers:
        match = re_next_link.match(headers["Link"])
        if match:
            return match.group(1)
        
def get_xml_namespace(element):
    m = re.match(r"\{(.*)\}", element.tag)
    return m.groups()[0] if m else ""


def merge_xml_results(xml_results):
    merged_root = ElementTree.fromstring(xml_results[0])
    for result in xml_results[1:]:
        root = ElementTree.fromstring(result)
        for child in root.findall("{http://uniprot.org/uniprot}entry"):
            merged_root.insert(-1, child)
    ElementTree.register_namespace("", get_xml_namespace(merged_root[0]))
    return ElementTree.tostring(merged_root, encoding="utf-8", xml_declaration=True)


def print_progress_batches(batch_index, size, total):
    n_fetched = min((batch_index + 1) * size, total)
    print(f"Fetched: {n_fetched} / {total}")
    
def get_id_mapping_results_search(url):
    parsed = urlparse(url)
    query = parse_qs(parsed.query)
    file_format = query["format"][0] if "format" in query else "json"
    if "size" in query:
        size = int(query["size"][0])
    else:
        size = 500
        query["size"] = size
    compressed = (
        query["compressed"][0].lower() == "true" if "compressed" in query else False
    )
    parsed = parsed._replace(query=urlencode(query, doseq=True))
    url = parsed.geturl()
    request = session.get(url)
    check_response(request)
    results = decode_results(request, file_format, compressed)
    total = int(request.headers["x-total-results"])
    print_progress_batches(0, size, total)
    for i, batch in enumerate(get_batch(request, file_format, compressed), 1):
        results = combine_batches(results, batch, file_format)
        print_progress_batches(i, size, total)
    if file_format == "xml":
        return merge_xml_results(results)
    return results

def check_id_mapping_results_ready(job_id):
    while True:
        request = session.get(f"{UNIPROT_API_URL}/idmapping/status/{job_id}")
        check_response(request)
        j = request.json()
        if "jobStatus" in j:
            if j["jobStatus"] == "RUNNING":
                print(f"Retrying in {POLLING_INTERVAL}s")
                time.sleep(POLLING_INTERVAL)
            else:
                raise Exception(j["jobStatus"])
        else:
            return bool(j["results"] or j["failedIds"])

In [30]:
# Function for data processing
def getProteinFASTA(uniprot_id):
    baseUrl = "http://www.uniprot.org/uniprot/"
    currentUrl = baseUrl + uniprot_id + ".fasta"
    response = requests.post(currentUrl)
    fasta = response.text
    return fasta

def getProteinRCSB(pdb_id, uniprot_id):
    url = f"http://files.rcsb.org/view/{pdb_id}.pdb"
    filename = f'{uniprot_id}.pdb'
    cache_path = os.path.join(PDB_FOLDER, filename)
    # if os.path.exists(cache_path):
    #     return cache_path

    pdb_req = requests.get(url)
    pdb_req.raise_for_status()
    open(cache_path, 'w').write(pdb_req.text)
    return cache_path

def converUniProt2PDB(uniprot_id, pdb_id_index=0):
    '''
    curl --form 'from="UniProtKB_AC-ID"' \
     --form 'to="PDB"' \
     --form 'ids="P07550"' \
     https://rest.uniprot.org/idmapping/run
    '''
    job_id = submit_id_mapping(
        from_db="UniProtKB_AC-ID", to_db="PDB", ids=[uniprot_id]
    )
    if check_id_mapping_results_ready(job_id):
        link = get_id_mapping_results_link(job_id)
        results = get_id_mapping_results_search(link)
    
    if len(results['results']) > pdb_id_index: # Just pick the first 3d Structure ID
        return results['results'][pdb_id_index]['to']
    return None
    
def getProteinMol(uniprot_id, pdb_id_index=0):
#     protein_fasta = getProteinFASTA(uniprot_id)
#     protein_mol = MolFromFASTA(protein_fasta)
#     MolToPDBFile(protein_mol, uniprot_id + ".pdb")
    pdb_id = converUniProt2PDB(uniprot_id, pdb_id_index)
    assert pdb_id is not None
    protein_path = getProteinRCSB(pdb_id, uniprot_id)
    protein_mol = MolFromPDBFile(protein_path)
    return protein_mol

def getLigandMol(ligand_smile):
    return MolFromSmiles(ligand_smile)

# Functions for reading CSVs
def readCSV(csv_file, skiprows=1):
    df = pd.read_csv(csv_file, dtype=str, skiprows=skiprows)
    return df

def findMatchingData(df, key, key_column, result_columns):
    # result_columns: list -> return dataframe
    # result_columns: string -> return series
    return df.loc[df[key_column] == key][result_columns]

In [31]:
# Data class
class IUPHAR:
    def __init__(self, folder, protein_ids, training=True, pdb_id_index=0):
        interactions_file = os.path.join(folder, "interactions.csv")
        ligands_file = os.path.join(folder, "ligands.csv")
        GtP_to_UniProt_mapping_file = os.path.join(folder, "GtP_to_UniProt_mapping.csv")
        
        self.interactions_df = readCSV(interactions_file)
        self.ligands_df = readCSV(ligands_file)
        self.GtP_to_UniProt_mapping_df = readCSV(GtP_to_UniProt_mapping_file)
        
        # List Protein by Uniprot ID
        self.protein_ids = protein_ids
        
        # {"UniprotID": Protein_Mol_Object}
        self.proteins = dict(map(lambda x: (x, getProteinMol(x, pdb_id_index)), self.protein_ids))
        self.proteins = dict(filter(lambda x: x[1] is not None, self.proteins.items()))
        
        # {"UniprotID": {"LigandID": BindingAff}}
        self.interactions = self.getInteractionsByProteins(self.protein_ids)
        
        # {"LigandID": Ligand_Mol_Object}
        self.protein_ligand_pairs = []
        self.ligands, self.ligands_smiles = self.getLigands(self.interactions)
            
        # Save protein-ligand for docking
        with open(os.path.join(folder, f'input_protein_ligand_{training}.csv'), "w") as f:
            f.write("protein_path,ligand\n")
            for p, l in self.protein_ligand_pairs:
                f.write(f'{p},{l}\n')
                
        # Create keys for following model
        self.keys = []
        for pid, lids in self.interactions.items():
            for lid in lids:
                if lid not in self.ligands_smiles:
                    continue
                self.keys.append(f'{pid}_{lid}_{self.interactions[pid][lid]}')
        
    def getInteractionsByProteins(self, protein_ids):
        interactions = defaultdict(dict)
        for pid in protein_ids:
            interaction = findMatchingData(self.interactions_df, pid, "Target UniProt ID", 
                                           ["Ligand ID", "Affinity Median"])
            
            for row in interaction.iterrows():
                lid, aff = row[1]["Ligand ID"], row[1]["Affinity Median"]
                if not (pd.isna(lid) or pd.isna(aff)):
                    interactions[pid][lid] = float(aff)
            
        return interactions
    
    def getLigands(self, interactions):
        ligand_ids = set([])
        for pid, lids in interactions.items():
            ligand_ids = ligand_ids.union(set(lids))
        
        ligand_ids = list(ligand_ids)
        ligands = {}
        ligands_smiles = {}
        for lid in ligand_ids:
            ligand_smiles = findMatchingData(self.ligands_df, lid, "Ligand ID", "SMILES")
            ligand_smiles = ligand_smiles.tolist()
            if len(ligand_smiles) > 0:
                if not pd.isna(ligand_smiles[0]):
                    # print(ligand_smiles[0])
                    ligands[lid] = getLigandMol(ligand_smiles[0])
                    ligands_smiles[lid] = ligand_smiles[0]
                    self.protein_ligand_pairs.append([pid+".pdb", ligand_smiles[0]])
                else:
                    ligands[lid] = None
                
        return ligands, ligands_smiles

In [32]:
training_data = IUPHAR(DATA_FOLDER, TRAINING_PROTEINS, training=True, pdb_id_index=1)
testing_data = IUPHAR(DATA_FOLDER, TESTING_PROTEINS, training=False, pdb_id_index=1)

Fetched: 39 / 39
Fetched: 50 / 50


In [34]:
def findFolder(listStr, protein, smile):
    for fol in listStr:
        if protein in fol and fol.endswith(smile):
            return fol
    return None

class ChemicalFeaturesFactory:
    """This is a singleton class for RDKit base features."""
    _instance = None

    @classmethod
    def get_instance(cls):
        try:
            from rdkit import RDConfig
            from rdkit.Chem import ChemicalFeatures
        except ModuleNotFoundError:
            raise ImportError("This class requires RDKit to be installed.")

        if not cls._instance:
            fdefName = os.path.join(RDConfig.RDDataDir, 'BaseFeatures.fdef')
            cls._instance = ChemicalFeatures.BuildFeatureFactory(fdefName)
        return cls._instance

factory = ChemicalFeaturesFactory.get_instance()

def get_feature_dict(mol):
    if mol == None:
        return {}
        
    feature_by_group = {}
    for f in factory.GetFeaturesForMol(mol):
        feature_by_group[f.GetAtomIds()] = f.GetFamily()

    feature_dict = {}
    for key in feature_by_group:
        for atom_idx in key:
            if atom_idx in feature_dict:
                feature_dict[atom_idx].append(feature_by_group[key])
            else:
                feature_dict[atom_idx] = [feature_by_group[key]]

    return feature_dict

def load_receptor_ligand_data(keys, training=True):
    result_list = {}
    if training:
        training = "training"
        data_agent = training_data
    else:
        training = "testing"
        data_agent = testing_data
        
    docking_folders = os.listdir(os.path.join(DATA_FOLDER, training))
    proteins_features = {}
    
    for key in tqdm(keys):
        receptor_name, ligand_name, _ = key.split("_")
        
        # Load ligand
        ligand_smile = data_agent.ligands_smiles[ligand_name].replace('/', '-')
        docking_folder = findFolder(docking_folders, receptor_name, ligand_smile)
        assert docking_folder is not None
        ligands_sdf = SDMolSupplier("%s/%s/%s/rank1.sdf" % (DATA_FOLDER, training, docking_folder))
        ligand = ligands_sdf[0]
        ligand_feature = get_feature_dict(ligand)
        # print("Ligand %s" % ligand_name, ligand != None)
        
        # Load receptor
        # receptor = MolFromPDBFile("%s/%s.pdb" % (PDB_FOLDER, receptor_name))
        receptor = data_agent.proteins[receptor_name]
        if receptor_name not in proteins_features:
            receptor_feature = get_feature_dict(receptor)
            proteins_features[receptor_name] = receptor_feature
        else:
            receptor_feature = proteins_features[receptor_name]
        # print("Receptor %s" % receptor_name, receptor != None)
        
        result_list[key] = (ligand, receptor, ligand_feature, receptor_feature)
        
    return result_list

# Load and save
def load_and_save_data_by_keys(keys, training=True):
    train_dict = load_receptor_ligand_data(keys, training=training)
    if not os.path.exists("data_GMGM"):
        os.mkdir("data_GMGM")
    
    for key, data in train_dict.items():
        with open('data_GMGM/'+key, 'wb') as f:
            pickle.dump(data, f)

In [35]:
# Load data and save keys
load_and_save_data_by_keys(training_data.keys, training=True)
load_and_save_data_by_keys(testing_data.keys, training=False)


100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 7/7 [01:28<00:00, 12.67s/it]


In [36]:
if not os.path.exists("keys"):
    os.mkdir("keys")
    
with open("keys/train_%s.pkl"%"IUPHAR", 'wb') as f:
    pickle.dump(training_data.keys, f)
    
with open("keys/test_%s.pkl"%"IUPHAR", 'wb') as f:
    pickle.dump(testing_data.keys, f)

# DiffDock

In [None]:
cd ../DiffDock
export HOME=esm/model_weights
export PYTHONPATH=$PYTHONPATH:/dfs/user/sttruong/DucWorkspace/DiffDock/esm

In [None]:
python datasets/esm_embedding_preparation.py \
--protein_ligand_csv ../graph_regression/input_protein_ligand_True.csv \
--out_file data/prepared_for_esm_train.fasta 

In [None]:
python esm/scripts/extract.py esm2_t33_650M_UR50D data/prepared_for_esm_train.fasta data/esm2_output \
--repr_layers 33 --include per_tok --truncation_seq_length 30000

In [None]:
python -m inference \
--protein_ligand_csv ../graph_regression/iuphar/input_protein_ligand_True.csv \
--out_dir results/training \
--inference_steps 20 --samples_per_complex 40 --batch_size 16

In [None]:
python datasets/esm_embedding_preparation.py \
--protein_ligand_csv ../graph_regression/input_protein_ligand_False.csv \
--out_file data/prepared_for_esm_test.fasta 

In [None]:
python esm/scripts/extract.py esm2_t33_650M_UR50D data/prepared_for_esm_test.fasta data/esm2_output \
--repr_layers 33 --include per_tok --truncation_seq_length 30000

In [None]:
python -m inference \
--protein_ligand_csv ../graph_regression/iuphar/input_protein_ligand_False.csv \
--out_dir results/testing \
--inference_steps 20 --samples_per_complex 40 --batch_size 16

In [None]:
python datasets/esm_embedding_preparation.py \
--protein_ligand_csv ../graph_regression/input_protein_ligand_False.csv \
--out_file data/prepared_for_esm_test_neg.fasta 

In [None]:
python esm/scripts/extract.py esm2_t33_650M_UR50D data/prepared_for_esm_test_neg.fasta data/esm2_output \
--repr_layers 33 --include per_tok --truncation_seq_length 30000

In [None]:
python -m inference \
--protein_ligand_csv ../graph_regression/iuphar/input_protein_ligand_False.csv \
--out_dir results/testing_neg \
--inference_steps 20 --samples_per_complex 40 --batch_size 16