In [1]:
import pandas as pd
import numpy as np
from glob import glob
from path import Path
import json
import seaborn as sns
import prody
from multiprocessing import Pool
from tqdm import tqdm
import itertools
import json
from rdkit import Chem
from rdkit.Chem import AllChem
from io import StringIO
from collections import OrderedDict, Counter
import traceback
import urllib
import pybel
from copy import deepcopy
from multiprocessing import Pool
from pymol import cmd

pd.set_option('display.max_columns', None)

from pocketdock.ligand.ligand_expo import LigandExpo
from pocketdock import utils
from pocketdock import pdb_tools



In [2]:
import logging
import logging.config

LOGGING = {
    'version': 1,
    'disable_existing_loggers': True,
    'formatters': {
        'default': {
            'format': '[%(levelname)s] %(asctime)s %(funcName)s [pid %(process)d] - %(message)s'
        }
    },
    'handlers': {
        'console': {
            'class': 'logging.StreamHandler',
            'formatter': 'default',
        }
    },
    'loggers': {
        'console': {
            'handlers': ['console'],
            'level': 'DEBUG',
            'propagate': False,
        }
    }
}

logging.config.dictConfig(LOGGING)
logger = logging.getLogger('console')


def get_file_handler(filename, mode='w', level='DEBUG'):
    h = logging.FileHandler(filename, mode=mode)
    h.setFormatter(logging.Formatter('[%(levelname)s] %(asctime)s %(funcName)s [pid %(process)d] - %(message)s'))
    h.setLevel(level)
    return h

In [3]:
CHAIN_INTERACTION_DIST = 6
MIN_POCKET_SIZE = 10
MAX_RESOLUTION = 4.0
LIGAND_GROUP_DIST = 5
MIN_LIGAND_TO_SYMMATES = 6
BSITE_SIM_RADIUS = 6
LIGAND_DOMAIN_CONTACT_RADIUS = 6
MAX_NUM_CHAINS = 20
COVALENT_LIGAND_DISTANCE = 1.8
FRAC_HEAVY_ATOMS_RESOLVED = 0.9

In [4]:
def _build_d2mat(crd1, crd2):
    return np.sum((crd1[:, None, :] - crd2[None, :, :])**2, axis=2)
            
def _calc_overlap(lig_ag, ref_ag, dist):
    dmat = prody.buildDistMatrix(lig_ag.heavy, ref_ag.heavy)
    return np.any(dmat <= dist, axis=1).sum().item() / lig_ag.heavy.numAtoms()

def _aln_to_mapping(aln1, aln2):
    assert len(aln1) == len(aln2)
    i = 0
    j = 0
    mapping = []
    for a, b in zip(aln1, aln2):
        if a != '-' and b != '-':
            mapping.append((i, j))
        if a != '-':
            i += 1
        if b != '-':
            j += 1
    return mapping

def _get_affinity(pdb_id, chemid):
    try:
        url = 'https://data.rcsb.org/rest/v1/core/entry/' + pdb_id.lower()
        with urllib.request.urlopen(url) as f:
            affs = json.loads(f.read().decode())
    except urllib.error.HTTPError as e:
        print('HTTPError for', url)
        affs = {}
        
    for x in affs.get('rcsb_binding_affinity', []):
        if x['comp_id'] == chemid:
            return x
    return None

def calc_identity(aln1, aln2):
    return sum([x == y for x, y in zip(aln1, aln2)]) / len(aln1.replace('-', ''))

In [5]:
#get_cases_for_pdb('1QUR')

In [6]:
three2one = {'SER': 'S', 'LYS': 'K', 'PRO': 'P', 'ALA': 'A', 'ASP': 'D', 'ARG': 'R', 'VAL': 'V', 'CYS': 'C', 'HIS': 'H',
         'ASX': 'B', 'PHE': 'F', 'MET': 'M', 'LEU': 'L', 'ASN': 'N', 'TYR': 'Y', 'ILE': 'I', 'GLN': 'Q', 'THR': 'T',
         'GLX': 'Z', 'GLY': 'G', 'TRP': 'W', 'GLU': 'E'}

