<a href="https://colab.research.google.com/github/casperg92/MaSIF_colab/blob/main/dMaSIF_Colab_V1.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [1]:
#@title Upload a pdb file?
from google.colab import files

upload_file = False #@param {type:"boolean"}

if upload_file:
  uploaded = files.upload()

In [2]:
#@title Change pdb path and chain name(s), then hit `Runtime` -> `Run all`
import os

# target pdb
target_pdb = "/content/MaSIF_colab/example/monomerexample.pdb" #@param {type:"string"}
target_name = target_pdb.split('/')
target_name = target_name[-1].split('.')

if target_name[-1] == 'pdb':
  target_name = target_name[0]
else:
  print('Please upload a valid .pdb file!')

chain_name = 'A' #@param {type:"string"}
chains = list(chain_name)

# Path to MaSIF weights
model_resolution = '1 Angstrom' #@param ["1 Angstrom"]
patch_radius = '9 Angstrom' #@param ["9 Angstrom", "12 Angstrom"]


if patch_radius == '9 Angstrom' and model_resolution == '1 Angstrom':
  model_path = '/content/MaSIF_colab/models/dMaSIF_site_3layer_16dims_9A_100sup_epoch64'
  resolution = 1.0
  radius = 9
elif patch_radius == '12 Angstrom' and model_resolution == '1 Angstrom':
  model_path = '/content/MaSIF_colab/models/dMaSIF_site_3layer_16dims_12A_100sup_epoch71'
  resolution = 1.0
  radius = 12

# create new folders
# chain dir
chains_dir = '/content/chains'
isExist = os.path.exists(chains_dir)
if not isExist:
  os.makedirs(chains_dir)

# npy folder
npy_dir = '/content/npys'
isExist = os.path.exists(npy_dir)
if not isExist:
  os.makedirs(npy_dir)

# Create folder for the embeddings
pred_dir = '/content/preds'
isExist = os.path.exists(pred_dir)
if not isExist:
  os.makedirs(pred_dir)

In [17]:
#@title Install dependencies
from tqdm.notebook import tqdm

#Manually set p_bar
def update_pbar(p_bar, c=1):
  p_bar.update(c)
  p_bar.refresh()

p_bar = tqdm(range(9))

# Switch to CUDA 11.1
%cd -q /usr/local/
!rm -rf cuda > /dev/null
!ln -s /usr/local/cuda-11.1 /usr/local/cuda > /dev/null
#!stat cuda

# Git clone MaSIF for Colab (including examples and weights)
print('Downloading MaSIF..')
%cd -q /content
!rm -fr MaSIF_colab > /dev/null
!git clone --quiet https://github.com/casperg92/MaSIF_colab.git > /dev/null
update_pbar(p_bar)


# Downgrade pytorch to make it compatable with pytorch geometric
print('Installing PyTorch..')
!pip install torch==1.8.1+cu111 -f https://download.pytorch.org/whl/cu111/torch_stable.html --quiet
update_pbar(p_bar)
print('Installing PyTorch Geometric..')
!pip install torch-scatter==2.0.6 torch-sparse torch-cluster -f https://data.pyg.org/whl/torch-1.8.1+cu111.html --quiet
!pip install torch-geometric==1.6.1 --quiet
update_pbar(p_bar)
print('Installing PyKeops..')
!pip install git+https://github.com/getkeops/keops.git@python_engine --quiet
update_pbar(p_bar)
print('Installing BioPython..')
!pip install biopython --quiet
update_pbar(p_bar)
print('Installing plyfile..')
!pip install plyfile --quiet
update_pbar(p_bar)
print('Installing pyvtk..')
!pip install pyvtk --quiet
update_pbar(p_bar)
print('Installing nglview..')
!pip install -q nglview --quiet
update_pbar(p_bar)
print('Installing pdbparser..')
!pip install pdbparser --quiet
update_pbar(p_bar)

  0%|          | 0/9 [00:00<?, ?it/s]

