In [13]:
import json
import pickle

def load_object(filename):
    with open(filename, 'rb') as f:
        return pickle.load(f)
    
affinity_dict = load_object('./PDBbind_affinity_dict.pkl')

# Write the dictionary to a JSON file
with open('PDBbind_data_dict.json', 'w', encoding='utf-8') as json_file:
    json.dump(affinity_dict, json_file, ensure_ascii=False, indent=4)

In [14]:
import os
import json
import argparse
import numpy as np

from Bio.PDB.PDBParser import PDBParser

# RDKit
from rdkit import Chem
from rdkit.Chem import AllChem
from rdkit.Chem import rdPartialCharges
from rdkit.Chem.MolStandardize import rdMolStandardize  

from time import time
import jax.numpy as jnp
from jax import jit, vmap

from f_parse_pdb_general import parse_pdb

# PyTorch and PyTorch Geometric
import torch
from torch_geometric.utils import to_undirected, add_self_loops
from torch_geometric.data import Data, Batch

### Input to the Algorithm

In [15]:

# def arg_parser():
#     parser = argparse.ArgumentParser(description="Inputs to Graph Generation Script")
#     parser.add_argument('--data_dir', type=str, required=True, help='Path to the data directory containing all proteins(PDB) and ligands (SDF)')
#     parser.add_argument('--affinity_dict', type=str, required=True, help='Path to json file assigning affinity values to complexes in the data')
    
#     parser.add_argument('--embedding_descriptors',
#     nargs='+',
#     help='Provide names of embeddings that should be incorporated (--embedding_descriptors string1 string2 string3)')

#     return parser.parse_args()

# args = arg_parser()
# data_dir = args.data_dir
# affinity_dict_path = args.affinity_dict
# embedding_descriptors = args.embedding_descriptors

### For testing ###
affinity_dict_path = 'PDBbind_data_dict.json'
data_dir = 'PDBbind1'
embedding_descriptors = ['ChemBERTa-10M-MLM', 'ankh_base', 'esm2_t6_8M_UR50D']


### Function Definitions

In [16]:
def parse_sdf_file(file_path):
    suppl = Chem.SDMolSupplier(file_path, sanitize = True, removeHs=True, strictParsing=True)
    molecules = []
    for mol in suppl:
        if mol is not None:
            molecules.append(mol)
    return molecules


def compute_connections(protein_atomcoords, pos):
    diff = protein_atomcoords[np.newaxis, :, :] - pos[:, np.newaxis, :]
    pairwise_distances = np.linalg.norm(diff, axis=2)
    close = pairwise_distances <= 5
    return close
compute_connections = jit(compute_connections)




def one_of_k_encoding_unk(x, allowable_set):
    """Maps inputs not in the allowable set to the last element.
    Unlike `one_of_k_encoding`, if `x` is not in `allowable_set`, this method
    pretends that `x` is the last element of `allowable_set`.
    Parameters
    ----------
    x: object
        Must be present in `allowable_set`.
    allowable_set: list
        List of allowable quantities.
    Examples
    --------
    >>> dc.feat.graph_features.one_of_k_encoding_unk("s", ["a", "b", "c"])
    [False, False, True]
    """
    if x not in allowable_set:
        x = allowable_set[-1]
    return list(map(lambda s: x == s, allowable_set))



def one_of_k_encoding(x, allowable_set):
    """Encodes elements of a provided set as integers.
    Parameters
    ----------
    x: object
        Must be present in `allowable_set`.
    allowable_set: list
        List of allowable quantities.
    Example
    -------
    >>> import deepchem as dc
    >>> dc.feat.graph_features.one_of_k_encoding("a", ["a", "b", "c"])
    [True, False, False]
    Raises
    ------
    `ValueError` if `x` is not in `allowable_set`.
    """
    if x not in allowable_set:
        raise ValueError("input {0} not in allowable set{1}:".format(
            x, allowable_set))
    return list(map(lambda s: x == s, allowable_set))



def make_undirected_with_self_loops(edge_index, edge_attr, undirected = True, self_loops = True):

    self_loop_feature_vector = torch.tensor(   [0., 1., 0.,                         # it's a self-loop
                                                0,                                  # length is zero
                                                0., 0., 0.,0.,0.,                   # bondtype = None
                                                0.,                                 # is not conjugated
                                                0.,                                 # is not in ring
                                                0., 0., 0., 0., 0., 0.])            # No stereo -> self-loop 

    if undirected: edge_index, edge_attr = to_undirected(edge_index, edge_attr)
    if self_loops: edge_index, edge_attr = add_self_loops( edge_index, edge_attr, fill_value = self_loop_feature_vector)

    return edge_index, edge_attr



def atom_features(mol, padding_len): # padding: first append n zeros (length of amino acid embeddings)

    x = []
    rdPartialCharges.ComputeGasteigerCharges(mol)

    for atom in mol.GetAtoms():

        padding = [0 for n in range(padding_len)]
        symbol = atom.GetSymbol()

        if symbol in metals: 
            symbol = 'metal'
        elif symbol in halogens:
            symbol = 'halogen'

        if symbol == 'H':
            atom_encoding = [0 for i in range(len(all_atoms))]
        else: 
            atom_encoding = one_of_k_encoding(symbol, all_atoms)
        
        ringm = [atom.IsInRing()]
        hybr = atom.GetHybridization()
        charge = [float(atom.GetFormalCharge())] 
        #charge = [float(atom.GetProp('_GasteigerCharge'))] 
        aromatic = [atom.GetIsAromatic()]
        mass = [atom.GetMass()/100]
        numHs = atom.GetTotalNumHs()
        degree = atom.GetDegree()
        chirality = str(atom.GetChiralTag())

    
        results =   padding + \
                    atom_encoding + \
                    ringm  + \
                    one_of_k_encoding(hybr, [Chem.rdchem.HybridizationType.S, Chem.rdchem.HybridizationType.SP, Chem.rdchem.HybridizationType.SP2, Chem.rdchem.HybridizationType.SP2D, 
                                             Chem.rdchem.HybridizationType.SP3, Chem.rdchem.HybridizationType.SP3D, Chem.rdchem.HybridizationType.SP3D2, Chem.rdchem.HybridizationType.UNSPECIFIED]) + \
                    charge + \
                    aromatic + \
                    mass + \
                    one_of_k_encoding(numHs, [0, 1, 2, 3, 4]) + \
                    one_of_k_encoding_unk(degree,[0, 1, 2, 3, 4, 5, 6, 7, 8, 'OTHER']) + \
                    one_of_k_encoding_unk(chirality, ['CHI_UNSPECIFIED', 'CHI_TETRAHEDRAL_CW', 'CHI_TETRAHEDRAL_CCW', 'OTHER'])     
        
        x.append(results)  

    return np.array(x)