residue_bonds_noh = {
    'GLY': {'OXT': ('C',), 'C': ('CA', 'O', 'OXT'), 'O': ('C',), 'N': ('CA',), 'CA': ('N', 'C')},
    'ALA': {'OXT': ('C',), 'C': ('CA', 'O', 'OXT'), 'O': ('C',), 'N': ('CA',), 'CA': ('N', 'C', 'CB'), 'CB': ('CA',)},
    'CYS': {'OXT': ('C',), 'C': ('CA', 'O', 'OXT'), 'O': ('C',), 'N': ('CA',), 'CA': ('N', 'C', 'CB'),
            'CB': ('CA', 'SG'), 'SG': ('CB',)},
    'SER': {'OXT': ('C',), 'C': ('CA', 'O', 'OXT'), 'O': ('C',), 'N': ('CA',), 'CA': ('N', 'C', 'CB'),
            'CB': ('CA', 'OG'), 'OG': ('CB',)},
    'MET': {'OXT': ('C',), 'C': ('CA', 'O', 'OXT'), 'O': ('C',), 'N': ('CA',), 'CA': ('N', 'C', 'CB'),
            'CB': ('CA', 'CG'), 'CG': ('CB', 'SD'), 'SD': ('CG', 'CE'), 'CE': ('SD',)},
    'LYS': {'OXT': ('C',), 'C': ('CA', 'O', 'OXT'), 'O': ('C',), 'N': ('CA',), 'CA': ('N', 'C', 'CB'),
            'CB': ('CA', 'CG'), 'CG': ('CB', 'CD'), 'CD': ('CG', 'CE'), 'CE': ('CD', 'NZ'), 'NZ': ('CE',)},
    'ARG': {'OXT': ('C',), 'C': ('CA', 'O', 'OXT'), 'O': ('C',), 'N': ('CA',), 'CA': ('N', 'C', 'CB'),
            'CB': ('CA', 'CG'), 'CG': ('CB', 'CD'), 'CD': ('CG', 'NE'), 'NE': ('CD', 'CZ'), 'CZ': ('NE', 'NH1', 'NH2'),
            'NH1': ('CZ',), 'NH2': ('CZ',)},
    'GLU': {'OXT': ('C',), 'C': ('CA', 'O', 'OXT'), 'O': ('C',), 'N': ('CA',), 'CA': ('N', 'C', 'CB'),
            'CB': ('CA', 'CG'), 'CG': ('CB', 'CD'), 'CD': ('CG', 'OE1', 'OE2'), 'OE1': ('CD',), 'OE2': ('CD',)},
    'GLN': {'OXT': ('C',), 'C': ('CA', 'O', 'OXT'), 'O': ('C',), 'N': ('CA',), 'CA': ('N', 'C', 'CB'),
            'CB': ('CA', 'CG'), 'CG': ('CB', 'CD'), 'CD': ('CG', 'OE1', 'NE2'), 'OE1': ('CD',), 'NE2': ('CD',)},
    'ASP': {'OXT': ('C',), 'C': ('CA', 'O', 'OXT'), 'O': ('C',), 'N': ('CA',), 'CA': ('N', 'C', 'CB'),
            'CB': ('CA', 'CG'), 'CG': ('CB', 'OD1', 'OD2'), 'OD1': ('CG',), 'OD2': ('CG',)},
    'ASN': {'OXT': ('C',), 'C': ('CA', 'O', 'OXT'), 'O': ('C',), 'N': ('CA',), 'CA': ('N', 'C', 'CB'),
            'CB': ('CA', 'CG'), 'CG': ('CB', 'OD1', 'ND2'), 'OD1': ('CG',), 'ND2': ('CG',)},
    'LEU': {'OXT': ('C',), 'C': ('CA', 'O', 'OXT'), 'O': ('C',), 'N': ('CA',), 'CA': ('N', 'C', 'CB'),
            'CB': ('CA', 'CG'), 'CG': ('CB', 'CD1', 'CD2'), 'CD1': ('CG',), 'CD2': ('CG',)},
    'HIS': {'OXT': ('C',), 'C': ('CA', 'O', 'OXT'), 'O': ('C',), 'N': ('CA',), 'CA': ('N', 'C', 'CB'),
            'CB': ('CA', 'CG'), 'CG': ('CB', 'ND1', 'CD2'), 'ND1': ('CG', 'CE1'), 'CD2': ('CG', 'NE2'),
            'CE1': ('ND1', 'NE2'), 'NE2': ('CD2', 'CE1')},
    'PHE': {'OXT': ('C',), 'C': ('CA', 'O', 'OXT'), 'O': ('C',), 'N': ('CA',), 'CA': ('N', 'C', 'CB'),
            'CB': ('CA', 'CG'), 'CG': ('CB', 'CD1', 'CD2'), 'CD1': ('CG', 'CE1'), 'CD2': ('CG', 'CE2'),
            'CE1': ('CD1', 'CZ'), 'CE2': ('CD2', 'CZ'), 'CZ': ('CE1', 'CE2')},
    'TYR': {'OXT': ('C',), 'C': ('CA', 'O', 'OXT'), 'O': ('C',), 'N': ('CA',), 'CA': ('N', 'C', 'CB'),
            'CB': ('CA', 'CG'), 'CG': ('CB', 'CD1', 'CD2'), 'CD1': ('CG', 'CE1'), 'CD2': ('CG', 'CE2'),
            'CE1': ('CD1', 'CZ'), 'CE2': ('CD2', 'CZ'), 'CZ': ('CE1', 'CE2', 'OH'), 'OH': ('CZ',)},
    'TRP': {'OXT': ('C',), 'C': ('CA', 'O', 'OXT'), 'O': ('C',), 'N': ('CA',), 'CA': ('N', 'C', 'CB'),
            'CB': ('CA', 'CG'), 'CG': ('CB', 'CD1', 'CD2'), 'CD1': ('CG', 'NE1'), 'CD2': ('CG', 'CE2', 'CE3'),
            'NE1': ('CD1', 'CE2'), 'CE2': ('CD2', 'NE1', 'CZ2'), 'CE3': ('CD2', 'CZ3'), 'CZ3': ('CE3', 'CH2'),
            'CZ2': ('CE2', 'CH2'), 'CH2': ('CZ2', 'CZ3')},
    'VAL': {'OXT': ('C',), 'C': ('CA', 'O', 'OXT'), 'O': ('C',), 'N': ('CA',), 'CA': ('N', 'C', 'CB'),
            'CB': ('CA', 'CG1', 'CG2'), 'CG1': ('CB',), 'CG2': ('CB',)},
    'THR': {'OXT': ('C',), 'C': ('CA', 'O', 'OXT'), 'O': ('C',), 'N': ('CA',), 'CA': ('N', 'C', 'CB'),
            'CB': ('CA', 'OG1', 'CG2'), 'OG1': ('CB',), 'CG2': ('CB',)},
    'ILE': {'OXT': ('C',), 'C': ('CA', 'O', 'OXT'), 'O': ('C',), 'N': ('CA',), 'CA': ('N', 'C', 'CB'),
            'CB': ('CA', 'CG1', 'CG2'), 'CG1': ('CB', 'CD1'), 'CD1': ('CG1',), 'CG2': ('CB',)},
    'PRO': {'OXT': ('C',), 'C': ('CA', 'O', 'OXT'), 'O': ('C',), 'N': ('CA', 'CD'), 'CA': ('N', 'C', 'CB'),
            'CB': ('CA', 'CG'), 'CG': ('CB', 'CD'), 'CD': ('N', 'CG')},
}


