In [1]:
#@title Download Uni-Mol

#@markdown Please execute this cell by pressing the *Play* button on 
#@markdown the left.
GIT_REPO = 'https://github.com/dptech-corp/Uni-Mol'
UNICORE_URL = 'https://github.com/dptech-corp/Uni-Core/releases/download/0.0.1/unicore-0.0.1+cu113torch1.12.1-cp37-cp37m-linux_x86_64.whl'
DOCKING_DATA_URL = 'https://unimol.dp.tech/data/finetune/protein_ligand_binding_pose_prediction.tar.gz'
DOCKING_WEIGHT_URL = 'https://unimol.dp.tech/ckp/bindind_pose/binding_pose_220908.pt'
!rm *.whl
!wget  {UNICORE_URL} 
!pip3 -q install "unicore-0.0.1+cu113torch1.12.1-cp37-cp37m-linux_x86_64.whl"
!rm -rf ./Uni-Mol
!git clone -b main {GIT_REPO}
!pip3 -q install ./Uni-Mol
!pip install rdkit
!pip install biopandas
!wget  {DOCKING_DATA_URL}
!tar -xzf "protein_ligand_binding_pose_prediction.tar.gz"
!wget {DOCKING_WEIGHT_URL}
!pip install py3Dmol


rm: cannot remove '*.whl': No such file or directory
--2022-09-09 08:16:26--  https://github.com/dptech-corp/Uni-Core/releases/download/0.0.1/unicore-0.0.1+cu113torch1.12.1-cp37-cp37m-linux_x86_64.whl
Resolving github.com (github.com)... 140.82.113.3
Connecting to github.com (github.com)|140.82.113.3|:443... connected.
HTTP request sent, awaiting response... 302 Found
Location: https://objects.githubusercontent.com/github-production-release-asset-2e65be/512317326/151c1bd4-8199-4be8-b2b7-9969d7415f4e?X-Amz-Algorithm=AWS4-HMAC-SHA256&X-Amz-Credential=AKIAIWNJYAX4CSVEH53A%2F20220909%2Fus-east-1%2Fs3%2Faws4_request&X-Amz-Date=20220909T081626Z&X-Amz-Expires=300&X-Amz-Signature=44f903a580a21b0370099db2ca034ae24f3df4e8bd03999279c50037309deaca&X-Amz-SignedHeaders=host&actor_id=0&key_id=0&repo_id=512317326&response-content-disposition=attachment%3B%20filename%3Dunicore-0.0.1%2Bcu113torch1.12.1-cp37-cp37m-linux_x86_64.whl&response-content-type=application%2Foctet-stream [following]
--2022-09-09 

In [2]:
import os
import sys 
import numpy as np 
import pandas as pd
import biopandas
import lmdb
from biopandas.pdb import PandasPdb
from rdkit import Chem
from rdkit.Chem import AllChem
from sklearn.cluster import KMeans
from rdkit.Chem import rdMolTransforms
from rdkit.Chem.rdMolAlign  import AlignMolConformers
from unimol.utils.docking_utils import docking_data_pre, ensemble_iterations
from tqdm import tqdm 
import pickle
import re
import json
import copy

CASF_PATH = 'protein_ligand_binding_pose_prediction/CASF-2016'
main_atoms = ['N', 'CA', 'C', 'O', 'H']

def load_from_CASF(pdb_id):
  try:
    pdb_path = os.path.join(CASF_PATH, 'casf2016', pdb_id+'_protein.pdb')
    pmol = PandasPdb().read_pdb(pdb_path)
    pocket_residues = json.load(open(os.path.join(CASF_PATH, 'casf2016.pocket.json')))[pdb_id]
    return pmol, pocket_residues
  except:
    print("Currently not support parsing pdb and pocket info from local files.")

def normalize_atoms(atom):
    return re.sub("\d+", "", atom)