def edge_index_and_attr(mol, pos, undirected = True, self_loops = True):

    edge_index = [[],[]]
    edge_attr = []

    #  Edge Attributes - Loop over edges and compute feature vector
    #--------------------------------------------------------------------    
    for bond in mol.GetBonds():

        atm1 = bond.GetBeginAtomIdx()
        atm2 = bond.GetEndAtomIdx()

        edge_index[0].append(atm1)
        edge_index[1].append(atm2)


        # Generate Edge Feature Vector
        #--------------------------------------------------------------------
        edge_feature_vector = []
        #print(f'Bond {bond.GetIdx()} between atoms {atm1, atm2}')

        # Edge Type (covalent bond, non-covalent_bond, self-loop)
        #print('---Covalent/Self-Loop/Non-Covalent: ', one_of_k_encoding_unk('covalent', ['covalent','self-loop','non-covalent']))
        edge_feature_vector.extend(one_of_k_encoding('covalent', ['covalent','self-loop','non-covalent']))

        # Length of Edge
        length = np.linalg.norm(pos[atm1]-pos[atm2])
        #print('---Bond Length: ', length )
        edge_feature_vector.append(length/10)

        # Bond Type (single, double, aromatic)
        #print('---Bond Type: ', one_of_k_encoding_unk(bond.GetBondTypeAsDouble(), [1.0, 1.5, 2.0, 'non-covalent']))
        edge_feature_vector.extend(one_of_k_encoding(bond.GetBondTypeAsDouble(), [0.,1.0,1.5,2.0,3.0]))

        # Conjugated
        #print('---Is Conjugated: ', [bond.GetIsConjugated()])
        edge_feature_vector.append(bond.GetIsConjugated())

        # Is in Ring?
        #print('---Is in Ring: ', [bond.IsInRing()])
        edge_feature_vector.append(bond.IsInRing())

        # Stereo
        allowed = [Chem.rdchem.BondStereo.STEREONONE,
                Chem.rdchem.BondStereo.STEREOANY, 
                Chem.rdchem.BondStereo.STEREOE, 
                Chem.rdchem.BondStereo.STEREOZ, 
                Chem.rdchem.BondStereo.STEREOCIS, 
                Chem.rdchem.BondStereo.STEREOTRANS]
        
        #print('---Bond Stereo: ', one_of_k_encoding(bond.GetStereo(), allowed))
        edge_feature_vector.extend(one_of_k_encoding(bond.GetStereo(), allowed))

        edge_attr.append(edge_feature_vector)

    # Make undirected and add self loops if necessary
    edge_index = torch.tensor(edge_index, dtype=torch.int64)
    edge_attr = torch.tensor(edge_attr, dtype=torch.float64)
    edge_index, edge_attr = make_undirected_with_self_loops(edge_index, edge_attr, undirected=undirected, self_loops=self_loops)
    return edge_index, edge_attr



all_atoms = ['B', 'C', 'N', 'O', 'P', 'S', 'Se', 'metal', 'halogen']
halogens = ['F', 'Cl', 'Br', 'I', 'At'] #Halogen atoms Fluorine (F), Chlorine (Cl), Bromine (Br), Iodine (I), and Astatine (At)
metals = [
    # Alkali Metals
    'Li', 'Na', 'K', 'Rb', 'Cs', 'Fr',
    # Alkaline Earth Metals
    'Be', 'Mg', 'Ca', 'Sr', 'Ba', 'Ra',
    # Transition Metals
    'Sc', 'Ti', 'V', 'Cr', 'Mn', 'Fe', 'Co', 'Ni', 'Cu', 'Zn',
    'Y', 'Zr', 'Nb', 'Mo', 'Tc', 'Ru', 'Rh', 'Pd', 'Ag', 'Cd',
    'Hf', 'Ta', 'W', 'Re', 'Os', 'Ir', 'Pt', 'Au', 'Hg',
    'Rf', 'Db', 'Sg', 'Bh', 'Hs', 'Mt', 'Ds', 'Rg', 'Cn',
    'Nh', 'Fl', 'Mc', 'Lv',
    # Lanthanides
    'La', 'Ce', 'Pr', 'Nd', 'Pm', 'Sm', 'Eu', 'Gd', 'Tb', 'Dy', 
    'Ho', 'Er', 'Tm', 'Yb', 'Lu',
    # Actinides
    'Ac', 'Th', 'Pa', 'U', 'Np', 'Pu', 'Am', 'Cm', 'Bk', 'Cf', 
    'Es', 'Fm', 'Md', 'No', 'Lr',
    # Post-Transition Metals
    'Al', 'Ga', 'In', 'Sn', 'Tl', 'Pb', 'Bi', 'Nh', 'Fl', 'Mc', 'Lv',
    # Half-Metals
    'As', 'Si', 'Sb', 'Te'
]



amino_acids = ["ALA", "ARG", "ASN", "ASP", "CYS", "GLN", "GLU", "GLY", "HIS", "ILE", "LEU",
               "LYS", "MET", "PHE", "PRO", "SER", "THR", "TRP", "TYR", "VAL"]

hetatm_smiles_dict1 = {'ZN': '[Zn+2]', 'MG': '[Mg+2]', 'NA': '[Na+1]', 'MN': '[Mn+2]', 'CA': '[Ca+2]', 'K': '[K+1]',
                      'NI': '[Ni+2]', 'FE': '[Fe+2]', 'CO': '[Co+2]', 'HG': '[Hg+2]', 'CD': '[Cd+2]', 'CU': '[Cu+2]', 
                      'CS': '[Cs+1]', 'AU': '[Au+1]', 'LI': '[Li+1]', 'GA': '[Ga+3]', 'IN': '[In+3]', 'BA': '[Ba+2]',
                      'RB': '[Rb+1]', 'SR': '[Sr+2]'}

hetatm_smiles_dict2 = {'Zn': '[Zn+2]', 'Mg': '[Mg+2]', 'Na': '[Na+1]', 'Mn': '[Mn+2]', 'Ca': '[Ca+2]', 'K': '[K+1]',
                      'Ni': '[Ni+2]', 'Fe': '[Fe+2]', 'Co': '[Co+2]', 'Hg': '[Hg+2]', 'Cd': '[Cd+2]', 'Cu': '[Cu+2]', 
                      'Cs': '[Cs+1]', 'Au': '[Au+1]', 'Li': '[Li+1]', 'Ga': '[Ga+3]', 'In': '[In+3]', 'Ba': '[Ba+2]',
                      'Rb': '[Rb+1]', 'Sr': '[Sr+2]'}

### Preprocessing

In [17]:
# Load the JSON file into a dictionary
with open(affinity_dict_path, 'r', encoding='utf-8') as json_file:
    affinity_dict = json.load(json_file)