def fetch_AF_model(uniprot_id):
    url = f'https://alphafold.ebi.ac.uk/files/AF-{uniprot_id}-F1-model_v1.pdb'
    with urllib.request.urlopen(url) as f:
        data = f.read().decode()
    return prody.parsePDBStream(StringIO(data))


def _find_closest_residue(resid_list, target, upstream=True):
    if not upstream:
        resid_list = reversed(resid_list)
    for resid in resid_list:
        if upstream and target <= resid:
            return resid
        if not upstream and target >= resid:
            return resid
    return None


def fetch_domains(instance_data, entity_to_pdb, pdb_residues, entity_len):
    data = instance_data['rcsb_polymer_instance_feature']
    cath_data = [x for x in data if x['type'] == 'CATH']
    entity_nums = sorted(entity_to_pdb.keys())
    
    domains = []
    for dom_id, dom_data in enumerate(cath_data):
        domain = OrderedDict(
            domain_id=dom_id,
            manually_filled=False,
            domain_chunks=[]
        )
        for chunk_data in dom_data['feature_positions']:
            # zero based position in the entity sequence
            chunk_ent_orig = [int(chunk_data['beg_seq_id'])-1, int(chunk_data['end_seq_id'])-1]
            # zero based position mapped to the closest residue with coordinates
            chunk_ent_map = [
                _find_closest_residue(entity_nums, chunk_ent_orig[0], True), 
                _find_closest_residue(entity_nums, chunk_ent_orig[1], False)
            ]
            # discard region if None
            if chunk_ent_map[0] is None or chunk_ent_map[1] is None:
                continue
            # pdb residue ids
            chunk_pdb = [
                pdb_residues[entity_to_pdb[chunk_ent_map[0]]].getResnum().item() if chunk_ent_map[0] is not None else None,
                pdb_residues[entity_to_pdb[chunk_ent_map[1]]].getResnum().item() if chunk_ent_map[1] is not None else None
            ]
            domain['domain_chunks'].append(
                OrderedDict(
                    chunk_ent_orig=chunk_ent_orig,
                    chunk_ent_map=chunk_ent_map,
                    chunk_pdb_resid=chunk_pdb
                )
            )
        if len(domain['domain_chunks']) == 0:
            continue
        domains.append(domain)
        
    #assert len(domains) > 0, domains
    
    # if there is no domains identified, fill them manually - whole structure
    if len(domains) == 0:
        chunk_ent_orig = [0, entity_len-1]
        chunk_ent_map = [
            _find_closest_residue(entity_nums, chunk_ent_orig[0], True), 
            _find_closest_residue(entity_nums, chunk_ent_orig[1], False)
        ]
        chunk_pdb = [
            pdb_residues[entity_to_pdb[chunk_ent_map[0]]].getResnum().item(),
            pdb_residues[entity_to_pdb[chunk_ent_map[1]]].getResnum().item()
        ]
        domains = [OrderedDict(
            domain_id=0,
            manually_filled=True,
            domain_chunks=[
                OrderedDict(
                    chunk_ent_orig=chunk_ent_orig,
                    chunk_ent_map=chunk_ent_map,
                    chunk_pdb_resid=chunk_pdb
                )
            ]
        )]
        
    return domains


def find_correct_instance(pdb_id, target_chain, chain_list):
    for cid in chain_list:
        url = f'https://data.rcsb.org/rest/v1/core/polymer_entity_instance/{pdb_id}/{cid}'
        with urllib.request.urlopen(url) as f:
            instance_data = json.loads(f.read().decode())   
        if instance_data['rcsb_polymer_entity_instance_container_identifiers']['auth_asym_id'] == target_chain:
            break
    
    if instance_data['rcsb_polymer_entity_instance_container_identifiers']['auth_asym_id'] != target_chain:
        return None
    
    return instance_data


