### Please note before running:
Commit: b962451 (b962451a019e15363bd34b3af9d3a3cd02330947)

Workspace path: Uni-Mol

Notebook path: Uni-Mol/unimol_posebuster_demo.ipynb

### Import modules

In [None]:
import os
import pickle
import numpy as np
import pandas as pd
from rdkit import Chem, RDLogger
from rdkit.Chem import AllChem
from tqdm import tqdm
RDLogger.DisableLog('rdApp.*')  
import warnings
warnings.filterwarnings(action='ignore')
from multiprocessing import Pool
import copy
import lmdb
from biopandas.pdb import PandasPdb
from sklearn.cluster import KMeans
from rdkit.Chem.rdMolAlign  import AlignMolConformers

### Preprocess func for generating the LMDB file

In [None]:
# allowed atom types 
main_atoms = ['N', 'CA', 'C', 'O', 'H']
allow_pocket_atoms = ['C', 'H', 'N', 'O', 'S']

def cal_configs(coords):
    """Calculate pocket configs"""

    centerx,centery,centerz = list((np.max(coords,axis=0)+np.min(coords,axis=0))/2)
    sizex,sizey,sizez = list(np.max(coords,axis=0)-np.mean(coords,axis=0))
    config = {'cx':centerx,'cy':centery,'cz':centerz,
                'sx':sizex,'sy':sizey,'sz':sizez}
            
    return config,centerx,centery,centerz,sizex,sizey,sizez


def filter_pocketatoms(atom):
    if atom[:2] in ['Cd','Cs', 'Cn', 'Ce', 'Cm', 'Cf', 'Cl', 'Ca', \
                    'Cr', 'Co', 'Cu', 'Nh', 'Nd', 'Np', 'No', 'Ne', 'Na',\
                     'Ni','Nb', 'Os', 'Og', 'Hf', 'Hg', 'Hs', 'Ho', 'He',\
                     'Sr', 'Sn', 'Sb', 'Sg', 'Sm', 'Si', 'Sc', 'Se']:
        return None
    if atom[0] >= '0' and atom[0] <= '9':
        return filter_pocketatoms(atom[1:])
    if atom[0] in ['Z','M','P','D','F','K','I','B']:
        return None
    if atom[0] in allow_pocket_atoms:
        return atom
    return atom


def single_conf_gen(tgt_mol, num_confs=1000, seed=42, removeHs=True):
    mol = copy.deepcopy(tgt_mol)
    mol = Chem.AddHs(mol)
    allconformers = AllChem.EmbedMultipleConfs(mol, numConfs=num_confs, randomSeed=seed, clearConfs=True)
    sz = len(allconformers)
    for i in range(sz):
        try:
            AllChem.MMFFOptimizeMolecule(mol, confId=i)
        except:
            continue
    if removeHs:
        mol = Chem.RemoveHs(mol)
    return mol


def clustering_coords(mol, M=1000, N=100, seed=42, removeHs=True, method='bonds'):
    rdkit_coords_list = []
    if method == 'rdkit_MMFF':
        rdkit_mol = single_conf_gen(mol, num_confs=M, seed=seed, removeHs=removeHs)
    else:
        print('no conformer generation methods:{}'.format(method))
        raise 
    noHsIds = [rdkit_mol.GetAtoms()[i].GetIdx() for i in range(len(rdkit_mol.GetAtoms())) if rdkit_mol.GetAtoms()[i].GetAtomicNum()!=1]
    ### exclude hydrogens for aligning
    AlignMolConformers(rdkit_mol, atomIds=noHsIds)
    sz = len(rdkit_mol.GetConformers())
    for i in range(sz):
        _coords = rdkit_mol.GetConformers()[i].GetPositions().astype(np.float32)
        rdkit_coords_list.append(_coords)

    ### exclude hydrogens for clustering
    rdkit_coords_flatten = np.array(rdkit_coords_list)[:, noHsIds].reshape(sz,-1)
    ids = KMeans(n_clusters=N, random_state=seed).fit_predict(rdkit_coords_flatten).tolist()
    coords_list = [rdkit_coords_list[ids.index(i)] for i in range(N)]
    return coords_list