Downloading MaSIF..
Installing PyTorch..
Installing PyTorch Geometric..
Installing PyKeops..
Installing BioPython..
Installing plyfile..
Installing pyvtk..
Installing nglview..
Installing pdbparser..


In [4]:
#@title Load functions
import sys
sys.path.append("/content/MaSIF_colab") 
sys.path.append("/content/MaSIF_colab/data_preprocessing") 

import numpy as np
import pykeops
import torch
import os
from Bio.PDB import *
from MaSIF_colab.data_preprocessing.download_pdb import convert_to_npy
from torch_geometric.data import DataLoader
from torch_geometric.transforms import Compose
import argparse

# Custom data loader and model:
from data import ProteinPairsSurfaces, PairData, CenterPairAtoms, load_protein_pair
from data import RandomRotationPairAtoms, NormalizeChemFeatures, iface_valid_filter
from model import dMaSIF
from data_iteration import iterate
from helper import *

# For showing the plot in nglview
from google.colab import output
output.enable_custom_widget_manager()
import nglview as ng
#import ipywidgets as widgets
from pdbparser.pdbparser import pdbparser

def generate_descr(model_path, output_path, pdb_file, npy_directory, radius, resolution):
    """Generat descriptors for a MaSIF site model"""
    parser = argparse.ArgumentParser(description="Network parameters")
    parser.add_argument("--experiment_name", type=str, default=model_path)
    parser.add_argument("--use_mesh", type=bool, default=False)
    parser.add_argument("--embedding_layer",type=str,default="dMaSIF")
    parser.add_argument("--curvature_scales",type=list,default=[1.0, 2.0, 3.0, 5.0, 10.0])
    parser.add_argument("--resolution",type=float,default=resolution)
    parser.add_argument("--distance",type=float,default=1.05)
    parser.add_argument("--variance",type=float,default=0.1)
    parser.add_argument("--sup_sampling", type=int, default=100)
    parser.add_argument("--atom_dims",type=int,default=6)
    parser.add_argument("--emb_dims",type=int,default=16)
    parser.add_argument("--in_channels",type=int,default=16)
    parser.add_argument("--orientation_units",type=int,default=16)
    parser.add_argument("--unet_hidden_channels",type=int,default=8)
    parser.add_argument("--post_units",type=int,default=8)
    parser.add_argument("--n_layers", type=int, default=3)
    parser.add_argument("--radius", type=float, default=radius)
    parser.add_argument("--k",type=int,default=40)
    parser.add_argument("--dropout",type=float,default=0.0)
    parser.add_argument("--site", type=bool, default=True) # set to true for site model
    parser.add_argument("--batch_size", type=int, default=1)
    parser.add_argument("--search",type=bool,default=False) # Set to true for search model
    parser.add_argument("--single_pdb",type=str,default=pdb_file)
    parser.add_argument("--seed", type=int, default=42)
    parser.add_argument("--random_rotation",type=bool,default=False)
    parser.add_argument("--device", type=str, default="cpu")
    #parser.add_argument("--single_protein",type=bool,default=True)
    parser.add_argument("--single_protein",type=bool,default=True) # set to false for site
    parser.add_argument("--no_chem", type=bool, default=False)
    parser.add_argument("--no_geom", type=bool, default=False)
    
    args = parser.parse_args("")

    model_path = args.experiment_name
    save_predictions_path = Path(output_path)
    
    # Ensure reproducability:
    torch.backends.cudnn.deterministic = True
    torch.manual_seed(args.seed)
    torch.cuda.manual_seed_all(args.seed)
    np.random.seed(args.seed)

    # Load the train and test datasets:
    transformations = (
        Compose([NormalizeChemFeatures(), CenterPairAtoms(), RandomRotationPairAtoms()])
        if args.random_rotation
        else Compose([NormalizeChemFeatures()])
    )
    
    if args.single_pdb != "":
        single_data_dir = Path(npy_directory)
        test_dataset = [load_protein_pair(args.single_pdb, single_data_dir, single_pdb=True)]
        test_pdb_ids = [args.single_pdb]

    # PyTorch geometric expects an explicit list of "batched variables":
    batch_vars = ["xyz_p1", "xyz_p2", "atom_coords_p1", "atom_coords_p2"]
    test_loader = DataLoader(
        test_dataset, batch_size=args.batch_size, follow_batch=batch_vars
    )

    net = dMaSIF(args)
    # net.load_state_dict(torch.load(model_path, map_location=args.device))
    net.load_state_dict(torch.load(model_path, map_location=args.device)["model_state_dict"])
    net = net.to(args.device)

    # Perform one pass through the data:
    info = iterate(
        net,
        test_loader,
        None,
        args,
        test=True,
        save_path=save_predictions_path,
        pdb_ids=test_pdb_ids,
    )
    return info