def process_pdb_chain(outdir, pdb_id, chain):
    pdb_id = pdb_id.upper()
    outdir = Path(outdir).mkdir_p()
    ag, header = pdb_tools.get_atom_group(pdb_id, header=True)
    
    url = f'https://data.rcsb.org/rest/v1/core/entry/{pdb_id}'
    with urllib.request.urlopen(url) as f:
        entry_data = json.loads(f.read().decode())
    
    url = f'https://data.rcsb.org/rest/v1/core/polymer_entity_instance/{pdb_id}/{chain}'
    with urllib.request.urlopen(url) as f:
        instance_data = json.loads(f.read().decode())
    
    # yes, instance id can be different from the chain name, 
    # so we have to loop though all instances (chains) until 
    # we find the correct one
    buf = instance_data['rcsb_polymer_entity_instance_container_identifiers']
    if buf['asym_id'] != buf['auth_asym_id']:
        instance_data = find_correct_instance(pdb_id, chain, set(ag.getChids()))
    if instance_data is None:
        raise RuntimeError('Cannot find correct RESTful link for {pdb_id}/{chain}. RCSB and Author\'s namings are different')
    
    entity_id = instance_data['rcsb_polymer_entity_instance_container_identifiers']['entity_id']
    
    url = f'https://data.rcsb.org/rest/v1/core/polymer_entity/{pdb_id}/{entity_id}'
    with urllib.request.urlopen(url) as f:
        entity_data = json.loads(f.read().decode())
    assert chain in entity_data['entity_poly']['pdbx_strand_id'].split(',')
        
    url = f'https://data.rcsb.org/rest/v1/core/uniprot/{pdb_id}/{entity_id}'
    with urllib.request.urlopen(url) as f:
        uniprot_data = json.loads(f.read().decode())
        
    # ensure there is no non-standard linkage and aas
    assert entity_data['entity_poly']['nstd_linkage'] == 'no'
    assert entity_data['entity_poly']['nstd_monomer'] == 'no'
    
    # align residues that have coordinates in the pdb to the 
    # sequence announced in the entity description and ID
    # the missing residues as well as partially resolved residues
    entity_seq = entity_data['entity_poly']['pdbx_seq_one_letter_code']
    ag_chain = ag.select(f'chain {chain} and (stdaa or nonstdaa)').copy()
    residues = list(ag_chain.getHierView().iterResidues())
    pdb_seq = ''.join([three2one[x.getResname()] for x in residues])
    entity_aln, pdb_aln = utils.global_align(entity_seq, pdb_seq)[0][:2]
    
    # sanity check
    assert calc_identity(pdb_aln, entity_aln) > 0.90, 'Entity sequence doesn\'t match the chain'
    
    entity_to_pdb = dict(_aln_to_mapping(entity_aln, pdb_aln))
    missing_residue = []
    missing_ca = []
    missing_atoms = []
    for ent_i, ent_res in enumerate(entity_seq):
        if ent_i not in entity_to_pdb:
            missing_residue.append(True)
            missing_ca.append(True)
            missing_atoms.append(True)
            continue
        missing_residue.append(False)
        res = residues[entity_to_pdb[ent_i]]
        res_ideal_names = set(residue_bonds_noh[res.getResname()].keys())
        res_ideal_names.discard('OXT')
        res_cur_names = set(res.heavy.getNames())
        res_cur_names.discard('OXT')
        missing_ca.append('CA' not in res_cur_names)
        missing_atoms.append(res_ideal_names != res_cur_names)
        
    # write full pdb
    prody.writePDB(outdir / f'{pdb_id}.pdb', ag)
        
    # ag_chain
    prody.writePDB(outdir / 'rec_orig.pdb', ag_chain)
        
    # download alphafold model
    uniprot_id = uniprot_data[0]['rcsb_id']
    ag_af = fetch_AF_model(uniprot_id)
    prody.writePDB(outdir / 'AF_orig.pdb', ag_af)
    
    # align AF to crystal
    ag_af_aln, rmsd, (af_seq_aln, pdb_seq_aln) = pdb_tools.align(ag_af, ag_chain)
    rmsd = rmsd[0]
    prody.writePDB(outdir / 'AF_aln.pdb', ag_af_aln)
    
    AF_residues = list(ag_af.getHierView().iterResidues())
    AF_seq = ''.join([three2one[x.getResname()] for x in AF_residues])
    
    # domain data
    domains = fetch_domains(instance_data, entity_to_pdb, residues, len(entity_seq))
    
    entity_info = entity_data['entity_poly'].copy()
    entity_info['entity_aln'] = entity_aln
    entity_info['pdb_aln'] = pdb_aln
    #print(entity_aln)
    #print(pdb_aln)
    
    rec_dict = OrderedDict(
        case_name=pdb_id + '_' + chain,
        pdb_id=pdb_id,
        pdb_chain=chain,
        instance_id=instance_data['rcsb_polymer_entity_instance_container_identifiers']['asym_id'],
        entity_id=entity_id,
        experiment=header.get('experiment'), #entry_data['exptl']['method'],
        resolution=header.get('resolution'), #entry_data['exptl']['method'],
        deposition_date=entry_data['rcsb_accession_info']['deposit_date'],
        
        seqclus100=pdb_tools.CHAIN_TO_CLUSTER[100][pdb_id + '_' + chain],
        seqclus90=pdb_tools.CHAIN_TO_CLUSTER[90][pdb_id + '_' + chain],
        seqclus40=pdb_tools.CHAIN_TO_CLUSTER[40][pdb_id + '_' + chain],
        seqclus30=pdb_tools.CHAIN_TO_CLUSTER[30][pdb_id + '_' + chain],
        
        missing_residue=''.join([str(int(x)) for x in missing_residue]),
        missing_ca=''.join([str(int(x)) for x in missing_ca]),
        missing_atoms=''.join([str(int(x)) for x in missing_atoms]),
        
        entity_info=entity_info,
        
        uniprot=OrderedDict(
            uniprot_id=uniprot_id,
            uniprot_seq=uniprot_data[0]['rcsb_uniprot_protein']['sequence'],
            uniprot_name=uniprot_data[0]['rcsb_uniprot_protein']['name']['value'],
        ),
        
        domains=domains,
        
        alphafold=OrderedDict(
            seq=AF_seq,
            PDB_CA_RMSD=rmsd,
            AF_seq_aln=af_seq_aln,
            PDB_seq_aln=pdb_seq_aln,
            PDB_AF_identity=calc_identity(pdb_seq_aln, af_seq_aln),
            AF_confidence=ag_af.calpha.getBetas().tolist()
        )
    )
    return rec_dict

