# Please choose to run either of the following two sections according to whether you have predefined chains for your dataset

# 1 For reproducing results in PocketAnchor Paper or processing cutomized datasets with predefined PDB IDs and chains

## 1.1 Prepare or load the list file

### 1.1.1 For running our demo data, or reproducing the results in PocketAnchor paper, please just use the lists in ./lists/

In [2]:
# demo data
with open('lists/demo_list.txt') as f:
    list_task = [line.strip() for line in f.readlines()]
print(list_task)

['121p_A', '12as_AB', '13pk_ABCD', '16pk_A', '17gs_AB', '182l_A', '183l_A', '185l_A', '186l_A', '187l_A', '18gs_AB', '19gs_AB', '1a05_AB', '1a0f_AB', '1a0g_AB']


For reproducing results on COACH420, please use lists/pdbid_chain_list_COACH420.txt

For reproducing results on HOLO4k, please use lists/pdbid_chain_list_HOLO4k.txt

For reproducing results on PDBbind, please use lists/xxx.txt


### 1.1.2 For customized datasets:

In [3]:
# example
customized_list = ['121p_A', '12as_AB', '13pk_ABCD', '16pk_A', '17gs_AB', '182l_A', '183l_A', '185l_A', '186l_A', '187l_A', '18gs_AB', '19gs_AB', '1a05_AB', '1a0f_AB', '1a0g_AB']
name = 'customized_list'
with open('lists/{}.txt'.format(name), 'w') as f:
    for item in customized_list:
        f.write(item+'\n')

In [4]:
# write the list file
with open('lists/{}.txt'.format(name)) as f:
    list_task = [line.strip() for line in f.readlines()]
print(list_task)

['121p_A', '12as_AB', '13pk_ABCD', '16pk_A', '17gs_AB', '182l_A', '183l_A', '185l_A', '186l_A', '187l_A', '18gs_AB', '19gs_AB', '1a05_AB', '1a0f_AB', '1a0g_AB']


## 1.2 Download pdb files

In [5]:
from multiprocessing import Pool
import os

def download_one(pdbid):
    if not os.path.exists('MasifOutput/00-raw_pdbs/{}.pdb'.format(pdbid)):
        os.system('wget -P MasifOutput/00-raw_pdbs/ https://files.rcsb.org/download/{}.pdb'.format(pdbid))

###############################################
### please define this number according to 
### your computational resources
num_processes = 4 
###############################################

pdbid_list = [item[:4] for item in list_task]

with Pool(num_processes) as p:
    res = p.map(download_one, pdbid_list)

# 2 For processing datasets without predefined chains (need ligand files; e.g., PDBbind)

In [None]:
# exmaple:
pdbid_list = ['4bny']

## 2.1 Download pdb files

In [None]:
from multiprocessing import Pool
import os

def download_one(pdbid):
    if not os.path.exists('MasifOutput/00-raw_pdbs/{}.pdb'.format(pdbid)):
        os.system('wget -P MasifOutput/00-raw_pdbs/ https://files.rcsb.org/download/{}.pdb'.format(pdbid))

###############################################
### please define this number according to 
### your computational resources
num_processes = 4 
###############################################

pdbid_list = [item[:4] for item in list_task]

with Pool(num_processes) as p:
    res = p.map(download_one, pdbid_list)

## 2.2 Select proper chains

In [None]:
from pymol import cmd
import numpy as np
import pandas as pd
import os, time, sys, pickle
from collections import defaultdict
from sklearn.metrics import pairwise_distances

prefix = "/data/tiantingzhong/Drug/data11_process/"
pdb_datapath = prefix + "PDB/raw/"
pdbbind_datapath = prefix + "PDBbind2020/pdbbind/"
pdbchains_path = prefix + "PDB/cpi_chain/"