# Get sorted lists of proteins and ligands (dirEntry objects) in the data_dir
proteins = sorted([protein for protein in os.scandir(data_dir) if protein.name.endswith('protein.pdb')], key=lambda x: x.name)
ligands = sorted([ligand for ligand in os.scandir(data_dir) if ligand.name.endswith('ligand.sdf')], key=lambda x: x.name)

if not len(proteins) == len(ligands):
    raise ValueError('Number of proteins and ligands does not match')
else:
    print(f'Number of Protein PDBs: {len(proteins)}')
    print(f'Number of Ligand SDFs: {len(ligands)}')


# Initialize Log File:
log_folder = os.path.join(data_dir,'.logs/')
if not os.path.exists(log_folder): os.makedirs(log_folder)
log_file_path = os.path.join(log_folder, "preprocessing_logs.txt")
log = open(log_file_path, 'a')
log.write("Data Preprocessing - Log File:\n")
log.write("\n")


Number of Protein PDBs: 16
Number of Ligand SDFs: 16


1

In [18]:
# Initialize PDB Parser
parser = PDBParser(PERMISSIVE=1, QUIET=True)

amino_acids = ["ALA","ARG","ASN","ASP","CYS","GLN","GLU","GLY","HIS","ILE","LEU","LYS","MET","PHE","PRO","SER","THR","TRP","TYR","VAL"]
known_hetatms = ['ZN','MG','NA','MN','CA','K','NI','FE','CO','HG','CD','CU','CS','AU','LI','GA','IN','BA','RB','SR']
known_residues = amino_acids + known_hetatms


# Here start a loop over the complexes
#----------------------------------------------------------
for protein, ligand in zip(proteins, ligands):
    
    if protein.name.split('_')[0] != ligand.name.split('_')[0]:
        raise ValueError(f'Protein {protein} and Ligand {ligand} do not match')
    else:
        protein_id = protein.name.split('_')[0]
        protein_path = protein.path
        ligand_path = ligand.path
    
    log_string = f'{protein_id}: '
    
    # -----------------------------------------------------
    # PRESELECTION OF COMPLEXES
    # -----------------------------------------------------

    # TEST 1: AFFINITY DATA - Continue only if there is a valid affinity value available for this complex
    if protein_id not in affinity_dict.keys():
        log_string += 'Protein is not processed: No valid affinity value'
        log.write(log_string + "\n")
        continue

    # TEST 2: LIGAND PARSING - Continue only if the parsed ligand has been processed successfully, else skip this complex
    ligand = parse_sdf_file(ligand_path)
    if len(ligand) == 1: mol = ligand[0]
    else:
        log_string += 'Ligand could not be parsed successfully'
        log.write(log_string + "\n")
        continue
    
    # TEST 3: LIGAND SIZE - Get coordinate Matrix of the ligand molecule - Continue only if the ligand has at least 5 heavy atoms
    conformer = mol.GetConformer()
    coordinates = conformer.GetPositions()
    pos = np.array(coordinates)
    if pos.shape[0]<5:
        log_string += 'Ligand is smaller than 5 Atoms and is skipped'
        log.write(log_string + "\n")
        continue


    # PARSING OF PROTEIN PDB TO GENERATE COORDINATE MATRIX
    # -----------------------------------------------------

    with open(protein_path) as pdbfile:
        protein = parse_pdb(parser, protein_id, pdbfile)


    protein_atomcoords = np.array([], dtype=np.int64).reshape(0,3)
    res_list = []
    residue_memberships = []
        
    clean_aa_chain = False
    chain_too_long = False
    residue_idx = 1

    # Iterate over the chains in the protein
    for chain in protein:

        chain_comp = protein[chain]['composition']

        # CHAIN CONTAINS ONLY AMINO ACIDS (AND HETATMS in chain)
        if chain_comp == [True, False] or chain_comp == [True, True]:
            clean_aa_chain = True

            # If the chain is longer than 1024, skip the complex
            if len(protein[chain]['aa_seq']) > 1022:
                chain_too_long = True
                break
            
            for residue in protein[chain]['aa_residues']:
                res_dict = protein[chain]['aa_residues'][residue]

                # Append the coords of the residue to the protein_atomcoords
                protein_atomcoords = np.vstack((protein_atomcoords, res_dict['coords']))

                res_list.append((residue_idx, res_dict['resname']))
                
                memb = [residue_idx for atom in res_dict['atom_indeces']]
                residue_memberships.extend(memb)

                residue_idx += 1
        
        # CHAIN CONTAINS HETATMS BUT NO AMINO ACIDS
        elif protein[chain]['composition'] == [False, True]: 
            
            for hetatm_res in protein[chain]['hetatm_residues']:
                hetatmres_dict = protein[chain]['hetatm_residues'][hetatm_res]

                # Append the coords of the residue to the protein_atomcoords
                protein_atomcoords = jnp.vstack((protein_atomcoords, hetatmres_dict['hetatmcoords']))

                res_list.append((residue_idx, hetatmres_dict['resname']))

                memb = [residue_idx for atom in hetatmres_dict['atoms']]
                residue_memberships.extend(memb)

                residue_idx +=1


    # To do: Can we delete this? Does this ever happen?
    # If a chain is too long to generate an ESM Embedding, skip the complex
    if chain_too_long:
        log_string += 'Protein AA sequence too long for ESM'
        log.write(log_string + "\n")
        continue

    # To do: Can we delete this? Does this ever happen?
    if not clean_aa_chain:
        log_string += 'No clean AA_chain has been found, complex is skipped'
        log.write(log_string + "\n")
        continue




    # # COMPUTE CONNECTIVITY BETWEEN LIGAND AND PROTEIN ATOMS
    # # -----------------------------------------------------
    
    # With numpy --------------------------------------------------------------------------------------------------------
    tic = time()
    max_len = 4
    diff = protein_atomcoords[np.newaxis, :, :] - pos[:, np.newaxis, :]
    pairwise_distances = np.linalg.norm(diff, axis=2)
    close = pairwise_distances <= max_len + 1
    
    connections = [np.unique(np.array(residue_memberships)[np.where(row)]) for row in close]
    np_time = time()-tic

    connected_res_num = sorted(list(set([atm for l in connections for atm in l])))
    connected_res_name = [res_list[aa-1][1] for aa in connected_res_num]
    
    print(connections)
    print(connected_res_num)
    print(connected_res_name)
    
    print(f'Numpy Time: {np_time}')
    # ---------------------------------------------------------------------------------------------------------------------



    # With JAX --------------------------------------------------------------------------------------------------------
    
    #Warm up the JIT compilation
    # _ = compute_connections(np.array(protein_atomcoords), jnp.array(pos))
    
    # tic = time()
    # close = compute_connections(protein_atomcoords, pos)    
    # connections = [np.unique(np.array(residue_memberships)[np.where(row)]) for row in np.array(close)]
    # jnp_time = time()-tic
    # connected_res_num = sorted(list(set([atm for l in connections for atm in l])))
    # connected_res_name = [res_list[aa-1][1] for aa in connected_res_num]

    # print(f'JaxNumpy Time: {jnp_time}')
    # ---------------------------------------------------------------------------------------------------------------------



    unknown_res = [(res not in known_residues and res.strip('0123456789') not in known_residues) for res in connected_res_name]
    if any(unknown_res):
        log_string += 'Ligand has been connected to a unknown protein residue, the complex is therefore skipped'
        log.write(log_string + "\n")
        continue



    # EXPORT DATA 
    # -------------------------------------------------------
    import pickle

    # Export protein dictionary as pkl
    filepath = os.path.join(data_dir, f'{protein_id}_protein_dict.pkl')
    with open(filepath, 'wb') as fp:
        pickle.dump(protein, fp)

    # Export CONNECTIONS as dict
    connections_dict = {'connections':connections, 'res_num':connected_res_num, 'res_name':connected_res_name}
    filepath = os.path.join(data_dir, f'{protein_id}_connections.pkl')
    with open(filepath, 'wb') as fp:
        pickle.dump(connections_dict, fp)


    log_string += 'Successful'
    log.write(log_string + "\n")