In [7]:
def _add_hs(out_mol, in_mol):
    # add hydrogens
    mol = next(pybel.readfile('mol', in_mol))
    mol.addh()
    mol.localopt('mmff94', steps=500)
    mol.write('mol', out_mol, overwrite=True)

    # fix back the coordinates
    mol = Chem.MolFromMolFile(out_mol, removeHs=False)
    ag = utils.mol_to_ag(mol)
    ref_ag = utils.mol_to_ag(Chem.MolFromMolFile(in_mol, removeHs=True))
    tr = prody.calcTransformation(ag.heavy, ref_ag.heavy)
    ag = tr.apply(ag)
    prody.writePDB(Path(out_mol).stripext() + '.pdb', ag)

    utils.change_mol_coords(mol, ag.getCoords())
    AllChem.ComputeGasteigerCharges(mol, throwOnParamFailure=False)
    Chem.MolToMolFile(mol, out_mol)


def _calc_rmsd(crd1, crd2):
    return np.square(crd1 - crd2).sum(1).mean().item()


def _bsite_similarity(mol_ag, mob_ag, ref_ag, mob_aln, ref_aln):
    ref_to_mob = dict(_aln_to_mapping(ref_aln, mob_aln))
    mob_residues = list(mob_ag.copy().getHierView().iterResidues())
    ref_residues = list(ref_ag.copy().getHierView().iterResidues())

    assert len(mob_residues) == len(mob_aln.replace('-', ''))
    assert len(ref_residues) == len(ref_aln.replace('-', ''))
    
    pocket = ref_ag.select(f'exwithin {BSITE_SIM_RADIUS} of sel', sel=mol_ag)
    pocket_resnums = set(pocket.getResnums())
    pocket_ref_residues = []
    pocket_mob_residues = []
    for i, x in enumerate(ref_residues):
        if x.getResnum() in pocket_resnums:
            pocket_ref_residues.append(x)
            mob_id = ref_to_mob.get(i, None)
            pocket_mob_residues.append(None if mob_id is None else mob_residues[mob_id])
    
    pocket_rmsd = None
    if not all([x is None for x in pocket_mob_residues]):
        pocket_rmsd = _calc_rmsd(
            np.stack([x['CA'].getCoords() for i, x in enumerate(pocket_ref_residues) if pocket_mob_residues[i] is not None]),
            np.stack([x['CA'].getCoords() for x in pocket_mob_residues if x is not None])
        )
    pocket_identity = sum([y is not None and x.getResname() == y.getResname() 
                           for x, y in zip(pocket_ref_residues, pocket_mob_residues)]) / len(pocket_ref_residues)
    return OrderedDict(
        nresidues=len(pocket_ref_residues),
        rmsd_ca=pocket_rmsd,
        identity=pocket_identity,
        resnums_pdb=[x.getResnum().item() for x in pocket_ref_residues],
        resnums_af=[x.getResnum().item() if x is not None else None for x in pocket_mob_residues],
        seq_pdb=''.join([x['CA'].getSequence() for x in pocket_ref_residues]),
        seq_af=''.join([x['CA'].getSequence() if x is not None else '-' for x in pocket_mob_residues])
    )
    