def single_conf_gen(tgt_mol, num_confs=1000, seed=42, removeHs=True):
    mol = copy.deepcopy(tgt_mol)
    mol = Chem.AddHs(mol)
    allconformers = AllChem.EmbedMultipleConfs(mol, numConfs=num_confs, randomSeed=seed, clearConfs=True)
    sz = len(allconformers)
    for i in range(sz):
        try:
            AllChem.MMFFOptimizeMolecule(mol, confId=i)
        except:
            continue
    if removeHs:
        mol = Chem.RemoveHs(mol)
    return mol

def clustering_coords(mol, M=1000, N=100, seed=42, removeHs=True):
    rdkit_coords_list = []
    rdkit_mol = single_conf_gen(mol, num_confs=M, seed=seed, removeHs=removeHs)
    noHsIds = [rdkit_mol.GetAtoms()[i].GetIdx() for i in range(len(rdkit_mol.GetAtoms())) if rdkit_mol.GetAtoms()[i].GetAtomicNum()!=1]
    ### exclude hydrogens for aligning
    AlignMolConformers(rdkit_mol, atomIds=noHsIds)
    sz = len(rdkit_mol.GetConformers())
    for i in range(sz):
        _coords = rdkit_mol.GetConformers()[i].GetPositions().astype(np.float32)
        rdkit_coords_list.append(_coords)

    ### exclude hydrogens for clustering
    rdkit_coords_flatten = np.array(rdkit_coords_list)[:, noHsIds].reshape(sz,-1)
    ids = KMeans(n_clusters=N, random_state=seed).fit_predict(rdkit_coords_flatten).tolist()
    coords_list = [rdkit_coords_list[ids.index(i)] for i in range(N)]
    return coords_list

def parser(pdb_id, smiles):
    pmol, pocket_residues = load_from_CASF(pdb_id)
    pname = pdb_id
    pro_atom = pmol.df['ATOM']
    pro_hetatm = pmol.df['HETATM']

    pro_atom['ID'] = pro_atom['chain_id'].astype(str)+pro_atom['residue_number'].astype(str)
    pro_hetatm['ID'] = pro_hetatm['chain_id'].astype(str)+pro_hetatm['residue_number'].astype(str)

    pocket = pd.concat([pro_atom[pro_atom['ID'].isin(pocket_residues)],
               pro_hetatm[pro_hetatm['ID'].isin(pocket_residues)]], 
               axis=0,
               ignore_index=True)

    pocket['normalize_atom'] = pocket['atom_name'].map(normalize_atoms)
    pocket = pocket[pocket['normalize_atom']!='']
    patoms = pocket['atom_name'].apply(normalize_atoms).values.tolist()
    pcoords = [pocket[['x_coord', 'y_coord', 'z_coord']].values]
    side = [0 if a in main_atoms else 1 for a in patoms]
    residues = (pocket['chain_id'].astype(str) + pocket['residue_number'].astype(str)).values.tolist()


    # generate ligand conformation
    M, N = 100, 10
    mol = Chem.MolFromSmiles(smiles)
    mol = Chem.AddHs(mol)
    AllChem.EmbedMolecule(mol, randomSeed=42)
    latoms = [atom.GetSymbol() for atom in mol.GetAtoms()]
    holo_coordinates = [mol.GetConformer().GetPositions().astype(np.float32)]
    holo_mol = mol
    coordinate_list = clustering_coords(mol, M=M, N=N, seed=42, removeHs=False)
    mol_list = [mol]*N

    print(len(patoms))
    print(len(pcoords))

    return pickle.dumps({
        'atoms':latoms, 
        'coordinates':coordinate_list, 
        'mol_list':mol_list, 
        'pocket_atoms':patoms, 
        'pocket_coordinates':pcoords, 
        'side':side,
        'residue':residues, 
        'holo_coordinates':holo_coordinates, 
        'holo_mol':holo_mol,
        'holo_pocket_coordinates':pcoords, 
        'smi':smiles, 
        'pocket':pname, 
        }, protocol=-1)

