In [209]:
import json
import pickle

def load_object(filename):
    with open(filename, 'rb') as f:
        return pickle.load(f)
    
affinity_dict = load_object('./PDBbind_affinity_dict.pkl')

# Write the dictionary to a JSON file
with open('PDBbind_data_dict.json', 'w', encoding='utf-8') as json_file:
    json.dump(affinity_dict, json_file, ensure_ascii=False, indent=4)

In [210]:
from Bio.PDB.PDBParser import PDBParser
import argparse
from f_parse_pdb_general import parse_pdb
import numpy as np
from rdkit import Chem
import os
import json
from time import time
import jax.numpy as jnp
from jax import jit, vmap


def parse_sdf_file(file_path):
    suppl = Chem.SDMolSupplier(file_path, sanitize = True, removeHs=True, strictParsing=True)
    molecules = []
    for mol in suppl:
        if mol is not None:
            molecules.append(mol)
    return molecules


#################
# def arg_parser():
#     parser = argparse.ArgumentParser(description='')
#     parser.add_argument('--data_dir', type=str, required=True, help='Path to the data directory containing all proteins(PDB) and ligands (SDF)')
#     parser.add_argument('--affinity_dict', type=str, required=True, help='Path to json file assigning affinity values to complexes in the data')
#     return parser.parse_args()

# args = arg_parser()
# data_dir = args.data_dir
# affinity_dict_path = args.affinity_dict
#################

### For testing ###
affinity_dict_path = 'PDBbind_data_dict.json'
data_dir = 'temporary_data'

# Load the JSON file into a dictionary
with open(affinity_dict_path, 'r', encoding='utf-8') as json_file:
    affinity_dict = json.load(json_file)


# Get sorted lists of proteins and ligands (dirEntry objects) in the data_dir
proteins = sorted([protein for protein in os.scandir(data_dir) if protein.name.endswith('protein.pdb')], key=lambda x: x.name)
ligands = sorted([ligand for ligand in os.scandir(data_dir) if ligand.name.endswith('ligand.sdf')], key=lambda x: x.name)

if not len(proteins) == len(ligands):
    raise ValueError('Number of proteins and ligands does not match')
else:
    print(f'Number of Protein PDBs: {len(proteins)}')
    print(f'Number of Ligand SDFs: {len(ligands)}')


Number of Protein PDBs: 1
Number of Ligand SDFs: 1


In [211]:
# Initialize Log File:
log_folder = os.path.join(data_dir,'.logs/')
if not os.path.exists(log_folder): os.makedirs(log_folder)
log_file_path = os.path.join(log_folder, "preprocessing_logs.txt")
log = open(log_file_path, 'a')
log.write("Data Preprocessing - Log File:\n")
log.write("\n")

1

In [212]:
# Initialize PDB Parser
parser = PDBParser(PERMISSIVE=1, QUIET=True)

amino_acids = ["ALA","ARG","ASN","ASP","CYS","GLN","GLU","GLY","HIS","ILE","LEU","LYS","MET","PHE","PRO","SER","THR","TRP","TYR","VAL"]
known_hetatms = ['ZN','MG','NA','MN','CA','K','NI','FE','CO','HG','CD','CU','CS','AU','LI','GA','IN','BA','RB','SR']
known_residues = amino_acids + known_hetatms


def compute_connections(protein_atomcoords, pos):
    diff = protein_atomcoords[jnp.newaxis, :, :] - pos[:, jnp.newaxis, :]
    pairwise_distances = jnp.linalg.norm(diff, axis=2)
    close = pairwise_distances <= 5
    return close
compute_connections = jit(compute_connections)


