In [1]:
import numpy as np
import pandas as pd
import os
from tqdm import tqdm 
from itertools import combinations

from rdkit import Chem
import torch
import time
import shutil
from pathlib import Path
import torch.nn.functional as F

from utils.volume_sampling import sample_discrete_number
from utils.volume_sampling import remove_output_files, run_fpocket, extract_values
from utils.templates import get_one_hot, get_pocket

from src.lightning_anchor_gnn import AnchorGNN_pl
from src.lightning import AR_DDPM
from scipy.spatial import distance
from Bio.PDB import PDBParser
from Bio.PDB.Polypeptide import is_aa, three_to_one

from analysis.reconstruct_mol import reconstruct_from_generated
from analysis.vina_docking import VinaDockingTask

from rdkit.Chem import rdmolfiles
from sampling.sample_mols import generate_mols_for_pocket

from openbabel import openbabel
import tempfile

atom_dict =  {'C': 0, 'N': 1, 'O': 2, 'S': 3, 'B': 4, 'Br': 5, 'Cl': 6, 'P': 7, 'I': 8, 'F': 9}
idx2atom = {0:'C', 1:'N', 2:'O', 3:'S', 4:'B', 5:'Br', 6:'Cl', 7:'P', 8:'I', 9:'F'}
CROSSDOCK_CHARGES = {'C': 6, 'O': 8, 'N': 7, 'F': 9, 'B':5, 'S': 16, 'Cl': 17, 'Br': 35, 'I': 53, 'P': 15}
pocket_atom_dict =  {'C': 0, 'N': 1, 'O': 2, 'S': 3} # only 4 atoms types for pocket
amino_acid_dict = {'A': 0, 'C': 1, 'D': 2, 'E': 3, 'F': 4, 'G': 5, 'H': 6, 'I': 7, 'K': 8, 'L': 9, 'M': 10, 'N': 11, 'P': 12, 'Q': 13, 'R': 14, 'S': 15, 'T': 16, 'V': 17, 'W': 18, 'Y': 19}
vdws = {'C': 1.7, 'N': 1.55, 'O': 1.52, 'S': 1.8, 'B': 1.92, 'Br': 1.85, 'Cl': 1.75, 'P': 1.8, 'I': 1.98, 'F': 1.47}

from utils.volume_sampling import extract_alpha_spheres_coords
from utils.visuals import write_xyz_file, visualize_3d_pocket_molecule, get_pocket_mol

In [2]:
def get_pocket(pdbfile, pocket_atom_dict, remove_H=True, ca_only=False):

    pdb_struct = PDBParser(QUIET=True).get_structure('', pdbfile)

    # find interacting pocket residues based on distance cutoff
    pocket_residues = []
    for residue in pdb_struct[0].get_residues():
        res_coords = np.array([a.get_coord() for a in residue.get_atoms()])
        if is_aa(residue.get_resname(), standard=True):
            pocket_residues.append(residue)

    pocket_ids = [f'{res.parent.id}:{res.id[1]}' for res in pocket_residues]
        
    if ca_only:
        try:
            pocket_one_hot = []
            pocket_coords = []
            for res in pocket_residues:
                for atom in res.get_atoms():
                    if atom.name == 'CA':
                        pocket_one_hot.append(np.eye(1, len(amino_acid_dict),
                        amino_acid_dict[three_to_one(res.get_resname())]).squeeze())
                    pocket_coords.append(atom.coord)
            pocket_one_hot = np.stack(pocket_one_hot)
            pocket_coords = np.stack(pocket_coords)
        except KeyError as e:
            raise KeyError(f'{e} not in amino acid dict ({pdbfile})')
    else: 
        full_atoms = np.concatenate([np.array([atom.element for atom in res.get_atoms()]) for res in pocket_residues], axis=0)
        full_coords = np.concatenate([np.array([atom.coord for atom in res.get_atoms()]) for res in pocket_residues], axis=0)
        full_atoms_names = np.concatenate([np.array([atom.get_id() for atom in res.get_atoms()]) for res in pocket_residues], axis=0)
        pocket_AA = np.concatenate([([three_to_one(atom.get_parent().get_resname()) for atom in res.get_atoms()]) for res in pocket_residues], axis=0)
        
        # removing Hs if present
        if remove_H:
            h_mask = full_atoms == 'H'
            full_atoms = full_atoms[~h_mask]
            pocket_coords = full_coords[~h_mask]
            full_atoms_names = full_atoms_names[~h_mask]
            pocket_AA = pocket_AA[~h_mask]
        try:
            pocket_one_hot = []
            for i in range(len(full_atoms)):
                a = full_atoms[i]
                aa = pocket_AA[i]
                atom_onehot = np.eye(1, len(pocket_atom_dict), pocket_atom_dict[a.capitalize()]).squeeze()
                amino_onehot = np.eye(1, len(amino_acid_dict), amino_acid_dict[aa.capitalize()]).squeeze()
                is_backbone = 1 if full_atoms_names[i].capitalize() in ['N','CA','C','O'] else 0
                pocket_one_hot.append(np.concatenate([atom_onehot, amino_onehot, (is_backbone,)]))
                
  
            pocket_one_hot = np.stack(pocket_one_hot)
        except KeyError as e:
            raise KeyError(
            f'{e} not in atom dict ({pdbfile})')

    pocket_one_hot = np.array(pocket_one_hot)
    return pocket_one_hot, pocket_coords