log.close()



[array([ 35,  88,  90,  94, 245, 307, 309, 310]), array([ 35,  45,  88,  90,  94, 307, 310]), array([ 35,  45,  48,  88,  90,  94, 310]), array([ 33,  35,  45,  48,  88,  90, 310]), array([33, 35, 45, 48, 88, 90]), array([ 35,  88,  94,  96, 307, 315]), array([ 35,  88,  94,  96, 247, 307, 315]), array([ 35,  88,  94,  96, 245, 247, 257, 307, 315]), array([ 88,  94,  96, 245, 247, 257, 307, 315]), array([ 88,  94,  96, 245, 257, 307]), array([ 35,  88,  94, 245, 257, 307, 309, 310]), array([ 90,  94, 310, 311]), array([ 90,  94, 311]), array([ 94, 310, 311]), array([ 90, 311]), array([ 35,  94, 245, 307, 308, 309, 310]), array([ 94, 243, 245, 307, 309, 310]), array([ 35,  88,  90,  94, 245, 307, 309, 310]), array([ 90,  94, 310, 311]), array([31, 90]), array([ 90, 311]), array([ 94, 309, 310, 311]), array([ 35,  88,  94, 245, 307, 309, 310])]
[31, 33, 35, 45, 48, 88, 90, 94, 96, 243, 245, 247, 257, 307, 308, 309, 310, 311, 315]
['TYR', 'GLY', 'TYR', 'LEU', 'HIS', 'LEU', 'TYR', 'ARG', '



[array([ 10,  48,  50,  91, 129]), array([  8,  10,  40,  48,  50, 129]), array([  8,  10,  40,  48, 129]), array([  8,  10,  40,  48, 129]), array([  8,  10,  40,  48,  50, 129]), array([ 8, 48, 50]), array([ 10,  48,  50,  91, 129]), array([ 10,  40,  48, 129]), array([  8,  10,  40,  44,  48, 129]), array([  8,  10,  40, 124, 126, 129]), array([ 8, 10, 48, 50, 91]), array([  8,  50, 124]), array([48, 91]), array([ 48,  50,  91, 131]), array([ 91, 131, 134]), array([ 50,  91, 118, 131, 134]), array([ 50,  70,  91, 118, 134]), array([ 50,  70,  91, 118, 134]), array([ 91, 131, 134, 906]), array([ 48,  50,  91, 118, 134]), array([ 91, 129, 130, 131]), array([118, 122, 124, 131, 134]), array([ 50,  70, 118, 134])]
[8, 10, 40, 44, 48, 50, 70, 91, 118, 122, 124, 126, 129, 130, 131, 134, 906]
['TYR', 'ARG', 'ARG', 'GLN', 'TYR', 'GLU', 'MET', 'ARG', 'PHE', 'ASN', 'ASP', 'HIS', 'ASP', 'SER', 'ASP', 'PHE', 'ASP']
Numpy Time: 0.049070119857788086
[array([118, 164]), array([118, 164]), array([1



[array([118, 164]), array([118, 164]), array([118]), array([118, 164]), array([118, 164, 165, 166]), array([118, 164, 165]), array([164]), array([118, 164, 165, 166]), array([118, 164, 165, 166]), array([118, 140, 142, 166]), array([118, 138, 140, 148, 164, 165, 166]), array([118, 140, 142, 166]), array([118, 138, 140, 148, 164, 165, 166]), array([118, 138, 140, 142, 148, 164, 166]), array([118, 138, 140, 141, 142, 143, 148, 164, 166]), array([118, 138, 140, 141, 142, 148]), array([118, 138, 140, 141, 142]), array([118, 138, 139, 140, 141, 142, 143, 148]), array([118, 138, 140, 141, 142, 148, 164]), array([164, 165, 166]), array([164, 165]), array([165]), array([164, 165, 166]), array([163, 164, 165]), array([163, 164, 165]), array([163, 164, 165]), array([163, 165]), array([163, 164]), array([], dtype=int64), array([], dtype=int64), array([], dtype=int64), array([], dtype=int64), array([165, 177, 178]), array([165, 177, 178]), array([165, 177, 178, 199, 200]), array([165, 178, 199]), 



[array([257, 259]), array([182, 257, 258, 259, 260, 262]), array([258, 262, 265]), array([257, 258, 259]), array([257, 258, 259, 262]), array([182, 257, 258, 259, 262]), array([258, 259, 262]), array([182, 258, 259, 262]), array([258, 259, 262]), array([258, 259, 262]), array([258, 259, 262]), array([258, 259, 262]), array([256, 257, 258, 259]), array([256, 257, 258, 259]), array([256, 257]), array([255, 256, 257]), array([256, 257, 258]), array([208, 256, 257, 258]), array([208, 256, 257, 258]), array([208, 256, 257, 258, 259]), array([ 86, 231, 234, 256, 257]), array([83, 86]), array([132, 256, 257]), array([ 79, 231, 234, 255, 256, 257]), array([ 79, 132, 234, 255, 256, 257]), array([ 79, 132, 234, 255, 256]), array([ 83, 132, 256]), array([ 79,  83, 132]), array([ 79,  83,  86, 132]), array([83, 86]), array([86]), array([132, 256, 257]), array([ 79,  80, 231, 232, 234, 255]), array([ 79, 231, 234, 255, 256, 257]), array([228, 229, 230, 231, 256, 257, 259, 260, 267, 268]), array([22



[array([ 40,  73, 216, 298, 332, 333]), array([ 40, 216, 271, 272, 298, 333]), array([ 40,  41,  73, 216, 298, 332, 333]), array([ 40,  41,  73, 200, 216, 298, 333]), array([ 40,  41,  73, 200, 216, 298, 333]), array([ 40,  41,  73, 150, 200, 333]), array([ 73,  74, 200, 333]), array([ 73,  74, 145, 147, 150, 200]), array([ 73,  74, 101, 102, 145, 147]), array([ 73,  74, 101, 145, 147]), array([ 73,  74, 101, 102, 145, 147, 150]), array([ 73, 147, 200, 216, 333]), array([ 40,  73, 200, 216, 298, 333]), array([147, 169, 199, 200, 216, 333]), array([ 73,  74, 147, 169, 216]), array([147, 169, 199, 200, 216, 218, 333]), array([147, 199, 200, 216, 218, 333]), array([147, 169, 199, 216, 218]), array([147, 169, 197, 199, 218]), array([ 41,  73,  74,  78, 101, 333]), array([ 41,  73,  74,  78, 101, 102, 150]), array([ 41,  73,  78, 101, 102, 150, 333]), array([ 41,  56,  73,  74,  78, 101, 102]), array([ 41,  72,  73,  74,  78, 101])]
[40, 41, 56, 72, 73, 74, 78, 101, 102, 145, 147, 150, 169,



[array([ 37,  98, 121, 122, 123, 124, 125]), array([ 37,  84,  93, 124, 136, 171, 173]), array([ 37,  38,  41,  84, 171, 173]), array([ 37,  38,  41,  79,  82,  83,  84, 171]), array([40, 44, 82]), array([84, 88, 92]), array([40, 44]), array([36, 37, 40, 44]), array([ 98, 121]), array([ 37,  98, 121, 122, 123, 124]), array([ 37,  38,  41,  79,  84, 171, 173]), array([ 37,  98, 121, 122, 123, 124, 125]), array([ 92,  93,  98, 121, 122, 124, 125]), array([ 37,  93, 124]), array([ 84,  92,  93, 124]), array([ 84,  93, 124]), array([ 37,  84,  93, 124, 173]), array([ 37,  41,  84, 171, 173]), array([37, 40, 41, 84]), array([37, 40, 41, 44, 82, 84]), array([40, 41, 44, 82, 84]), array([44, 82, 84]), array([44, 84, 92]), array([44, 92]), array([44, 92, 98]), array([44, 98]), array([40, 44, 98]), array([40, 44]), array([37, 40]), array([37, 40]), array([ 37,  98, 121, 123, 124]), array([ 98, 121]), array([ 92,  93,  98, 121, 122, 124, 125]), array([ 84,  93, 124, 136, 171, 173]), array([ 37, 



[array([ 40,  73, 216, 298, 332, 333]), array([ 40, 216, 271, 272, 298, 333]), array([ 40, 216, 298, 332, 333]), array([ 40,  41,  73, 200, 216, 298, 333]), array([ 40,  41,  73, 200, 216, 298, 333]), array([ 40,  41,  73,  78, 150, 200, 333]), array([ 41,  73,  74, 200, 333]), array([ 73,  74, 147, 150, 200]), array([ 73,  74, 101, 102, 145, 147]), array([ 73,  74, 101, 145]), array([ 73,  74, 101, 102, 145, 147, 150]), array([ 73, 200, 216, 333]), array([ 40,  73, 200, 216, 298, 333]), array([ 74, 147, 200, 216, 333]), array([ 73,  74, 216]), array([147, 169, 199, 200]), array([145, 147, 169, 199]), array([ 74, 145, 147, 169]), array([ 74, 145, 147, 169]), array([ 74, 145, 169]), array([144, 145, 147, 167, 168, 169]), array([144, 145, 168, 169]), array([144, 145, 147, 167, 168, 169, 170]), array([144, 145, 167, 168, 169, 170]), array([147, 169, 199, 200, 216, 333]), array([147, 169, 199, 200, 216, 218, 333]), array([147, 169, 199, 200, 216, 218]), array([40, 41, 73, 78])]
[40, 41, 73



[array([202, 203, 204, 205, 206, 207]), array([202, 203, 204, 205, 206, 207]), array([202, 203, 205, 206, 207, 208, 247]), array([201, 202, 203, 204, 205, 206, 207]), array([203, 204, 205, 206, 207, 208]), array([203, 205, 206, 207, 208, 223]), array([203, 205, 207, 208, 223, 225]), array([205, 206, 207, 208, 209, 223]), array([203, 204, 205, 206, 207, 208, 223, 306]), array([203, 204, 205, 208, 306]), array([203, 205, 208, 306]), array([203, 205, 218, 306]), array([208, 218, 223]), array([223]), array([205, 208, 218, 223]), array([208, 218, 219, 223]), array([205, 208, 218, 306]), array([205, 208, 218, 306, 349]), array([204, 205, 208, 218, 306, 349]), array([204, 205, 208, 209, 218, 305, 306, 349, 350]), array([205, 208, 218, 306, 309, 348, 349, 350]), array([205, 218, 305, 306, 308, 309, 348, 349, 350]), array([218, 305, 306, 307, 308, 347, 348, 349, 350, 351]), array([218, 306, 308, 309, 348, 349, 350]), array([218, 306, 308, 309, 348, 350]), array([306, 308, 309, 348, 350]), array



[array([ 70,  77, 123, 224, 245, 246, 247]), array([ 70,  74,  77, 123, 224, 245, 246]), array([ 70,  74,  77, 123]), array([ 74,  77, 123]), array([ 77, 247]), array([ 77, 246, 247, 248]), array([ 77, 123, 246, 247]), array([ 74,  77, 123]), array([ 77, 123, 245, 246, 247]), array([123, 221, 245, 246, 247]), array([ 70,  77, 123, 221, 224, 245, 246, 247]), array([ 77, 221, 224, 246, 247]), array([221, 246, 247, 248, 249]), array([198, 246, 247, 248]), array([123, 198, 246, 247, 248]), array([ 74, 198]), array([ 74, 121, 122, 123, 198]), array([ 74, 121, 122, 123, 198]), array([ 74, 121, 122, 123, 198, 246]), array([122, 123, 198, 246, 247]), array([ 70,  77, 221, 224, 245, 246, 247]), array([ 70, 220, 221, 222, 224, 244, 245, 246]), array([ 70, 220, 221, 222, 223, 224, 244, 245]), array([220, 221, 222, 223, 224, 244]), array([ 70, 220, 221, 222, 223, 224, 244, 245, 246, 247]), array([219, 220, 221, 224, 244, 245, 246, 247]), array([219, 220, 221, 223, 244, 246, 247]), array([218, 219,



### Generate Graphs

In [None]:
#-------------------------------------------------------------------------------------------------------------
# GENERATE INTERACTION-GRAPHS OF ALL COMPLEXES
#------------------------------------------------------------------------------------------------------------- 

# Choose the esm embedding that should be used:
embedding_descriptor = args.embedding
num_atomfeatures = 40
num_edgefeatures = 17
output_folder = f'/data/grbv/PDBbind/DTI_5/input_graphs_{embedding_descriptor}_unpad/'


# GET THE PREPROCESSED DATA
# -------------------------------------------------------------------------------
# Generate a lists of all protein-ligand complexes, the corresponding folder path and protein_dictionary paths
input_data_dir = '/data/grbv/PDBbind/DTI5_input_data_processed'
complexes = [subfolder for subfolder in os.listdir(input_data_dir) if len(subfolder) ==4 and subfolder[0].isdigit()]
folder_paths = [os.path.join(input_data_dir, complex) for complex in complexes]
protein_paths = [os.path.join(folder_path, f'{complex}_protein_dict.pkl') for complex, folder_path in zip(complexes, folder_paths)]
ligand_paths = [os.path.join(folder_path, f'{complex}_ligand_san.sdf') for complex, folder_path in zip(complexes, folder_paths)]
affinity_dict = load_object('/data/grbv/PDBbind/DTI5_general_affinity_dict.pkl')
# -------------------------------------------------------------------------------


# CHECK WHICH COMPLEXES ARE PART OF THE TEST DATASETS (CASF2013 and CASF2016) AND WHICH ARE PART OF REFINED SET
# -------------------------------------------------------------------------------
casf_2013_dir = '/data/grbv/PDBbind/raw_data/CASF-2013/coreset'
casf_2016_dir = '/data/grbv/PDBbind/raw_data/CASF-2016/coreset'



casf_2013_complexes = [subfolder for subfolder in os.listdir(casf_2013_dir) if len(subfolder) ==4 and subfolder[0].isdigit()]
casf_2016_complexes = [subfolder for subfolder in os.listdir(casf_2016_dir) if len(subfolder) ==4 and subfolder[0].isdigit()]

data_dir_general = '/data/grbv/PDBbind/raw_data/v2020_general_san'
data_dir_refined = '/data/grbv/PDBbind/raw_data/v2020_refined_san'

general_complexes = [protein for protein in os.listdir(data_dir_general) if len(protein)==4 and protein[0].isdigit()]
refined_complexes = [protein for protein in os.listdir(data_dir_refined) if len(protein)==4 and protein[0].isdigit()]


# Are all 2013 complexes present in the preprocessed data? 
missing_2013 = []
for complex in casf_2013_complexes: 
    if not complex in complexes: missing_2013.append(complex)
print(f'CASF-2013 complexes that are not present in preprocessed data: {missing_2013}')

# Are all 2016 complexes present in the preprocessed data? 
missing_2016 = []
for complex in casf_2016_complexes: 
    if not complex in complexes: missing_2016.append(complex)
print(f'CASF-2016 complexes that are not present in preprocessed data: {missing_2016}')
# -------------------------------------------------------------------------------


# Create Output Folders
# -----------------------------------------------------------
train_folder = os.path.join(output_folder, 'training_data')
test_folder = os.path.join(output_folder, 'test_data')
casf2013_folder = os.path.join(test_folder, 'casf2013')
casf2016_folder = os.path.join(test_folder, 'casf2016')

for folder in [train_folder, test_folder, casf2013_folder, casf2016_folder]:
    if not os.path.exists(folder): os.makedirs(folder)
# -----------------------------------------------------------


# Initialize Log:
# -------------------------------------------------------------------------------
log_folder = output_folder + '.logs/'

if not os.path.exists(log_folder): os.makedirs(log_folder)
log_file_path = os.path.join(log_folder, f"graph_generation_{embedding_descriptor}.txt")
log = open(log_file_path, 'a')
log.write("Generation of Featurized Interaction Graphs - Log File:\n")
log.write("Data: PDBbind v2020 refined and general set merged\n")
log.write("\n")

skipped = []
num_threads = torch.get_num_threads() // 4
torch.set_num_threads(num_threads)
# -------------------------------------------------------------------------------



# Here start a loop over the mutants
#----------------------------------------------------------
# ind = complexes.index('1a0q')
# for complex_id, folder_path, protein_path, ligand_path in zip(complexes[ind:ind+1], folder_paths[ind:ind+1], protein_paths[ind:ind+1], ligand_paths[ind:ind+1]):

for complex_id, folder_path, protein_path, ligand_path in zip(complexes, folder_paths, protein_paths, ligand_paths):
    
    log_string = f'{complex_id}: '
    print(complex_id)

    # Load necessary data
    protein_dict = load_object(protein_path)
    ligand = parse_sdf_file(ligand_path)

    try:
        connections_dict = load_object( os.path.join(folder_path, f'{complex_id}_connections.pkl') )
    except FileNotFoundError:
        log_string += 'Skipped - No connections_dict found (Fail in preprocessing)'
        log.write(log_string + "\n")
        skipped.append(complex_id)
        continue
    


    aa_embedding = torch.load(os.path.join(folder_path, f'{complex_id}_{embedding_descriptor}.pt'))
    aa_emb_len = aa_embedding.shape[1]
    

    # Access the ligand mol object and generate coordinate matrix (pos)
    if not len(ligand) == 1:
        log_string += 'Skipped - More than one ligand molecule provided'
        log.write(log_string + "\n")
        skipped.append(complex_id)
        continue
    
    mol = ligand[0]
    conformer = mol.GetConformer()
    coordinates = conformer.GetPositions()
    pos = np.array(coordinates)



    #===================================================================================================================================
    # Create Interaction Graph
    #===================================================================================================================================


    # Edge Index and Node Feature Matrix for Substrate
    #------------------------------------------------------------------------------------------
    x_lig_emb = atom_features(mol, padding_len=aa_emb_len)
    x_lig_aa = atom_features(mol, padding_len=len(amino_acids))
    

    if np.sum(np.isnan(x_lig_emb)) > 0:
        log_string += 'Skipped - Nans during ligand feature computation'
        log.write(log_string + "\n")
        skipped.append(complex_id)
        continue
    
    edge_index_lig, edge_attr_lig = edge_index_and_attr(mol, pos, self_loops=False, undirected=False)
    #------------------------------------------------------------------------------------------
    



    # Add the data of the amino acids identified as neighbors to X and POS
    #------------------------------------------------------------------------------------------
    connections = connections_dict['connections']
    connections_res_num = connections_dict['res_num']
    connections_res_name = connections_dict['res_name']

    # Joined Residues Dictionary
    protein = {}
    residue_idx = 1
    for chain in protein_dict:
        chain_comp = protein_dict[chain]['composition']

        if chain_comp == [True, False] or chain_comp == [True, True]:
            for residue in protein_dict[chain]['aa_residues']:
                protein[residue_idx] = protein_dict[chain]['aa_residues'][residue]
                residue_idx += 1

        elif chain_comp == [False, True]:
            for hetatm_res in protein_dict[chain]['hetatm_residues']:
                protein[residue_idx] = protein_dict[chain]['hetatm_residues'][hetatm_res]
                residue_idx += 1

    # Iterate over the connection enzyme residues
    x_prot_emb = np.array([]).reshape(0, aa_emb_len + num_atomfeatures)
    x_prot_aa = np.array([]).reshape(0, len(amino_acids) + num_atomfeatures)
    
    new_indeces = []
    count = pos.shape[0]
    residue_mismatch = False
    incomplete_residue = False

    for residue, resname in zip(connections_res_num, connections_res_name):

        if not resname == protein[residue]['resname']:
            print(f'Complex {complex_id}: Residues do not match with')
            residue_mismatch = True
        
        # IF THE RESIDUE IS AN AMINO ACID
        if resname in amino_acids:
            
            try: ca_idx = protein[residue]['atoms'].index('CA')
            except ValueError as ve: 
                incomplete_residue = (residue, resname)
                continue
            
            # Add coords of the CA atom to pos
            coords = protein[residue]['coords'][ca_idx]
            pos = np.vstack((pos, coords))


            # Add feature vector to prot_emb matrix
            padding = np.zeros((1, num_atomfeatures))
            embedding = aa_embedding[residue-1]
            features = np.concatenate((embedding[np.newaxis,:], padding), axis=1)
            # ----------------------------------------------
            x_prot_emb = np.vstack((x_prot_emb, features))
            # ----------------------------------------------


            # Add feature vector to prot_aa matrix
            padding = np.zeros((1, num_atomfeatures))
            aa_identity = one_of_k_encoding_unk(resname, amino_acids)
            features = np.concatenate((np.array(aa_identity)[np.newaxis,:], padding), axis=1)
            # ----------------------------------------------
            x_prot_aa = np.vstack((x_prot_aa, features))
            # ----------------------------------------------
            

        # IF THE RESIDUE IS A HETATM
        else:
            # Add coords of hetatm to pos
            coords = protein[residue]['hetatmcoords']
            pos = np.vstack((pos, coords))

            # Add feature vector of hetatm to x
            resname_smiles = hetatm_smiles_dict1[resname.strip('0123456789')]
            hetatm_mol = Chem.MolFromSmiles(resname_smiles)

            hetatm_features_emb = atom_features(hetatm_mol, padding_len=aa_emb_len)
            hetatm_features_aa = atom_features(hetatm_mol, padding_len=len(amino_acids))

            # ----------------------------------------------
            x_prot_emb = np.vstack((x_prot_emb, hetatm_features_emb))
            x_prot_aa = np.vstack((x_prot_aa, hetatm_features_aa))
            # ----------------------------------------------

        new_indeces.append(count)
        count +=1


    # MASTER NODE
    # add master node features (ones) to x_prot and add a point with mean of ligand coordinates to pos

    master_node_features_emb = np.zeros([1, aa_emb_len + num_atomfeatures], dtype=np.float64)
    master_node_features_aa = np.zeros([1, len(amino_acids) + num_atomfeatures], dtype=np.float64)


    x_prot_aa = np.vstack((x_prot_aa, master_node_features_aa))
    x_prot_emb = np.vstack((x_prot_emb, master_node_features_emb))

    pos = np.vstack((pos, np.mean(pos[:x_lig_emb.shape[0],:], axis=0)))
    
    #------------------------------------------------------------------------------------------



    # Check that no nans have been added to x during the feature computation
    if np.sum(np.isnan(x_prot_emb)) > 0 or np.sum(np.isnan(x_prot_aa)) > 0:
        log_string += 'Skipped - Nans during enzyme residues feature computation'
        log.write(log_string + "\n")
        skipped.append(complex_id)
        continue

    # Check that there has been no residue mismatch
    if residue_mismatch: 
        log_string += 'Skipped - Mismatch between "connections" and protein_dict found!'
        log.write(log_string + "\n")
        skipped.append(complex_id)
        continue
    
    # If in one of the residues the CA atom was not found, PDB is incomplete, skip complex
    if incomplete_residue:
        log_string += f'Skipped - Protein residue {incomplete_residue} missing CA-Atom'
        log.write(log_string + "\n")
        skipped.append(complex_id)
        continue




    # EDGE INDEX, EDGE ATTR - Add the connection identified above to the edge_index
    #------------------------------------------------------------------------------------------

    mapping = {key: value for key, value in zip(connections_res_num, new_indeces)}

    edge_index_prot = [[],[]]
    edge_attr_prot = []

    for index, neighbor_list in enumerate(connections): 
        for enzyme_residue in neighbor_list:
            
            edge_index_prot[0]+=[index]
            edge_index_prot[1]+=[mapping[enzyme_residue]]

            distance = np.linalg.norm(pos[index]-pos[mapping[enzyme_residue]])

            # Add the feature vector of the new edges to new_edge_attr (2x)
            non_cov_feature_vec =   [0.,0.,1.,                # non-covalent interaction
                                    distance/10,              # length divided by 10
                                    0.,0.,0.,0.,0.,           # bondtype = non-covalent
                                    0.,                       # is not conjugated
                                    0.,                       # is not in ring
                                    0.,0.,0.,0.,0.,0.]        # No stereo -> non-covalent

            # Add the feature vector of the new edges to new_edge_attr
            edge_attr_prot.append(non_cov_feature_vec)

    edge_index_prot = torch.tensor(edge_index_prot, dtype=torch.int64)
    edge_attr_prot = torch.tensor(edge_attr_prot, dtype=torch.float64)
    #------------------------------------------------------------------------------------------


    # Merging the two edge_indeces and edge_attrs into an overall edge_index and edge_attr
    edge_index = torch.concatenate( [edge_index_lig, edge_index_prot], axis=1 )
    edge_attr = torch.concatenate( [edge_attr_lig, edge_attr_prot], axis=0 )


    # Make undirected and add remaining self-loops
    edge_index, edge_attr = make_undirected_with_self_loops(edge_index, edge_attr)
    edge_index_prot, edge_attr_prot = make_undirected_with_self_loops(edge_index_prot, edge_attr_prot)
    edge_index_lig, edge_attr_lig = make_undirected_with_self_loops(edge_index_lig, edge_attr_lig)

    # Master Node edge index. Connect all nodes to a hypothetical master node in a directed way
    # (information flows only from the ligand nodes into the master node)
    n_normal_nodes = x_lig_emb.shape[0] + x_prot_emb.shape[0] - 1 
    
    master_lig = [[i for i in range(x_lig_emb.shape[0])],
                  [n_normal_nodes for _ in range(x_lig_emb.shape[0])]]
    
    master_prot = [[i for i in range(max(master_lig[0])+1, n_normal_nodes)],
                   [n_normal_nodes for _ in range(x_prot_emb.shape[0]-1)]]

    edge_index_master_lig = torch.tensor(master_lig, dtype=torch.int64)
    edge_index_master_prot = torch.tensor(master_prot, dtype=torch.int64)


    # Check the shapes of the input tensors
    #------------------------------------------------------------------------------------------
    shape_inconsistency = False

    try:
        if pos.shape[1] != 3:
            log_string += f'Skipped - POS has shape {pos.shape}'
            log.write(log_string + "\n")
            skipped.append(complex_id)
            continue

        if x_lig_emb.shape[1] != num_atomfeatures+aa_emb_len:
            log_string += f'Skipped - x_lig_emb has shape {x_lig_emb.shape}'
            log.write(log_string + "\n")
            skipped.append(complex_id)
            continue

        if x_lig_aa.shape[1] != num_atomfeatures+len(amino_acids):
            log_string += f'Skipped - x_lig_aa has shape {x_lig_aa.shape}'
            log.write(log_string + "\n")
            skipped.append(complex_id)
            continue

        if x_prot_emb.shape[1] != aa_emb_len + num_atomfeatures:
            log_string += f'Skipped - x_prot_emb has shape {x_prot_emb.shape}'
            log.write(log_string + "\n")
            skipped.append(complex_id)
            continue

        if x_prot_aa.shape[1] != len(amino_acids) + num_atomfeatures:
            log_string += f'Skipped - x_prot_aa has shape {x_prot_aa.shape}'
            log.write(log_string + "\n")
            skipped.append(complex_id)
            continue

        if x_prot_aa.shape[0] != x_prot_emb.shape[0]:
            log_string += f'Skipped - Dimension 0 of x_prot_emb {x_prot_aa.shape} and x_prot_aa {x_prot_aa.shape} not identical '
            log.write(log_string + "\n")
            skipped.append(complex_id)
            continue

        if x_prot_aa.shape[0] + x_lig_emb.shape[0] != pos.shape[0]:
            log_string += f'Skipped - x_lig_emb {x_prot_aa.shape} and x_prot {x_prot_aa.shape} not consistent with POS {pos.shape} '
            log.write(log_string + "\n")
            skipped.append(complex_id)
            continue
        
        for edge_ind, edge_at in [(edge_index.shape, edge_attr.shape),(edge_index_lig.shape, edge_attr_lig.shape),(edge_index_prot.shape, edge_attr_prot.shape)]:
            if edge_ind[0] != 2 or edge_at[1] != num_edgefeatures or edge_ind[1]!=edge_at[0]:
                log_string += f'Skipped - edge indeces error: \
                        {edge_index.shape, edge_attr.shape}\n\
                        {edge_index_lig.shape, edge_attr_lig.shape}\n\
                        {edge_index_prot.shape, edge_attr_prot.shape}'
                log.write(log_string + "\n")
                shape_inconsistency = True

    except IndexError as e:
        log_string += 'Skipped -' + str(e)
        log.write(log_string + "\n")
        shape_inconsistency = True
            
    if shape_inconsistency:
        skipped.append(complex_id)
        continue
    #------------------------------------------------------------------------------------------



    # Retrieve the binding affinity and other metadata of the complex
    # -------------------------------------------------------------------------------------------
    affmetric_encoding = {'Ki':1., 'Kd':2.,'IC50':3.}
    precision_encoding = {'=':0., '>':1., '<':2., '>=':3., '<=':4., '~':5.}
   
    if 'Ki' in affinity_dict[complex_id].keys():
        affinity = affinity_dict[complex_id]['Ki']
        affinity_metric = 'Ki'
        
    elif 'Kd' in affinity_dict[complex_id].keys():
        affinity = affinity_dict[complex_id]['Kd']
        affinity_metric = 'Kd'

    elif 'IC50' in affinity_dict[complex_id].keys():
        affinity = affinity_dict[complex_id]['IC50']
        affinity_metric = 'IC50'

    resolution = affinity_dict[complex_id]['resolution']
    log_kd_ki = affinity_dict[complex_id]['log_kd_ki']
    precision = affinity_dict[complex_id]['precision']


    try: resolution = float(resolution)
    except ValueError: resolution = 0



    # Find out if the graph is part of the test or training data and save into the corresponding folder: 
    # -------------------------------------------------------------------------------------------

    log_string += 'Successful - Saved in '
    save_folders = []

    in_casf_2013 = False
    in_casf_2016 = False
    in_refined = False

    if complex_id in casf_2013_complexes:
        in_casf_2013 = True
        log_string += 'CASF2013 '
        save_folders.append(casf2013_folder)
        
    if complex_id in casf_2016_complexes:
        in_casf_2016 = True
        log_string += 'CASF2016 '
        save_folders.append(casf2016_folder)

    if (not in_casf_2013) and (not in_casf_2016):
        log_string += 'Training Data'
        save_folders.append(train_folder)
        in_refined = complex_id in refined_complexes



    metadata = [in_refined, affmetric_encoding[affinity_metric], resolution, precision_encoding[precision], float(log_kd_ki)]
    
    graph = Data(
        
            x_lig_emb = torch.tensor(x_lig_emb, dtype=torch.float64),
            x_lig_aa = torch.tensor(x_lig_aa, dtype=torch.float64),

            x_prot_emb = torch.tensor(x_prot_emb, dtype=torch.float64),
            x_prot_aa = torch.tensor(x_prot_aa, dtype=torch.float64),
                 
            edge_index = edge_index,
            edge_index_lig = edge_index_lig,
            edge_index_prot = edge_index_prot,

            edge_index_master_lig = edge_index_master_lig,
            edge_index_master_prot = edge_index_master_prot,

            edge_attr = edge_attr,
            edge_attr_lig = edge_attr_lig,
            edge_attr_prot = edge_attr_prot,

            pos = torch.tensor(pos, dtype=torch.float64),
            affinity= torch.tensor(affinity, dtype=torch.float64),

            id = complex_id,
            data = torch.tensor(metadata, dtype=torch.float64)
            )
    

    # Save the Graph
    for save_folder in save_folders:
        torch.save(graph, os.path.join(save_folder, f'{complex_id}_graph_{embedding_descriptor}.pt'))
    


    log.write(log_string + "\n")

print(f'Graph Generation Finished - Skipped Complexes {skipped}')
log.close()