def write_lmdb(pdb_id, smiles_list, result_dir='./results'):
  os.makedirs(result_dir, exist_ok=True)
  outputfilename = os.path.join(result_dir, pdb_id+'.lmdb')
  try:
    os.remove(outputfilename)
  except:
    pass
  env_new = lmdb.open(
      outputfilename,
      subdir=False,
      readonly=False,
      lock=False,
      readahead=False,
      meminit=False,
      max_readers=1,
      map_size=int(10e9),
  )
  for i,smiles in enumerate(smiles_list):
    inner_output = parser(pdb_id, smiles)
    txn_write = env_new.begin(write=True)
    txn_write.put(f'{i}'.encode("ascii"), inner_output)
  txn_write.commit()
  env_new.close()

#@title Uni-Mol Docking 

#@markdown Currently this scripts only support CASF-2016 dataset with given pockets residues.

#@markdown You can dock serveral smiles to a given pocket by input smiles split by ','. 

#@markdown If no smiles is given, target ligand will redock to this pocket.

pdb_id = '1o3f'  #@param {type:"string"}
supp = Chem.SDMolSupplier(os.path.join(CASF_PATH,'casf2016',pdb_id+'_ligand.sdf'))
mol = [mol for mol in supp if mol][0]
ori_smiles = Chem.MolToSmiles(mol)
smiles = ''  #@param {type:"string"}
data_path="./protein_ligand_binding_pose_prediction"  
results_path="./results/"  
weight_path="/content/binding_pose_220908.pt"
batch_size=8
dist_threshold=8.0
recycling=3
if smiles.split(',')==0 or smiles == '':
  print('No other smiles inputs')
  smiles_list = [ori_smiles]
else:
  print('Docking with smiles: {}'.smiles)
  smiles_list = smiles.split(',')

write_lmdb(pdb_id, smiles_list, result_dir='./protein_ligand_binding_pose_prediction')
!python ./Uni-Mol/unimol/infer.py --user-dir ./unimol $data_path --valid-subset $pdb_id \
       --results-path $results_path \
       --num-workers 8 --ddp-backend=c10d --batch-size $batch_size \
       --task docking_pose --loss docking_pose --arch docking_pose \
       --path $weight_path \
       --fp16 --fp16-init-scale 4 --fp16-scale-window 256 \
       --dist-threshold $dist_threshold --recycling $recycling \
       --log-interval 50 --log-format simple

def generate_docking_input(predict_file, reference_file, tta_times=10, output_dir='./results'):
  (
    mol_list,
    smi_list,
    pocket_list,
    pocket_coords_list,
    distance_predict_list,
    holo_distance_predict_list,
    holo_coords_list,
    holo_center_coords_list,
  ) = docking_data_pre(reference_file, predict_file)
  iter = ensemble_iterations(
    mol_list,
    smi_list,
    pocket_list,
    pocket_coords_list,
    distance_predict_list,
    holo_distance_predict_list,
    holo_coords_list,
    holo_center_coords_list,
    tta_times=tta_times,
  )
  for i, content in enumerate(iter):
    pocket = content[3]
    output_name = os.path.join(output_dir, "{}.{}.pkl".format(pocket,i))
    try:
      os.remove(output_name)
    except:
      pass
    pd.to_pickle(content, output_name)

predict_file=os.path.join(results_path, 'content_'+pdb_id+'.out.pkl')
reference_file=os.path.join(data_path, pdb_id+'.lmdb')   
generate_docking_input(predict_file, reference_file, tta_times=10, output_dir=results_path)
for i,smiles in enumerate(smiles_list):
  print('Docking {}'.format(smiles))
  input_path = os.path.join(results_path, "{}.{}.pkl".format(pdb_id,i))
  ligand_path = os.path.join(results_path, "docking.{}.{}.sdf".format(pdb_id,i))
  cmd = "python ./Uni-Mol/unimol/utils/coordinate_model.py --input {} --output-ligand {}".format(input_path,ligand_path)
  os.system(cmd)