In [4]:
pdb = '2z3h.pdb'

In [5]:
k = 1 # pocket number identified by fpocket
pqr_file = pdb[:-4] + '_out/pockets/pocket' + str(k) + '_vert.pqr'
n_samples = 10 # number of samples to generate

In [6]:
# pdb of pocket only (no ligand and only maximum 4A around the pocket)
pocket_onehot, pocket_coords = get_pocket(pdb, pocket_atom_dict, remove_H=True, ca_only=False)

# use fpocket to identify the protein pocket
# NOTE: --------------------------
# fpocket can sometimes give you the wrong pocket, make sure to check the output and visualize the pocket

try:
    if  not os.path.exists(pdb[:-4] + '_out'):
        #shutil.rmtree(pdb[:-4] + '_out', ignore_errors=True
        print('running fpocket...')
        run_fpocket(pdb)
    pqr_file = pdb[:-4] + '_out/pockets/pocket' + str(k) + '_vert.pqr'
    alpha_spheres = np.array(extract_alpha_spheres_coords(pqr_file))

except:
    raise ValueError('fpocket failed!')
    #exit()

# ---------------  make a grid box around the pocket ----------------
min_coords = pocket_coords.min(axis=0) - 2.5 #
max_coords = pocket_coords.max(axis=0) + 2.5

x_range = slice(min_coords[0], max_coords[0] + 1, 1.5) # spheres of radius 1.5 (vdw radius of C)
y_range = slice(min_coords[1], max_coords[1] + 1, 1.5)
z_range = slice(min_coords[2], max_coords[2] + 1, 1.5)

grid = np.mgrid[x_range, y_range, z_range]
grid_points = grid.reshape(3, -1).T  # This transposes the grid to a list of coordinates

# make sure the pocket-number is correct and you identified the correct pocket
pqr_file = pdb[:-4] + '_out/pockets/pocket' + str(k) + '_vert.pqr'
alpha_spheres = np.array(extract_alpha_spheres_coords(pqr_file))

distances_spheres = distance.cdist(grid_points, alpha_spheres)
mask_spheres = (distances_spheres < 3).any(axis=1)
filtered_alpha_points = grid_points[mask_spheres]

# remove grid points that are close to the pocket
pocket_distances = distance.cdist(filtered_alpha_points, pocket_coords)
mask_pocket = (pocket_distances < 2).any(axis=1)
grids = filtered_alpha_points[~mask_pocket]

grids = torch.tensor(grids)

all_grids = [] # list of grids
for i in range(n_samples):
    all_grids.append(grids) 

pocket_vol = len(grids)
max_mol_sizes = []
for i in range(n_samples):
    max_mol_sizes.append(sample_discrete_number(pocket_vol))

pocket_size = len(pocket_coords)

max_mol_sizes = np.array(max_mol_sizes)
print('maximum molecule sizes', max_mol_sizes)

maximum molecule sizes [32 25 28 25 27 28 27 27 27 30]




choosing pocket anchors from pocket atoms using the fpocket alpha spheres

In [15]:
# NOTE: choose pocket anchors from the pocket atoms that are close to alpha sphere
alpha_spheres_pocket_distances = distance.cdist(pocket_coords, alpha_spheres)
possible_pocket_anchors = np.argsort((alpha_spheres_pocket_distances < 4.5).sum(1))[::-1][:7]
pocket_anchors = np.random.choice(possible_pocket_anchors, size=n_samples, replace=True)