def process_ligand_group(case_dir, case_dict, group_dict):
    case_dir = Path(case_dir)
    group_dir = case_dir / group_dict['name'] 
    
    # add hs
    for item in group_dict['ligands']:
        _add_hs(group_dir / item['sdf_id'] + '_ah.mol', group_dir / item['sdf_id'] + '.mol')
    
    # make ligand group ag
    mols = [(x['sdf_id'], utils.mol_to_ag(Chem.MolFromMolFile(group_dir / x['sdf_id'] + '.mol'))) for x in group_dict['ligands']]
    combined_ag = mols[0][1].copy()
    for _, mol in mols[1:]:
        combined_ag += mol.copy()
        
    # define domains in contact
    pdb_ag = prody.parsePDB(case_dir / 'rec_orig.pdb')
    domains = case_dict['domains']
    bsite_ag = pdb_ag.select(f'exwithin {LIGAND_DOMAIN_CONTACT_RADIUS} of lig', lig=combined_ag).copy()
    interacting_domains = []
    for dom in domains:
        for chunk in dom['domain_chunks']:
            a, b = chunk['chunk_pdb_resid']
            if bsite_ag.select(f"resnum `{a} to {b}`") is not None:
                interacting_domains.append(dom['domain_id'])
                break
    
    #print(domains)
    #print(interacting_domains)
                
    # select domains in pdb
    resnums = []
    for dom_id in interacting_domains:
        dom = domains[dom_id]
        for chunk in dom['domain_chunks']:
            a, b = chunk['chunk_pdb_resid']
            resnums.append((a, b))
    #print('PD', ' or '.join([f'(resnum {a} to {b})' for a, b in resnums]))
    domains_pdb_ag = pdb_ag.select(' or '.join([f"(resnum `{a} to {b}`)" for a, b in resnums])).copy()
    prody.writePDB(group_dir / 'domains_crys.pdb', domains_pdb_ag)
    pdb_resnums = resnums
    
    # select domains in AF
    # align AF sequence to the entity sequence and select domain regions
    af_ag = prody.parsePDB(case_dir / 'AF_orig.pdb')
    entity_seq = case_dict['entity_info']['pdbx_seq_one_letter_code']
    af_residues = list(af_ag.getHierView().iterResidues())
    af_seq = ''.join([three2one[x.getResname()] for x in af_residues]) #case_dict['alphafold']['seq']
    entity_aln, af_aln = utils.global_align(entity_seq, af_seq)[0][:2]
    ent_to_af_map = dict(_aln_to_mapping(entity_aln, af_aln))
    ent_resids = list(ent_to_af_map.keys())
    #print(entity_aln)
    #print(af_aln)
    #print(ent_to_af_map)
    #print(ent_resids)
    #print('af_residues', len(af_residues))
    #print('af_seq', len(af_seq))
    resnums = []
    for dom_id in interacting_domains:
        dom = domains[dom_id]
        #print(dom)
        for chunk in dom['domain_chunks']:
            a, b = chunk['chunk_ent_orig']
            #print(a, b)
            #print(ent_to_af_map[_find_closest_residue(ent_resids, a, True)], 
            #      ent_to_af_map[_find_closest_residue(ent_resids, b, False)])
            a = af_residues[ent_to_af_map[_find_closest_residue(ent_resids, a, True)]].getResnum().item()
            b = af_residues[ent_to_af_map[_find_closest_residue(ent_resids, b, False)]].getResnum().item()
            resnums.append((a, b))
    #print('AF', ' or '.join([f'(resnum {a} to {b})' for a, b in resnums]))
    domains_af_ag = af_ag.select(' or '.join([f"(resnum `{a} to {b}`)" for a, b in resnums])).copy()
    prody.writePDB(group_dir / 'domains_AF.pdb', domains_af_ag)
    af_resnums = resnums
    
    # align AF domain to PDB domain
    domains_af_ag_aln, rmsd, (af_aln, pdb_aln) = pdb_tools.align(domains_af_ag, domains_pdb_ag)
    rmsd = rmsd[0]
    prody.writePDB(group_dir / 'domains_AF_aln.pdb', domains_af_ag_aln)
    
    domain_aligment = OrderedDict(
        interacting_domains=interacting_domains,
        pdb_resi_ranges=pdb_resnums,
        af_resi_ranges=af_resnums,
        af_aln=af_aln,
        pdb_aln=pdb_aln,
        pdb_af_identity=calc_identity(pdb_aln, af_aln),
        rmsd_aln=rmsd
    )
    
    bsite_analysis = _bsite_similarity(combined_ag, domains_af_ag_aln, domains_pdb_ag, af_aln, pdb_aln)
    return domain_aligment, bsite_analysis
    
    
#process_ligand_group('data/3WKE_A', case_dict, case_dict['ligand_groups'][0])
    

In [8]:
#get_cases_for_pdb('6bg0')

In [9]:
def get_chain(chemid, mol_rd, pdb_ag):
    mol_ag = utils.mol_to_ag(mol_rd)
    not_lig = pdb_ag.select('not resname ' + chemid).copy()
    
    pocket_all = not_lig.select(f'(not water) exwithin {CHAIN_INTERACTION_DIST} of sel', sel=mol_ag)
    assert pocket_all is not None
    assert len(pocket_all) > MIN_POCKET_SIZE
    
    dmat_min = np.sqrt(_build_d2mat(pocket_all.getCoords(), mol_ag.getCoords())).min()
    #print(dmat_min)
    assert dmat_min > COVALENT_LIGAND_DISTANCE, f'{dmat_min} <= {COVALENT_LIGAND_DISTANCE}'
    
    # exclude hetero because ligands in the same group can be in different chains
    # and count toward "counts"
    pocket = pocket_all.select('not hetero')
    counts = Counter(pocket.getChids())
    chain = sorted(counts.items(), key=lambda x: -x[1])[0][0]
    assert (counts[chain] / len(pocket)) > 0.8
    return chain


