<a href="https://colab.research.google.com/github/casperg92/MaSIF_colab/blob/main/dMaSIF_Colab_V1.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# dMaSIF site
Protein binding is determined by the chemical and geometric features between their surfaces. differentiable Molecular Surface Interaction Fingerprinting (dMaSIF) site is a geometric deep learning framework trained on these surface 'fingerprints' to identify potential protein binding sites. For more details, check out the original papers:

1) [Gainza, P., Sverrisson, F., Monti, F., Rodola, E., Boscaini, D., Bronstein, M. M., & Correia, B. E. (2020). Deciphering interaction fingerprints from protein molecular surfaces using geometric deep learning. Nature Methods, 17(2), 184-192.](https://doi.org/10.1038/s41592-019-0666-6)

2) [Sverrisson, F., Feydy, J., Correia, B. E., & Bronstein, M. M. (2021). Fast end-to-end learning on protein surfaces. In Proceedings of the IEEE/CVF Conference on Computer Vision and Pattern Recognition (pp. 15272-15281).](http://dx.doi.org/10.1109/CVPR46437.2021.01502)

In [1]:
#@title Upload a pdb file?
#@markdown pdb files will be uploaded to the '/content/pdbs' folder.
from google.colab import files
import os

# Create folder for the pdbs
pred_dir = '/content/pdbs'
isExist = os.path.exists(pred_dir)
if not isExist:
  os.makedirs(pred_dir)

upload_file = False #@param {type:"boolean"}

%cd -q /content/pdbs
if upload_file:
  uploaded = files.upload()
%cd -q /content

In [2]:
#@title Change pdb path and chain name(s), then hit `Runtime` -> `Run all`
#@markdown Note: the pdb file cannot contain an underscore ('_') in its name.
import os
import glob



# target pdb
target_pdb = "/content/MaSIF_colab/example/monomerexample.pdb" #@param {type:"string"}
target_name = target_pdb.split('/')
target_name = target_name[-1].split('.')

if target_name[-1] == 'pdb':
  target_name = target_name[0]
else:
  print('Please upload a valid .pdb file!')

chain_name = 'A' #@param {type:"string"}
chains = [chain_name]

# Path to MaSIF weights
#@markdown A resolution of 0.7 Angstrom gives a higher point cloud density and a higher performance. Different radii settings do not seem to impact performance.
model_resolution = '0.7 Angstrom' #@param ["1 Angstrom", "0.7 Angstrom"]
patch_radius = '9 Angstrom' #@param ["9 Angstrom", "12 Angstrom"]


if patch_radius == '9 Angstrom':
  if model_resolution == '1 Angstrom':
    model_path = '/content/MaSIF_colab/models/dMaSIF_site_3layer_16dims_9A_100sup_epoch64'
    resolution = 1.0
    radius = 9
    sup_sampling = 100
  else:
    model_path = '/content/MaSIF_colab/models/dMaSIF_site_3layer_16dims_9A_0.7res_150sup_epoch85'
    resolution = 0.7
    radius = 9
    supsampling = 150

elif patch_radius == '12 Angstrom':
  if model_resolution == '1 Angstrom':
    model_path = '/content/MaSIF_colab/models/dMaSIF_site_3layer_16dims_12A_100sup_epoch71'
    resolution = 1.0
    radius = 12
    supsampling = 100
  else:
    model_path = '/content/MaSIF_colab/models/dMaSIF_site_3layer_16dims_12A_0.7res_150sup_epoch59'
    resolution = 0.7
    radius = 12
    supsampling = 100


# create new folders
# chain dir
chains_dir = '/content/chains'
isExist = os.path.exists(chains_dir)
if not isExist:
  os.makedirs(chains_dir)
else:
  files = glob.glob(chains_dir + '/*')
  for f in files:
    os.remove(f)

# npy folder
npy_dir = '/content/npys'
isExist = os.path.exists(npy_dir)
if not isExist:
  os.makedirs(npy_dir)
else:
  files = glob.glob(npy_dir + '/*')
  for f in files:
    os.remove(f)

# Create folder for the embeddings
pred_dir = '/content/preds'
isExist = os.path.exists(pred_dir)
if not isExist:
  os.makedirs(pred_dir)
else:
  files = glob.glob(pred_dir + '/*')
  for f in files:
    os.remove(f)

In [3]:
#@title Install dependencies
from tqdm.notebook import tqdm

#Manually set p_bar
def update_pbar(p_bar, c=1):
  p_bar.update(c)
  p_bar.refresh()

p_bar = tqdm(range(10))

# Switch to CUDA 11.1
%cd -q /usr/local/
!rm -rf cuda > /dev/null
!ln -s /usr/local/cuda-11.1 /usr/local/cuda > /dev/null
#!stat cuda

# Git clone MaSIF for Colab (including examples and weights)
print('Downloading MaSIF..')
%cd -q /content
!rm -fr MaSIF_colab > /dev/null
!git clone --quiet https://github.com/casperg92/MaSIF_colab.git > /dev/null
update_pbar(p_bar)


# Downgrade pytorch to make it compatable with pytorch geometric
print('Installing PyTorch..')
!pip install torch==1.8.1+cu111 -f https://download.pytorch.org/whl/cu111/torch_stable.html &> /dev/null
update_pbar(p_bar)
print('Installing PyTorch Geometric..')
!pip install torch-scatter==2.0.7 -f https://data.pyg.org/whl/torch-1.8.1+cu111.html &> /dev/null
!pip install torch-sparse==0.6.11 -f https://data.pyg.org/whl/torch-1.8.1+cu111.html &> /dev/null
!pip install torch-cluster==1.5.9 -f https://data.pyg.org/whl/torch-1.8.1+cu111.html &> /dev/null
!pip install torch-geometric==1.6.1 &> /dev/null
update_pbar(p_bar)
print('Installing PyKeops..')
!pip install git+https://github.com/getkeops/keops.git@python_engine &> /dev/null
update_pbar(p_bar)
print('Installing BioPython..')
!pip install biopython &> /dev/null
update_pbar(p_bar)
print('Installing plyfile..')
!pip install plyfile &> /dev/null
update_pbar(p_bar)
print('Installing pyvtk..')
!pip install pyvtk &> /dev/null
update_pbar(p_bar)
print('Installing nglview..')
!pip install -q nglview &> /dev/null
update_pbar(p_bar)
print('Installing pdbparser..')
!pip install pdbparser &> /dev/null
update_pbar(p_bar)
print('Installing reduce..')
!git clone --quiet https://github.com/rlabduke/reduce > /dev/null
!cmake reduce &> /dev/null
!make &> /dev/null
!sudo make install &> /dev/null
update_pbar(p_bar)

  0%|          | 0/10 [00:00<?, ?it/s]

Downloading MaSIF..
Installing PyTorch..
Installing PyTorch Geometric..
Looking in links: https://data.pyg.org/whl/torch-1.8.1+cu111.html
Collecting torch-sparse==0.6.11
  Downloading https://data.pyg.org/whl/torch-1.8.0%2Bcu111/torch_sparse-0.6.11-cp37-cp37m-linux_x86_64.whl (1.9 MB)
[K     |████████████████████████████████| 1.9 MB 319 kB/s 
Installing collected packages: torch-sparse
Successfully installed torch-sparse-0.6.11
Looking in links: https://data.pyg.org/whl/torch-1.8.1+cu111.html
Collecting torch-cluster==1.5.9
  Downloading https://data.pyg.org/whl/torch-1.8.0%2Bcu111/torch_cluster-1.5.9-cp37-cp37m-linux_x86_64.whl (1.7 MB)
[K     |████████████████████████████████| 1.7 MB 687 kB/s 
[?25hInstalling collected packages: torch-cluster
Successfully installed torch-cluster-1.5.9
Installing PyKeops..
Installing BioPython..
Installing plyfile..
Installing pyvtk..
Installing nglview..
Installing pdbparser..
Installing reduce..


In [4]:
#@title Load functions
import sys
sys.path.append("/content/MaSIF_colab") 
sys.path.append("/content/MaSIF_colab/data_preprocessing") 

import numpy as np
import pykeops
import torch
from Bio.PDB import *
from MaSIF_colab.data_preprocessing.download_pdb import convert_to_npy
from torch_geometric.data import DataLoader
from torch_geometric.transforms import Compose
import argparse
import shutil

# Custom data loader and model:
from data import ProteinPairsSurfaces, PairData, CenterPairAtoms, load_protein_pair
from data import RandomRotationPairAtoms, NormalizeChemFeatures, iface_valid_filter
from model import dMaSIF
from data_iteration import iterate
from helper import *

# For showing the plot in nglview
from google.colab import output
output.enable_custom_widget_manager()
import nglview as ng
#import ipywidgets as widgets
from pdbparser.pdbparser import pdbparser

# For downloading files
from google.colab import files

def generate_descr(model_path, output_path, pdb_file, npy_directory, radius, resolution,supsampling):
    """Generat descriptors for a MaSIF site model"""
    parser = argparse.ArgumentParser(description="Network parameters")
    parser.add_argument("--experiment_name", type=str, default=model_path)
    parser.add_argument("--use_mesh", type=bool, default=False)
    parser.add_argument("--embedding_layer",type=str,default="dMaSIF")
    parser.add_argument("--curvature_scales",type=list,default=[1.0, 2.0, 3.0, 5.0, 10.0])
    parser.add_argument("--resolution",type=float,default=resolution)
    parser.add_argument("--distance",type=float,default=1.05)
    parser.add_argument("--variance",type=float,default=0.1)
    parser.add_argument("--sup_sampling", type=int, default=supsampling)
    parser.add_argument("--atom_dims",type=int,default=6)
    parser.add_argument("--emb_dims",type=int,default=16)
    parser.add_argument("--in_channels",type=int,default=16)
    parser.add_argument("--orientation_units",type=int,default=16)
    parser.add_argument("--unet_hidden_channels",type=int,default=8)
    parser.add_argument("--post_units",type=int,default=8)
    parser.add_argument("--n_layers", type=int, default=3)
    parser.add_argument("--radius", type=float, default=radius)
    parser.add_argument("--k",type=int,default=40)
    parser.add_argument("--dropout",type=float,default=0.0)
    parser.add_argument("--site", type=bool, default=True) # set to true for site model
    parser.add_argument("--batch_size", type=int, default=1)
    parser.add_argument("--search",type=bool,default=False) # Set to true for search model
    parser.add_argument("--single_pdb",type=str,default=pdb_file)
    parser.add_argument("--seed", type=int, default=42)
    parser.add_argument("--random_rotation",type=bool,default=False)
    parser.add_argument("--device", type=str, default="cpu")
    #parser.add_argument("--single_protein",type=bool,default=True)
    parser.add_argument("--single_protein",type=bool,default=True) # set to false for site
    parser.add_argument("--no_chem", type=bool, default=False)
    parser.add_argument("--no_geom", type=bool, default=False)
    
    args = parser.parse_args("")

    model_path = args.experiment_name
    save_predictions_path = Path(output_path)
    
    # Ensure reproducability:
    torch.backends.cudnn.deterministic = True
    torch.manual_seed(args.seed)
    torch.cuda.manual_seed_all(args.seed)
    np.random.seed(args.seed)

    # Load the train and test datasets:
    transformations = (
        Compose([NormalizeChemFeatures(), CenterPairAtoms(), RandomRotationPairAtoms()])
        if args.random_rotation
        else Compose([NormalizeChemFeatures()])
    )
    
    if args.single_pdb != "":
        single_data_dir = Path(npy_directory)
        test_dataset = [load_protein_pair(args.single_pdb, single_data_dir, single_pdb=True)]
        test_pdb_ids = [args.single_pdb]

    # PyTorch geometric expects an explicit list of "batched variables":
    batch_vars = ["xyz_p1", "xyz_p2", "atom_coords_p1", "atom_coords_p2"]
    test_loader = DataLoader(
        test_dataset, batch_size=args.batch_size, follow_batch=batch_vars
    )

    net = dMaSIF(args)
    # net.load_state_dict(torch.load(model_path, map_location=args.device))
    net.load_state_dict(torch.load(model_path, map_location=args.device)["model_state_dict"])
    net = net.to(args.device)

    # Perform one pass through the data:
    info = iterate(
        net,
        test_loader,
        None,
        args,
        test=True,
        save_path=save_predictions_path,
        pdb_ids=test_pdb_ids,
    )
    return info

  

def show_pointcloud(main_pdb, coord_file, emb_file):
  # Normalize embedding to represent a b-factor value between 0-100
  b_factor = []
  for emb in emb_file:
      b_factor.append(emb[-2])
  
  # b_factor = [(float(i)-min(b_factor))/(max(b_factor)-min(b_factor)) for i in b_factor]

  # writing a psudo pdb of all points using their coordinates and H atom.
  records = []

  for i in range(len(coord_file)):
      points = coord_file[i]
      x_coord = points[0]
      y_coord = points[1]
      z_coord = points[2]

      records.append( { "record_name"       : 'ATOM',
                    "serial_number"     : len(records)+1,
                    "atom_name"         : 'H',
                    "location_indicator": '',
                    "residue_name"      : 'XYZ',
                    "chain_identifier"  : '',
                    "sequence_number"   : len(records)+1,
                    "code_of_insertion" : '',
                    "coordinates_x"     : x_coord,
                    "coordinates_y"     : y_coord,
                    "coordinates_z"     : z_coord,
                    "occupancy"         : 1.0,
                    "temperature_factor": b_factor[i]*100,
                    "segment_identifier": '',
                    "element_symbol"    : 'H',
                    "charge"            : '',
                    } )
    
  pdb = pdbparser()
  pdb.records = records

  pdb.export_pdb("pointcloud.pdb")

  # reading the psudo PDB we generated above for the point cloud.
  coordPDB = "pointcloud.pdb"
  view = ng.NGLWidget()
  view.add_component(ng.FileStructure(os.path.join("/content", coordPDB)), defaultRepresentation=False)

  # representation with our customized colorscheme.
  view.add_representation('point', 
                          useTexture = 1,
                          pointSize = 2,
                          colorScheme = "bfactor",
                          colorDomain = [100.0, 0.0], 
                          colorScale = 'rwb',
                          selection='_H')

  view.add_component(ng.FileStructure(os.path.join("/content", main_pdb)))
  view.background = 'black'
  return view

def show_structure(main_pdb):
  # reading the psudo PDB we generated above for the point cloud.
  view = ng.NGLWidget()

  view.add_component(ng.FileStructure(main_pdb), defaultRepresentation=False)
  view.add_representation("cartoon", colorScheme = "bfactor", colorScale = 'rwb', colorDomain = [100.0, 0.0])
  view.add_representation("ball+stick", colorScheme = "bfactor", colorScale = 'rwb', colorDomain = [100.0, 0.0])
  view.background = 'black'
  return view

[KeOps] Compiling main dll ... OK




In [5]:
#@title Run MaSIF
# Protonate the pdb file using reduce
tmp_pdb = '/content/pdbs/tmp_1.pdb'
shutil.copyfile(target_pdb, tmp_pdb)

# Remove protons if there are any
!reduce -Trim -Quiet /content/pdbs/tmp_1.pdb > /content/pdbs/tmp_2.pdb
# Add protons
!reduce -HIS -Quiet /content/pdbs/tmp_2.pdb > /content/pdbs/tmp_3.pdb

tmp_pdb = '/content/pdbs/tmp_3.pdb'
shutil.copyfile(tmp_pdb, target_pdb)

!rm /content/pdbs/tmp_1.pdb /content/pdbs/tmp_2.pdb /content/pdbs/tmp_3.pdb

# Generate the surface features
convert_to_npy(target_pdb, chains_dir, npy_dir, chains)

# Generate the embeddings
pdb_name = "{n}_{c}_{c}".format(n= target_name, c=chain_name)
info = generate_descr(model_path, pred_dir, pdb_name, npy_dir, radius, resolution, supsampling)

# In info I hardcoded memory usage to 0 so MaSIF would run on the CPU. We might want to change this.


  0%|          | 0/1 [00:00<?, ?it/s][A

[KeOps] Compiling formula Sum_Reduction(Exp(-Sqrt(Sum((Var(0,3,0)-Var(1,3,1))**2))),1) ... OK
[KeOps] Compiling formula Sum_Reduction((Var(0,1,0)*Exp(-Sqrt(Sum((Var(1,3,0)-Var(2,3,1))**2))))/Var(3,1,1),1) ... OK
[KeOps] Compiling formula Max_SumShiftExpWeight_Reduction(Concat(-Sqrt(Sum((Var(0,3,0)-Var(1,3,1))**2))/Var(2,1,0),1),1) ... OK
[KeOps] Compiling formula Sum_Reduction(-(-((2*SumT((Var(2,1,0)*(Extract(Var(3,2,1),1,1)*Exp(-Sqrt(Sum((Var(0,3,0)-Var(1,3,1))**2))/Var(2,1,0)-Extract(Var(4,2,1),0,1))))*(1/2*Rsqrt(Sum((Var(0,3,0)-Var(1,3,1))**2))),3))*(Var(0,3,0)-Var(1,3,1))))/Var(2,1,0)**2,1) ... OK
[KeOps] Compiling formula Sum_Reduction(-(-((2*SumT(((Var(0,1,0)*(Var(3,1,1)*Var(4,1,1)))*Exp(-Sqrt(Sum((Var(1,3,0)-Var(2,3,1))**2))))*(1/2*Rsqrt(Sum((Var(1,3,0)-Var(2,3,1))**2))),3))*(Var(1,3,0)-Var(2,3,1))))/Var(3,1,1)**2,1) ... OK
[KeOps] Compiling formula Sum_Reduction(-((Var(0,1,0)*Exp(-Sqrt(Sum((Var(1,3,0)-Var(2,3,1))**2))))*Var(4,1,1))/Var(3,1,1)**2,1) ... OK
[KeOps] Compiling form


100%|██████████| 1/1 [00:11<00:00, 11.40s/it]


In [6]:
#@title Generate PDBs for hotspot atoms and residues
list_hotspot_residues = False #@param {type:"boolean"}

from Bio.PDB.PDBParser import PDBParser
from scipy.spatial.distance import cdist

parser=PDBParser(PERMISSIVE=1)
structure=parser.get_structure("structure", target_pdb)

coord = np.load("preds/{n}_{c}_predcoords.npy".format(n= target_name, c=chain_name))
embedding = np.load("/content/preds/{n}_{c}_predfeatures_emb1.npy".format(n= target_name, c=chain_name))
atom_coords = np.stack([atom.get_coord() for atom in structure.get_atoms()])

b_factor = embedding[:, -2]
# b_factor = (b_factor - min(b_factor)) / (max(b_factor) - min(b_factor))

dists = cdist(atom_coords, coord)
nn_ind = np.argmin(dists, axis=1)
dists = dists[np.arange(len(dists)), nn_ind]
atom_b_factor = b_factor[nn_ind]
dist_thresh = 2.0
atom_b_factor[dists > dist_thresh] = 0.0

for i, atom in enumerate(structure.get_atoms()):
    atom.set_bfactor(atom_b_factor[i] * 100)

# Create folder for the embeddings
pred_dir = '/content/output'
os.makedirs(pred_dir, exist_ok=True)

# Save pdb file with per-atom b-factors
io = PDBIO()
io.set_structure(structure)
io.save("/content/output/per_atom_binding.pdb")

atom_residues = np.array([atom.get_parent().id[1] for atom in structure.get_atoms()])

hotspot_res = {}
for residue in structure.get_residues():
    res_id = residue.id[1]
    res_b_factor = np.max(atom_b_factor[atom_residues == res_id])
    hotspot_res[res_id] = res_b_factor
    for atom in residue.get_atoms():
        atom.set_bfactor(res_b_factor * 100)

# Save pdb file with per-residue b-factors
io = PDBIO()
io.set_structure(structure)
io.save("/content/output/per_resi_binding.pdb")

if list_hotspot_residues:
  print('Sorted on residue contribution (high to low')
  for w in sorted(hotspot_res, key=hotspot_res.get, reverse=True):
    print(w, hotspot_res[w])



In [9]:
#@title Plot output
#@markdown Blue identifies non-binding and red identifies binding interaction sites. Rerun this cell if you want to change the plotted structure.
plot_structure = 'Pointcloud' #@param ["Pointcloud", "Residues", "Atoms"]

## file addresses
if plot_structure == 'Pointcloud':
  view = show_pointcloud(target_pdb, coord, embedding)
elif plot_structure == "Residues":
  view = show_structure('/content/output/per_resi_binding.pdb')
elif plot_structure == "Atoms":
  view = show_structure('/content/output/per_atom_binding.pdb')

view

2022-03-21 12:17:57 - pdbparser <INFO> All records successfully exported to 'pointcloud.pdb'


NGLWidget(background='black')

In [8]:
#@title Download predictions
!cp /content/preds/* /content/output
!cp /content/pointcloud.pdb /content/output
!zip -r /content/output.zip output
files.download("/content/output.zip")

  adding: output/ (stored 0%)
  adding: output/per_atom_binding.pdb (deflated 77%)
  adding: output/monomerexample_A_predfeatures_emb1.npy (deflated 8%)
  adding: output/per_resi_binding.pdb (deflated 78%)
  adding: output/monomerexample_A_predcoords.npy (deflated 9%)
  adding: output/monomerexample_A_pred_emb1.vtk (deflated 59%)
  adding: output/pointcloud.pdb (deflated 77%)


<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

# Visualization
To view the predictions download the output and open the pdb file with either pymol or chimeraX. The predicted binding sites can be visualized by coloring based on the b-factor.

In pymol we recommend using the command  'spectrum b, blue_white_red, minimum=0, maximum=100' for a better visualization.