In [1]:
# Imports
import numpy as np
import math
from scipy.stats import pearsonr
import matplotlib.pyplot as plt

from openbabel import openbabel
openbabel.OBMessageHandler().SetOutputLevel(0)
openbabel.obErrorLog.SetOutputLevel(0)

import molgrid
import torch
from models.default2018_model import default2018_Net
from models.gnina_dense_model import Dense

In [2]:
# Fix seeds
seed=42
molgrid.set_random_seed(seed)
torch.manual_seed(seed)
np.random.seed(seed)

# Set CuDNN options for reproducibility
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False

In [3]:
# Helper function to get predictions and labels
def get_predictions_gnina(model, test_file, label_idx=1, pred_idx=-1, batch_size=32, data_root='./'):
    ypred_test, y_test = [], []
    model.eval()
    with torch.no_grad():
        e_test = molgrid.ExampleProvider(data_root=data_root,balanced=False,shuffle=False)
        e_test.populate(test_file)
        gmaker = molgrid.GridMaker()
        dims = gmaker.grid_dimensions(e_test.num_types())
        tensor_shape = (batch_size,)+dims
        input_tensor = torch.zeros(tensor_shape, dtype=torch.float32, device='cuda')
        float_labels = torch.zeros(batch_size, dtype=torch.float32)
        
        num_samples = e_test.size()
        num_batches = -(-num_samples // batch_size)
        for _ in range(num_batches):
            # Load data
            batch = e_test.next_batch(batch_size)
            batch.extract_label(label_idx, float_labels)
            gmaker.forward(batch, input_tensor, random_rotation=False, random_translation=0.0)
            # Get prediction
            output = model(input_tensor)[pred_idx].detach().cpu().numpy().reshape(-1)
            ypred_test.extend(list(output))
            # Get labels
            y_test.extend(list(float_labels.detach().cpu().numpy()))
    ypred_test = np.array(ypred_test)[:num_samples]
    y_test = np.array(y_test)[:num_samples]
    return ypred_test, y_test

# Default2018 - CrossDocked

## Predictive performance - CASF-2016

In [4]:
data_name = 'CASF-2016'
data_root = f'./data/{data_name}/'

dims = (28, 48, 48, 48)
model_name = './models/crossdock_default2018.pt'
model = default2018_Net(dims).to('cuda')
model.load_state_dict(torch.load(model_name))

preds, labels = get_predictions_gnina(model, "./data/CASF-2016/casf_2016_prepared.types", data_root=data_root)
        
rmse = np.sqrt(np.mean((labels-preds)**2))
corr = pearsonr(preds, labels)[0]

print(f'Performance default2018 on {data_name} - RMSE: {rmse:.3f}, Pearson: {corr:.3f}')

Performance default2018 on CASF-2016 - RMSE: 1.550, Pearson: 0.732


# Dense

## Predictive performance - CASF-2016

In [None]:
data_name = 'CASF-2016'
data_root = f'./data/{data_name}/'

dims = (28, 48, 48, 48)
model_name = './models/crossdock_dense.pt'
model = Dense(dims).to('cuda')
model.load_state_dict(torch.load(model_name))

preds, labels = get_predictions_gnina(model, "./data/CASF-2016/casf_2016_prepared.types", data_root=data_root)
        
rmse = np.sqrt(np.mean((labels-preds)**2))
corr = pearsonr(preds, labels)[0]

print(f'Performance of dense on {data_name} - RMSE: {rmse:.3f}, Pearson: {corr:.3f}')