def assert_one_chemid_copy_per_chain(chemid, pdb_ag):
    for chain_ag in pdb_ag.getHierView().iterChains():
        chain_ag = chain_ag.protein
        if chain_ag is None:
            continue
        lig = pdb_ag.select(f'resname {chemid} within 6 of sel', sel=chain_ag)
        if lig is not None:
            assert len(set(lig.getResnums())) == 1
            

def get_crystall_lattice(pdb, cutoff=6):
    pdb_file = pdb_tools.pdb_file_path(pdb)
    cmd.reinitialize()
    cmd.load(pdb_file, 'orig')
    cmd.symexp('symm', 'orig', 'orig', cutoff)
    tmp_file = utils.tmp_file(prefix='alphadock-', suffix='.pdb')
    cmd.save(tmp_file, 'symm*')
    symm_ag = prody.parsePDB(tmp_file)
    Path(tmp_file).remove_p()
    cmd.delete('all')
    return symm_ag


#get_crystall_lattice('4ow1')


def get_cases_for_pdb(pdb_id):
    pdb_id = pdb_id.upper()
    wdir = Path('data').mkdir_p().mkdir_p()
    
    try:
        db_result = pdb_tools.get_atom_group(pdb_id, header=True)
        assert db_result is not None
        pdb_ag, header = db_result
        assert header.get('resolution', MAX_RESOLUTION + 10) <= MAX_RESOLUTION
        # avoid crazy large pdb files like ribosomes
        num_chains = len(set(pdb_ag.getChids()))
        assert num_chains <= MAX_NUM_CHAINS, f'{num_chains} > {MAX_NUM_CHAINS}'
    except AssertionError as e:
        logger.warning(f'{pdb_id} does not pass the criteria')
        logger.exception(e)
        return []
    
    chemids = LigandExpo.get_chemids_list(pdb_id)
    
    # find groups of ligands. here we treat groups of proximal ligands as 
    # whole instead of discarding them. such cases will represent cooperative binding
    mol_sdf_ids = []
    mol_rds = {}
    mol_rd_coords = {}
    for chemid in chemids:
        sdf_ids = LigandExpo.get_all_sdf_ids(pdb_id, chemid)
        for sdf_id in sdf_ids:
            mol_rd = LigandExpo.get_all_sdf_mol(sdf_id)
            if mol_rd is not None:
                mol_rds[sdf_id] = mol_rd
                mol_rd_coords[sdf_id] = mol_rd.GetConformer(0).GetPositions()
                mol_sdf_ids.append(sdf_id)
            else:
                logger.warning(f'Cannot read {sdf_id}')
    
    mol_groups = []
    for sdf_a in mol_sdf_ids:
        crd_a = mol_rd_coords[sdf_a]
        found = None
        for mol_group in mol_groups:
            for sdf_b in mol_group:
                crd_b = mol_rd_coords[sdf_b]
                d2mat = np.sum((crd_a[:, None, :] - crd_b[None, :, :])**2, axis=2)
                if np.any(d2mat < LIGAND_GROUP_DIST**2):
                    found = mol_group
                    break
            if found is not None:
                break
        if found is not None:
            found.append(sdf_a)
        else:
            mol_groups.append([sdf_a])
            
    #print(mol_groups)
            
    # build symmetry mates
    symm_mates_ag = get_crystall_lattice(pdb_id)
    
    # for each sdf_id
    detailed_groups = []
    for mol_group in mol_groups:
        detailed_group = []
        # if at least one ligand in the group is bad, we skip the whole group
        try:
            for sdf_id in mol_group:
                mol_rd = mol_rds[sdf_id]
                #assert mol_rd is not None
                
                chemid = sdf_id.split('_')[1]
                assert_one_chemid_copy_per_chain(chemid, pdb_ag)
                
                # check smiles and that all heavy atoms are present
                smi = LigandExpo.get_smiles(chemid)
                assert smi is not None
                
                mol_smi = Chem.MolFromSmiles(smi)
                nheavy = mol_smi.GetNumHeavyAtoms()
                frac_resolved = nheavy / mol_rd.GetNumHeavyAtoms()
                assert frac_resolved >= FRAC_HEAVY_ATOMS_RESOLVED, f'{nheavy} < {mol_rd.GetNumHeavyAtoms()}, ({smi})'
                
                # check that ligand interacts with a single chain
                pdb_chain = get_chain(chemid, mol_rd, pdb_ag)
                
                # check that there is no symm mates interaction
                mol_ag = utils.mol_to_ag(mol_rd)
                if symm_mates_ag is not None:
                    assert mol_ag.select(f'within {MIN_LIGAND_TO_SYMMATES} of symmmates', symmmates=symm_mates_ag) is None
                
                case = OrderedDict(
                    sdf_id=sdf_id, 
                    chemid=chemid,
                    smiles=smi,
                    pdb_id=pdb_id,
                    pdb_chain=pdb_chain,
                    num_heavy_atoms=mol_rd.GetNumHeavyAtoms(),
                    frac_resolved=frac_resolved,
                    affinity=_get_affinity(pdb_id, chemid)
                )
                detailed_group.append(case)
                
            # check that all chains in the group are the same
            inter_chains = set([x['pdb_chain'] for x in detailed_group])
            assert len(inter_chains) == 1, f'Multiple interacting chains: {inter_chains}'
            
        except AssertionError as e:
            logger.info(f'Dropping group {mol_group}')
            logger.exception(e)
            continue
            
        detailed_groups.append(detailed_group)
        
    if len(detailed_groups) == 0:
        logger.error('Groups are empty')
        return []
        
    # select unique (chain, chemids) groups
    unique_groups = []
    _desc_list = []
    for group in detailed_groups:
        seqclus100 = pdb_tools.CHAIN_TO_CLUSTER[100].get(pdb_id + '_' + group[0]['pdb_chain'])
        if seqclus100 is None:
            continue
        _desc = [seqclus100]
        _desc += sorted([x['chemid'] for x in group])
        _desc = tuple(_desc)
        if _desc in _desc_list:
            continue
        _desc_list.append(_desc)
        unique_groups.append(group)
        
    # prepare dict with chains
    chains_dict = OrderedDict()
    chains_list = sorted(set([x[0]['pdb_chain'] for x in unique_groups]))
    for chain in chains_list:
        try:
            chain_name = pdb_id + '_' + chain
            chains_dict[chain_name] = process_pdb_chain(wdir / chain_name, pdb_id, chain)
            chains_dict[chain_name]['ligand_groups'] = []
        except Exception as e:
            (wdir / chain_name).rmtree_p()
            logger.warning(f'Cannot process {chain_name}, skipping')
            logger.exception(e)
            continue
    
    # process ligand groups
    for group in unique_groups:
        try:
            chain_name = pdb_id + '_' + group[0]['pdb_chain']
            chain_dict = chains_dict.get(chain_name)
            if chain_dict is None:
                continue

            group_name = '_'.join(sorted([x['chemid'] for x in group]))
            group_dir = (wdir / chain_name / group_name).mkdir_p()
            group_dict = OrderedDict(
                name=group_name,
                ligands=group
            )

            for lig_dict in group:
                Chem.MolToMolFile(mol_rds[lig_dict['sdf_id']], group_dir / lig_dict['sdf_id'] + '.mol')

            domain_aligment, bsite_analysis = process_ligand_group(wdir / chain_name, chain_dict, group_dict)
            group_dict['domain_aligment'] = domain_aligment
            group_dict['bsite_analysis'] = bsite_analysis
            chain_dict['ligand_groups'].append(group_dict)
        except Exception as e:
            logger.warning(f'Cannot process {group_name}, skipping')
            logger.exception(e)
        
    cases = []
    for case in chains_dict.values():
        # if all ligand groups were discarded skip this chain
        if len(case['ligand_groups']) == 0:
            (wdir / case['case_name']).rmtree_p()
            continue
        utils.write_json(case, wdir / case['case_name'] / 'case.json')
        cases.append(case)
        
    return cases