def extract_pose_posebuster(content):

    pdbid, ligid, protein_path, ligand_path, index = content

    def read_pdb(path, pdbid):
        #### protein preparation
        pfile = os.path.join(path, pdbid+'.pdb')
        pmol = PandasPdb().read_pdb(pfile)
        
        return pmol

    ### totally posebuster data
    def read_mol(path, pdbid, ligid):
        lsdf = os.path.join(path, f'{pdbid}_{ligid}.sdf')
        supp = Chem.SDMolSupplier(lsdf)
        mols = [mol for mol in supp if mol]
        if len(mols) == 0:
            print(lsdf)
        mol = mols[0]
        return mol

    # influence pocket size
    dist_thres=6
    if pdbid == 'index' or pdbid == 'readme':
        return None

    pmol = read_pdb(protein_path, pdbid)
    pname = pdbid
    mol = read_mol(ligand_path, pdbid, ligid)
    mol = Chem.RemoveHs(mol)
    lcoords = mol.GetConformer().GetPositions().astype(np.float32)
        
    pdf = pmol.df['ATOM']
    filter_std = []
    for lcoord in lcoords:
        pdf['dist'] = pmol.distance(xyz=list(lcoord), records=('ATOM'))
        df = pdf[(pdf.dist <= dist_thres) & (pdf.element_symbol != 'H')][['chain_id', 'residue_number']]
        filter_std += list(zip(df.chain_id.tolist(), df.residue_number.tolist()))

    filter_std = set(filter_std)
    patoms, pcoords, residues = [], np.empty((0,3)), []
    for id,res in filter_std:
        df = pdf[(pdf.chain_id == id) & (pdf.residue_number == res)]
        patoms += df['atom_name'].tolist()
        pcoords = np.concatenate((pcoords, df[['x_coord','y_coord','z_coord']].to_numpy()), axis=0)
        residues += [str(id)+str(res)]*len(df)

    if len(pcoords)==0:
        print('empty pocket:', pdbid)
        return None
    config,centerx,centery,centerz,sizex,sizey,sizez = cal_configs(pcoords)

    # filter unnormal atoms, include metal
    atoms, index, residues_tmp = [], [], []
    for i,a in enumerate(patoms):
        output = filter_pocketatoms(a)
        if output is not None:
            index.append(True)
            atoms.append(output)
            residues_tmp.append(residues[i])
        else:
            index.append(False)
    coordinates = pcoords[index].astype(np.float32)
    residues = residues_tmp

    assert len(atoms) == len(residues)
    assert len(atoms) == coordinates.shape[0]

    if len(atoms) != coordinates.shape[0]:
        print(pname)
        return None
    patoms = atoms
    pcoords = [coordinates]
    side = [0 if a in main_atoms else 1 for a in patoms]

    smiles = Chem.MolToSmiles(mol)
    mol = AllChem.AddHs(mol, addCoords=True)
    latoms = [atom.GetSymbol() for atom in mol.GetAtoms()]
    holo_coordinates = [mol.GetConformer().GetPositions().astype(np.float32)]
    holo_mol = mol
    
    M, N = 100, 10
    coordinate_list = clustering_coords(mol, M=M, N=N, seed=42, removeHs=False, method='rdkit_MMFF')
    mol_list = [mol]*N
    ligand = [latoms, coordinate_list, holo_coordinates, smiles, mol_list, holo_mol]

    return pname, patoms, pcoords, side, residues, config, ligand


def parser(content):
    pname, patoms, pcoords, side, residues, config, ligand = extract_pose_posebuster(content)
    latoms, coordinate_list, holo_coordinates, smiles, mol_list, holo_mol = ligand
    pickle.dumps({})
    return pickle.dumps(
        {
            "atoms": latoms,
            "coordinates": coordinate_list,
            "mol_list": mol_list,
            "pocket_atoms": patoms,
            "pocket_coordinates": pcoords,
            "side": side,
            "residue": residues,
            "config": config,
            "holo_coordinates": holo_coordinates,
            "holo_mol": holo_mol,
            "holo_pocket_coordinates": pcoords,
            "smi": smiles,
            'pocket':pname,
            'scaffold':pname,
        },
        protocol=-1,
    )