In [6]:
def get_biomolecule(pdb_filename):
    dict_temp = {}
    with open(os.path.join(pdb_filename), 'r') as f:
        line = f.readline()
        while line:
            temp = line.strip().split()
            if temp[0] == "REMARK":
                if len(temp)>1 and temp[1] == "350":
                    if "BIOMOLECULE" in line:
                        biomol = temp[3]
                    elif "APPLY THE FOLLOWING TO CHAIN" in line:
                        added = line.strip().replace(" ", "").split(':')[1].split(',')
                        added = [item for item in added if item != ""]
                        if biomol not in dict_temp:
                            dict_temp[biomol] = []
                        dict_temp[biomol].extend(added)
                    elif "AND CHAIN" in line:
                        added = line.strip().replace(" ", "").split(':')[1].split(',')
                        added = [item for item in added if item != ""]
                        if biomol not in dict_temp:
                            dict_temp[biomol] = []
                        dict_temp[biomol].extend(added)
            if temp[0] == "ATOM":
                break
            line = f.readline()
    return dict_temp


def select_biomolecule(pdb_filename, chain_dict, ligand_filename=None):
    distance_dict = {}
    cmd.reinitialize()
    cmd.load(pdb_datapath + pdbid + '.pdb')
    cmd.load(pdbbind_datapath + pdbid + '/' + pdbid + '_ligand.sdf')
    ligand_coords = []
    cmd.iterate_state(-1, '{}_ligand'.format(pdbid), 'ligand_coords.append([x,y,z])', space=locals())
        
    for biomol, chain_list in chain_dict.items():
        cmd.select('biomol_{}'.format(biomol), 'chain {}'.format('+'.join(chain_list)))
        coords = []
        cmd.iterate_state(-1, 'biomol_{}'.format(biomol), 'coords.append([x,y,z])', space=locals())
#         print('chain {}'.format('+'.join(chain_list)), len(coords), len(ligand_coords))
        if len(coords) > 0:
            dist = pairwise_distances(ligand_coords, coords)
            distance_dict[biomol] = dist.min()
        else:
            distance_dict[biomol] = np.nan
    return distance_dict

def get_all_chains(pdbid):
    cmd.reinitialize()
    cmd.load(pdb_datapath+'{}.pdb'.format(pdbid))
    cmd.remove('het')
    chains = set()
    for x in cmd.get_names():
        # print('x', x)
        for ch in cmd.get_chains(x):
            chains.add(ch)
            #print(x, " has chain ", ch)
    return chains

def get_remove_chains(pdbid):
    cmd.reinitialize()
    cmd.load(pdb_datapath+'{}.pdb'.format(pdbid))
    cmd.remove('het')
    cmd.load(pdbbind_datapath+'{}/{}_ligand.sdf'.format(pdbid, pdbid))
    cmd.select('near_ligand', '{}_ligand expand 0.1'.format(pdbid))
    chains = set()
    cmd.iterate('near_ligand', 'chains.add(chain)', space=locals())
    if '' in chains:
        chains = chains - {''}
    return chains


def get_chains_one_sample(pdbid, isprint=False, isread=True):
    
    if isread:
        if os.path.exists(pdbchains_path + pdbid):
            with open(pdbchains_path + pdbid, 'r') as f:
                chains = f.readline().strip()
            return chains
    
    biomol_dict = get_biomolecule(pdbid)
    if isprint:
        print('biomol_dict', biomol_dict)
    try:
        result_dict = select_biomolecule(pdbid, biomol_dict)
    except Exception as E:
        print(pdbid, E)
        return ""
    if isprint:
        print('result_dict', result_dict)
    if len(biomol_dict) == 0:
        select_chains = get_all_chains(pdbid)
    else:
        select_key, selected_value = '', 99999
        for key, value in result_dict.items():
            if value != np.nan and value < selected_value:
                selected_value = value
                select_key = key
        if isprint:
            print('select_key', select_key)
        assert select_key != '', pdbid
        select_chains = biomol_dict[select_key]
    if isprint:
        print('select_chains', select_chains)
    remove_chains = get_remove_chains(pdbid)
    if isprint:
        print('remove_chains', remove_chains)
    final_chains  = list(set(select_chains) - remove_chains)
    if isprint:
        print('final_chains', final_chains)
        
    chains = "".join(sorted(final_chains))
    with open(pdbchains_path + pdbid, 'w') as f:
        f.write(chains)
        
    return chains