In [10]:
#case_dicts = get_cases_for_pdb('3wke')

In [11]:
#case_dicts[0]

In [12]:
#set(['a',24]) in [set([24,'a'])]

In [13]:
#LigandExpo._PDB_TO_CC #.keys()

In [None]:
try:
    old_h = logger.handlers
    logger.handlers = [get_file_handler(Path('data') / 'log.txt')]
    for pdb in tqdm(list(LigandExpo._PDB_TO_CC.keys())):
        logger.info(f'============ Processing {pdb} ===========')
        get_cases_for_pdb(pdb)
        logger.info(f'============  Finished {pdb}  ===========\n\n')
finally:
    logger.handlers[0].close()
    logger.handlers = old_h

In [None]:
get_cases_for_pdb('5A1P')

In [None]:
get_cases_for_pdb('5yqb')

In [None]:
Chem.MolFromSmiles('[H]/N=C(\c1ccc(cc1)C[C@H](C(=O)N2CCCCC2)NC(=O)[C@H](CCC(=O)O)NS(=O)(=O)c3ccc4ccccc4c3)/N')

In [None]:
def run_mp(args):
    chemid, pdb = args
    res = {'chemid': chemid, 'pdb': pdb}
    try:
        cases = get_cases(chemid, pdb)
        res['cases'] = cases
    except Exception:
        res.update({'exc': traceback.format_exc()})
    return res


# In[194]:


def get_args():
    for chemid, pdbs in LigandExpo._CC_TO_PDB.items():
        for pdb in pdbs:
            yield chemid, pdb

with Pool(32) as p:
    ntotal = sum(len(x) for x in LigandExpo._CC_TO_PDB.values())
    for r in tqdm(p.imap_unordered(run_mp, get_args()), total=ntotal):
        if 'exc' in r:
            print(r)
            print(r['exc'])
            continue