# Here start a loop over the complexes
#----------------------------------------------------------
for protein, ligand in zip(proteins, ligands):
    
    if protein.name.split('_')[0] != ligand.name.split('_')[0]:
        raise ValueError(f'Protein {protein} and Ligand {ligand} do not match')
    else:
        protein_id = protein.name.split('_')[0]
        protein_path = protein.path
        ligand_path = ligand.path
    
    log_string = f'{protein_id}: '
    
    # -----------------------------------------------------
    # PRESELECTION OF COMPLEXES
    # -----------------------------------------------------

    # TEST 1: AFFINITY DATA - Continue only if there is a valid affinity value available for this complex
    if protein_id not in affinity_dict.keys():
        log_string += 'Protein is not processed: No valid affinity value'
        log.write(log_string + "\n")
        continue

    # TEST 2: LIGAND PARSING - Continue only if the parsed ligand has been processed successfully, else skip this complex
    ligand = parse_sdf_file(ligand_path)
    if len(ligand) == 1: mol = ligand[0]
    else:
        log_string += 'Ligand could not be parsed successfully'
        log.write(log_string + "\n")
        continue
    
    # TEST 3: LIGAND SIZE - Get coordinate Matrix of the ligand molecule - Continue only if the ligand has at least 5 heavy atoms
    conformer = mol.GetConformer()
    coordinates = conformer.GetPositions()
    pos = jnp.array(coordinates)
    if pos.shape[0]<5:
        log_string += 'Ligand is smaller than 5 Atoms and is skipped'
        log.write(log_string + "\n")
        continue


    # PARSING OF PROTEIN PDB TO GENERATE COORDINATE MATRIX
    # -----------------------------------------------------

    with open(protein_path) as pdbfile:
        protein = parse_pdb(parser, protein_id, pdbfile)


    protein_atomcoords = jnp.array([], dtype=jnp.int64).reshape(0,3)
    res_list = []
    residue_memberships = []
        
    clean_aa_chain = False
    chain_too_long = False
    residue_idx = 1

    # Iterate over the chains in the protein
    for chain in protein:

        chain_comp = protein[chain]['composition']

        # CHAIN CONTAINS ONLY AMINO ACIDS (AND HETATMS in chain)
        if chain_comp == [True, False] or chain_comp == [True, True]:
            clean_aa_chain = True

            # If the chain is longer than 1024, skip the complex
            if len(protein[chain]['aa_seq']) > 1022:
                chain_too_long = True
                break
            
            for residue in protein[chain]['aa_residues']:
                res_dict = protein[chain]['aa_residues'][residue]

                # Append the coords of the residue to the protein_atomcoords
                protein_atomcoords = jnp.vstack((protein_atomcoords, res_dict['coords']))

                res_list.append((residue_idx, res_dict['resname']))
                
                memb = [residue_idx for atom in res_dict['atom_indeces']]
                residue_memberships.extend(memb)

                residue_idx += 1
        
        # CHAIN CONTAINS HETATMS BUT NO AMINO ACIDS
        elif protein[chain]['composition'] == [False, True]: 
            
            for hetatm_res in protein[chain]['hetatm_residues']:
                hetatmres_dict = protein[chain]['hetatm_residues'][hetatm_res]

                # Append the coords of the residue to the protein_atomcoords
                protein_atomcoords = jnp.vstack((protein_atomcoords, hetatmres_dict['hetatmcoords']))

                res_list.append((residue_idx, hetatmres_dict['resname']))

                memb = [residue_idx for atom in hetatmres_dict['atoms']]
                residue_memberships.extend(memb)

                residue_idx +=1


    # To do: Can we delete this? Does this ever happen?
    # If a chain is too long to generate an ESM Embedding, skip the complex
    if chain_too_long:
        log_string += 'Protein AA sequence too long for ESM'
        log.write(log_string + "\n")
        continue

    # To do: Can we delete this? Does this ever happen?
    if not clean_aa_chain:
        log_string += 'No clean AA_chain has been found, complex is skipped'
        log.write(log_string + "\n")
        continue




    # # COMPUTE CONNECTIVITY BETWEEN LIGAND AND PROTEIN ATOMS
    # # -----------------------------------------------------
    
    # With numpy --------------------------------------------------------------------------------------------------------
    # tic = time()
    # max_len = 4
    # diff = protein_atomcoords[np.newaxis, :, :] - pos[:, np.newaxis, :]
    # pairwise_distances = np.linalg.norm(diff, axis=2)
    # close = pairwise_distances <= max_len + 1
    
    # connections = [np.unique(np.array(residue_memberships)[np.where(row)]) for row in close]
    # np_time = time()-tic

    # connected_res_num = sorted(list(set([atm for l in connections for atm in l])))
    # connected_res_name = [res_list[aa-1][1] for aa in connected_res_num]
    
    # print(connections)
    # print(connected_res_num)
    # print(connected_res_name)
    
    # print(f'Numpy Time: {np_time}')
    # ---------------------------------------------------------------------------------------------------------------------



    # With JAX --------------------------------------------------------------------------------------------------------
    
    #Warm up the JIT compilation
    _ = compute_connections(np.array(protein_atomcoords), jnp.array(pos))
    
    tic = time()
    close = compute_connections(protein_atomcoords, pos)    
    connections = [np.unique(np.array(residue_memberships)[np.where(row)]) for row in np.array(close)]
    jnp_time = time()-tic
    connected_res_num = sorted(list(set([atm for l in connections for atm in l])))
    connected_res_name = [res_list[aa-1][1] for aa in connected_res_num]

    print(f'JaxNumpy Time: {jnp_time}')
    # ---------------------------------------------------------------------------------------------------------------------



    unknown_res = [(res not in known_residues and res.strip('0123456789') not in known_residues) for res in connected_res_name]
    if any(unknown_res):
        log_string += 'Ligand has been connected to a unknown protein residue, the complex is therefore skipped'
        log.write(log_string + "\n")
        continue



    # EXPORT DATA 
    # -------------------------------------------------------

    # Export protein dictionary as pkl
    filepath = os.path.join(data_dir, f'{protein_id}_protein_dict.pkl')
    with open(filepath, 'wb') as fp:
        pickle.dump(protein, fp)

    # Export CONNECTIONS as dict
    connections_dict = {'connections':connections, 'res_num':connected_res_num, 'res_name':connected_res_name}
    filepath = os.path.join(data_dir, f'{protein_id}_connections.pkl')
    with open(filepath, 'wb') as fp:
        pickle.dump(connections_dict, fp)


    log_string += 'Successful'
    log.write(log_string + "\n")

log.close()

JaxNumpy Time: 0.003453969955444336


  protein_atomcoords = jnp.array([], dtype=jnp.int64).reshape(0,3)
