In [1]:
import os
import time
import torch
import numpy as np
#Setting path
import sys
sys.path.append('../../')
from models.unet_model import Unet3D
from prediction.unet_prediction import unet_prediction
from utils.utils import Config, load_checkpoint, PDB, preprocess_pdb_for_unet

In [2]:
train_list = np.load('../data_lists/mlp_train_data.npy', allow_pickle=True)
validation_list = np.load('../data_lists/mlp_validation_data.npy', allow_pickle=True)
pdb_path = 'Here set the path to the directory of your pdb files'

In [5]:
#Parameters for the prediction
config = Config()
config.radius, config.vs, config.pad, config.unet_batch_size, config.device, config.include_hetatm = 4, 0.8, 5, 18, 'cuda', True
config.sigmoid, config.cap, config.prediction_pad, config.prediction_iterations = torch.nn.Sigmoid(), 0.12, 0, 3

In [4]:
#Load model
in_channels, out_channels, intermediate_channels = 3, 1, [16,32,64,128]
#Initialize model and load data.
model = Unet3D(in_channels, out_channels, intermediate_channels)
model.to(config.device)
checkpoint = '../../checkpoints/unet/Unet3D_36_epoch_374.pth.tar'
load_checkpoint(torch.load(checkpoint, map_location=config.device), model, '374 unet model')

=> Loading checkpoint, epoch 374 unet model.


In [5]:
#Compute and save unet prediction of water coordinates for your training dataset.
pdb_list = train_list
pdb_dir = 'mlp_train_36_374'
st = time.time()
for i, pdb_id in enumerate(pdb_list):
    if os.path.exists(f'./unet_prediction_waters/{pdb_dir}/{pdb_id}_waters.npy'):
        continue
    #Load pdb
    pdb = PDB(f'{pdb_path}{pdb_id}.pdb.gz')
    #Preprocess pdb class object
    pdb, atom_hetatm_coords, atom_hetatm_atomtype = preprocess_pdb_for_unet(pdb, config.radius, include_hetatm=config.include_hetatm)
    #Predict
    pred_coords, scores = unet_prediction(atom_hetatm_coords, atom_hetatm_atomtype, model, config)
    np.save(f'./unet_prediction_waters/{pdb_dir}/{pdb_id}_waters.npy', pred_coords)
    print(f'{i+1}/{len(pdb_list)}, {round((time.time()-st)/60, 2)} minutes have passed.', end='\r')

In [6]:
#Compute and save unet prediction of water coordinates for your validation dataset.
pdb_list = validation_list
pdb_dir = 'mlp_validation_36_374'
st = time.time()
for i, pdb_id in enumerate(pdb_list):
    if os.path.exists(f'./unet_prediction_waters/{pdb_dir}/{pdb_id}_waters.npy'):
        continue
    #Load pdb
    pdb = PDB(f'{pdb_path}{pdb_id}.pdb.gz')
    #Preprocess pdb class object
    pdb, atom_hetatm_coords, atom_hetatm_atomtype = preprocess_pdb_for_unet(pdb, config.radius, include_hetatm=config.include_hetatm)
    #Predict
    pred_coords, scores = unet_prediction(atom_hetatm_coords, atom_hetatm_atomtype, model, config)
    np.save(f'./unet_prediction_waters/{pdb_dir}/{pdb_id}_waters.npy', pred_coords)
    print(f'{i+1}/{len(pdb_list)}, {round((time.time()-st)/60, 2)} minutes have passed.', end='\r')