In [9]:
# pdb of pocket only (no ligand and only maximum 4A around the pocket)
pocket_mol = get_pocket_mol(pocket_coords, pocket_onehot)

  Failed to kekulize aromatic bonds in OBMol::PerceiveBondOrders (title is /tmp/tmpw1389hey)



Visualizing the protein pocket, the alpha spheres from fpocket are shown in yellow and the randomly selected pocket anchors in green

In [10]:
visualize_3d_pocket_molecule(pocket_mol, mol=None, spin=False, optimize_coords=False, sphere_positions2=alpha_spheres.tolist(), sphere_positions1=pocket_coords[possible_pocket_anchors].tolist())

<py3Dmol.view at 0x7f27cb2807c0>

In [11]:
pocket_onehot = torch.tensor(pocket_onehot).float()
pocket_coords = torch.tensor(pocket_coords).float()

In [12]:
dev = 'cuda:0' # cuda device 
model = AR_DDPM.load_from_checkpoint('pocket-gvp.ckpt', device=dev)
model = model.to(dev)



In [13]:
anchor_model = AnchorGNN_pl.load_from_checkpoint('anchor-model.ckpt', device=dev)
anchor_model = anchor_model.to(dev)

In [16]:
# running the autofragdiff for 8 fragments 
max_num_frags = 8
x, h, mol_masks = generate_mols_for_pocket(n_samples=n_samples,
                                           num_frags=max_num_frags,
                                           pocket_size=pocket_size,
                                           pocket_coords=pocket_coords,
                                           pocket_onehot=pocket_onehot,
                                           lig_coords=None,
                                           anchor_model=anchor_model,
                                           diff_model=model,
                                           device=dev,
                                           return_all=False,
                                           prot_path=pdb,
                                           max_mol_sizes=max_mol_sizes,
                                           all_grids=all_grids,
                                           rejection_sampling=False,
                                           pocket_anchors=pocket_anchors)

generating fragment 1
fragment sizes: [11 10 14 11 14 10 10 11 12 13]
generating fragment at step 2
Sampled fragsizes [ 6  6 13  5  1  1  3 12  5  1]
fragment sizes:  [ 6  6 13  5  1  1  3 12  5  1]
mol sizes: [17 16 27 16 15 11 13 23 17 14]
generating fragment at step 3
Sampled fragsizes [1 5 4 4 7 6 5 6 1 9]
fragment sizes:  [1 5 4 4 7 6 5 6 1 9]
mol sizes: [18 21 31 20 22 17 18 29 18 23]
generating fragment at step 4
Sampled fragsizes [ 1  1  7 10  6  9 11  5  2  1]
fragment sizes:  [ 1  1  7 10  6  9 11  5  2  1]
mol sizes: [19 22 38 30 28 26 29 34 20 24]
generating fragment at step 5
Sampled fragsizes [7 6 6 6 9 2 2 9 5 1]
fragment sizes:  [7 6 6 6 9 2 2 9 5 1]
mol sizes: [26 28 37 36 37 28 31 38 25 25]
generating fragment at step 6
Sampled fragsizes [5 5 6 6 2 5 2 5 6 5]
fragment sizes:  [5 5 6 6 2 5 2 5 6 5]
mol sizes: [31 33 37 36 30 33 33 34 31 30]
generating fragment at step 7
Sampled fragsizes [3 5 6 6 3 7 1 6 1 9]
fragment sizes:  [3 5 6 6 3 7 1 6 1 9]
mol sizes: [34 33 37 

In [17]:
x = x.cpu().numpy()
h = h.cpu().numpy()
mol_masks = mol_masks.cpu().cpu().numpy()

# convert to SDF
all_mols = []
for k in range(len(x)):
    mask = mol_masks[k]
    h_mol = h[k]
    x_mol = x[k][mask.astype(np.bool_)]

    atom_inds = h_mol[mask.astype(np.bool_)].argmax(axis=1)
    atom_types = [idx2atom[x] for x in atom_inds]
    atomic_nums = [CROSSDOCK_CHARGES[i] for i in atom_types]

    try:
        mol_rec = reconstruct_from_generated(x_mol.tolist(), atomic_nums)
        all_mols.append(mol_rec)
    except:
        continue

  Failed to kekulize aromatic bonds in OBMol::PerceiveBondOrders

  Failed to kekulize aromatic bonds in OBMol::PerceiveBondOrders



In [19]:
visualize_3d_pocket_molecule(pocket_mol, mol=all_mols[1], spin=False, optimize_coords=False)

<py3Dmol.view at 0x7f27cb3522e0>