No other smiles inputs
350
1
2022-09-09 08:18:27 | INFO | unimol.inference | loading model(s) from /content/binding_pose_220908.pt
2022-09-09 08:18:28 | INFO | unimol.tasks.docking_pose | ligand dictionary: 30 types
2022-09-09 08:18:28 | INFO | unimol.tasks.docking_pose | pocket dictionary: 9 types
2022-09-09 08:18:33 | INFO | unimol.inference | Namespace(activation_dropout=0.0, activation_fn='gelu', adam_betas='(0.9, 0.999)', adam_eps=1e-08, all_gather_list_size=16384, allreduce_fp32_grad=False, arch='docking_pose', attention_dropout=0.1, batch_size=8, batch_size_valid=8, bf16=False, bf16_sr=False, broadcast_buffers=False, bucket_cap_mb=25, conf_size=10, cpu=False, curriculum=0, data='./protein_ligand_binding_pose_prediction', data_buffer_size=10, ddp_backend='c10d', delta_pair_repr_norm_loss=-1.0, device_id=0, disable_validation=False, dist_threshold=8.0, distributed_backend='nccl', distributed_init_method=None, distributed_no_spawn=False, distributed_num_procs=1, distributed_port=-1

In [3]:
#@title Visualize Docking results

#@markdown Note: Target ligand in Complex is shown in red color

#@markdown Only visualize the first dock ligand

import py3Dmol
import matplotlib.pyplot as plt
pdb_path = os.path.join(CASF_PATH, 'casf2016', pdb_id+'_protein.pdb')
ligand_path = os.path.join(results_path, "docking.{}.{}.sdf".format(pdb_id,0))
gt_ligand_path = os.path.join(CASF_PATH,'casf2016',pdb_id+'_ligand.sdf')
view = py3Dmol.view()
view.removeAllModels()
view.setViewStyle({'style':'outline','color':'black','width':0.1})
pdb_path = os.path.join(CASF_PATH, 'casf2016', pdb_id+'_protein.pdb')
view.addModel(open(pdb_path,'r').read(),format='pdb')
Prot=view.getModel()
# Prot.setStyle({'cartoon':{'arrows':True, 'tubes':True, 'style':'oval', 'color':'white'}})
Prot.setStyle({'cartoon': {'colorscheme': {'prop':'b','gradient': 'roygb','min':0.5,'max':0.9}}})

view.addSurface(py3Dmol.VDW,{'opacity':0.6,'color':'white'})


view.addModel(open(ligand_path,'r').read(),format='sdf')
ref_m = view.getModel()
ref_m.setStyle({},{'stick':{'colorscheme':'greenCarbon','radius':0.2}})

view.addModel(open(gt_ligand_path,'r').read(),format='sdf')
ref_m = view.getModel()
ref_m.setStyle({},{'stick':{'colorscheme':'redCarbon','radius':0.2}})

view.zoomTo()
view.show()

In [4]:
#@title Download the prediction
#@markdown **The content of zip file**:
#@markdown 1. PDB formatted structures
#@markdown 2. Docking ligand SDF files
#@markdown 3. Target ligand SDF files.

from google.colab import files
file_lists = []
pdb_path = os.path.join(CASF_PATH, 'casf2016', pdb_id+'_protein.pdb')
file_lists.append(pdb_path)
for i in range(len(smiles_list)):
  ligand_path = os.path.join(results_path, "docking.{}.{}.sdf".format(pdb_id,i))
  file_lists.append(ligand_path)
gt_ligand_path = os.path.join(CASF_PATH,'casf2016',pdb_id+'_ligand.sdf')
file_lists.append(gt_ligand_path)

!zip -j {"unimol.docking."+pdb_id}.zip {" ".join(file_lists)}
files.download(f'{"unimol.docking."+pdb_id}.zip')

  adding: 1o3f_protein.pdb (deflated 76%)
  adding: docking.1o3f.0.sdf (deflated 76%)
  adding: 1o3f_ligand.sdf (deflated 72%)


<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>