def write_lmdb(protein_path, ligand_path, outpath, meta_info_file, lmdb_name, nthreads=8):
    os.makedirs(outpath, exist_ok=True)
    df = pd.read_csv(meta_info_file)
    pdb_ids = list(df['pdb_code'].values)
    lig_ids = list(df['lig_code'].values)
    content_list = list(zip(pdb_ids, lig_ids, [protein_path]*len(pdb_ids), [ligand_path]*len(pdb_ids), range(len(pdb_ids))))
    outputfilename = os.path.join(outpath, lmdb_name +'.lmdb')
    try:
        os.remove(outputfilename)
    except:
        pass
    env_new = lmdb.open(
        outputfilename,
        subdir=False,
        readonly=False,
        lock=False,
        readahead=False,
        meminit=False,
        max_readers=1,
        map_size=int(100e9),
    )
    txn_write = env_new.begin(write=True)
    print("Start preprocessing data...")
    print(f'Number of systems: {len(pdb_ids)}')
    with Pool(nthreads) as pool:
        i = 0
        failed_num = 0
        for inner_output in tqdm(pool.imap(parser, content_list)):
            if inner_output is not None:
                txn_write.put(f"{i}".encode("ascii"), inner_output)
                i+=1
            elif inner_output is None: 
                failed_num += 1
        txn_write.commit()
        env_new.close()
    print(f'Total num: {len(pdb_ids)}, Success: {i}, Failed: {failed_num}')
    print("Done!")

### Generate `lmdb` from `pdb` and `sdf`

In [None]:
protein_path = 'eval_sets/posebusters/proteins'
ligand_path = 'eval_sets/posebusters/ligands'
outpath = 'posebuster_test'
meta_info_file = 'eval_sets/posebusters/posebuster_set_meta.csv'
lmdb_name = 'posebuster_428'
nthreads = 8

write_lmdb(protein_path, ligand_path, outpath, meta_info_file, lmdb_name, nthreads=nthreads)

### Infer with public ckp
The script is the same as it is in the [Readme](https://github.com/dptech-corp/Uni-Mol/tree/main/unimol#protein-ligand-binding-pose-prediction)

In [None]:
data_path=outpath
results_path="./infer_pose"  # replace to your results path
weight_path="./ckp/binding_pose_220908.pt"
batch_size=8
dist_threshold=8.0
recycling=3
valid_subset=lmdb_name
mol_dict_name='dict_mol.txt'
pocket_dict_name='dict_pkt.txt'

!cp ./example_data/molecule/dict.txt $data_path/$mol_dict_name
!cp ./example_data/pocket/dict_coarse.txt $data_path/$pocket_dict_name
!python ./unimol/infer.py --user-dir ./unimol $data_path --valid-subset $valid_subset \
       --results-path $results_path \
       --num-workers 8 --ddp-backend=c10d --batch-size $batch_size \
       --task docking_pose --loss docking_pose --arch docking_pose \
       --path $weight_path \
       --fp16 --fp16-init-scale 4 --fp16-scale-window 256 \
       --dist-threshold $dist_threshold --recycling $recycling \
       --log-interval 50 --log-format simple

### Docking and cal metrics:
The script is the same as it is in the [Readme](https://github.com/dptech-corp/Uni-Mol/tree/main/unimol#protein-ligand-binding-pose-prediction)

In [None]:
nthreads=8  # Num of threads
predict_file=f"{results_path}/ckp_{lmdb_name}.out.pkl"  # Your inference file dir
reference_file=f"{outpath}/{lmdb_name}.lmdb"  # Your reference file dir
output_path="./unimol_repro_posebuster428"  # Docking results path

!python ./unimol/utils/docking.py --nthreads $nthreads --predict-file $predict_file --reference-file $reference_file --output-path $output_path