In [5]:
#@title Run MaSIF
# Generate the surface features
convert_to_npy(target_pdb, chains_dir, npy_dir, chains)

# Generate the embeddings
pdb_name = "{n}_{c}_{c}".format(n= target_name, c=chain_name)
info = generate_descr(model_path, pred_dir, pdb_name, npy_dir, radius, resolution)

# In info I hardcoded memory usage to 0 so MaSIF would run on the CPU. We might want to change this.

100%|██████████| 1/1 [00:02<00:00,  2.04s/it]


In [8]:
#@title Plot output

## file addresses
main_pdb = target_pdb
coord = np.load("preds/{n}_{c}_predcoords.npy".format(n= target_name, c=chain_name))
embedding = np.load("/content/preds/{n}_{c}_predfeatures_emb1.npy".format(n= target_name, c=chain_name))

# Normalize embedding to represent a b-factor value between 0-100
b_factor = []
for emb in embedding:
    b_factor.append(emb[24])
  
b_factor = [(float(i)-min(b_factor))/(max(b_factor)-min(b_factor)) for i in b_factor]

# writing a psudo pdb of all points using their coordinates and H atom.
records = []

for i in range(len(coord)):
    points = coord[i]
    x_coord = points[0]
    y_coord = points[1]
    z_coord = points[2]

    color = b_factor[i]

    records.append( { "record_name"       : 'ATOM',
                  "serial_number"     : len(records)+1,
                  "atom_name"         : 'H',
                  "location_indicator": '',
                  "residue_name"      : 'XYZ',
                  "chain_identifier"  : '',
                  "sequence_number"   : len(records)+1,
                  "code_of_insertion" : '',
                  "coordinates_x"     : x_coord,
                  "coordinates_y"     : y_coord,
                  "coordinates_z"     : z_coord,
                  "occupancy"         : 1.0,
                  "temperature_factor": color*100,
                  "segment_identifier": '',
                  "element_symbol"    : 'H',
                  "charge"            : '',
                  } )
    
pdb = pdbparser()
pdb.records = records

pdb.export_pdb("result_pdb.pdb")

# reading the psudo PDB we generated above for the point cloud.
coordPDB = "result_pdb.pdb"
view = ng.NGLWidget()
view.add_component(ng.FileStructure(os.path.join("/content", coordPDB)), defaultRepresentation=False)
#view.clear_representations()

# representation with our customized colorscheme.
view.add_representation('point', 
                        useTexture = 1,
                        pointSize = 2,
                        colorScheme = "bfactor",
                        #colorDomain = [50.0, 100.0], 
                        colorScale = 'rwb',
                        selection='_H')

view.add_component(ng.FileStructure(os.path.join("/content", main_pdb)))
view.background = 'black'
view

2022-02-19 10:36:16 - pdbparser <INFO> All records successfully exported to 'result_pdb.pdb'


NGLWidget(background='black')