# Train HAC-Net

This notebook allows one to train HAC-Net with any dataset that we provide, or any other dataset of the same format and preprocessing requirements. In this notebook, we train with the training set that corresponds to testing on the PDBbind 2016 core set. 

## Set Up Notebook

In [None]:
# install torch packages necessary for GCNs
!pip install torch-scatter torch-sparse -f https://data.pyg.org/whl/torch-${TORCH}+${CUDA}.html
!pip install torch-geometric

In [None]:
# import packages
import torch
from torch.utils.data import Dataset
import os
import h5py
from sklearn.metrics import pairwise_distances
from torch_geometric.nn.conv import MessagePassing, GatedGraphConv
from torch_geometric.nn import global_add_pool
from torch_geometric.utils import add_self_loops
from torch_geometric.nn.aggr import AttentionalAggregation
from torch import Tensor
import torch.nn as nn
import numpy as np
import math
from torch.nn.parallel import DataParallel, DistributedDataParallel
from torch_geometric.nn import DataParallel as GeometricDataParallel
from torch_geometric.data import DataListLoader, Data
from torch_geometric.utils import dense_to_sparse
from torch.utils.data import Dataset, DataLoader
import matplotlib.pyplot as plt
from tqdm import tqdm
import scipy as sp
from scipy.stats import *
from sklearn.metrics import *
import random
from sklearn.metrics import r2_score, mean_absolute_error, mean_squared_error
from torch._C import NoneType
from torch.optim import Adam, RMSprop, lr_scheduler

In [None]:
# mount google drive
from google.colab import drive
drive.mount('/content/drive')

## Data Containers

In [None]:
''' Define a class to contain the data that will be included in the dataloader 
sent to the GCN model '''

class GCN_Dataset(Dataset):
  
    def __init__(self, data_file):
        super(GCN_Dataset, self).__init__()
        self.data_file = data_file
        self.data_dict = {}
        self.data_list = []
        
        # retrieve PDB id's and affinities from hdf file
        with h5py.File(data_file, 'r') as f:
            for pdbid in f.keys():
                affinity = np.asarray(f[pdbid].attrs['affinity']).reshape(1, -1)
                self.data_list.append((pdbid, affinity))

    def __len__(self):
        return len(self.data_list)

    def __getitem__(self, item):

        
        if item in self.data_dict.keys():
            return self.data_dict[item]

        pdbid, affinity = self.data_list[item]
        node_feats, coords = None, None

        coords=h5py.File(self.data_file,'r')[pdbid][:,0:3]
        dists=pairwise_distances(coords, metric='euclidean')
        
        self.data_dict[item] = (pdbid, dists)
        return self.data_dict[item]

In [None]:
""" Define a class to contain the data that will be included in the dataloader 
sent to the 3D-CNN """

class CNN_Dataset(Dataset):

	def __init__(self, hdf_path, feat_dim=22):
		super(CNN_Dataset, self).__init__()
		self.hdf_path = hdf_path
		self.feat_dim = feat_dim
		self.hdf = h5py.File(self.hdf_path, 'r')
		self.data_info_list = []
    # append PDB id and affinity label to data_info_list
		for pdbid in self.hdf.keys():
			affinity = float(self.hdf[pdbid].attrs['affinity'])
			self.data_info_list.append([pdbid, affinity])

	def close(self):
		self.hdf.close()

	def __len__(self):
		count = len(self.data_info_list)
		return count
		
	def get_data_info_list(self):
		return self.data_info_list

	def __getitem__(self, idx):
		pdbid, affinity = self.data_info_list[idx]

		data = self.hdf[pdbid][:]
		x = torch.tensor(data)
		x = x.permute(3,0,1,2)
		y = torch.tensor(np.expand_dims(affinity, axis=0))
		return x,y, pdbid

In [None]:
""" Define a class to contain the extracted 3D-CNN features that will be included in the dataloader 
sent to the fully-connected network """

class Linear_Dataset(Dataset):
	def __init__(self, npy_path, feat_dim=22):
		super(Linear_Dataset, self).__init__()
		self.npy_path = npy_path
		self.input_feat_array = np.load(npy_path, allow_pickle=True)[:,:-1].astype(np.float32)
		self.input_affinity_array = np.load(npy_path, allow_pickle=True)[:,-1].astype(np.float32)
		self.data_info_list = []
    


	def __len__(self):
		count = self.input_feat_array.shape[0]
		return count


	def __getitem__(self, idx):
		data, affinity = self.input_feat_array[idx], self.input_affinity_array[idx]

		x = torch.tensor(data)
		y = torch.tensor(np.expand_dims(affinity, axis=0))
		return x,y

## Model Architecture

In [None]:
""" Define GCN architecture class """

class GCN(torch.nn.Module):

    def __init__(self, in_channels, gather_width=128, prop_iter=4, dist_cutoff=3.5):
        super(GCN, self).__init__()

        #define distance cutoff
        self.dist_cutoff=torch.Tensor([dist_cutoff])
        if torch.cuda.is_available():
            self.dist_cutoff = self.dist_cutoff.cuda()

        #Attentional aggregation
        self.gate_net = nn.Sequential(nn.Linear(in_channels, int(in_channels/2)), nn.Softsign(), nn.Linear(int(in_channels/2), int(in_channels/4)), nn.Softsign(), nn.Linear(int(in_channels/4),1))
        self.attn_aggr = AttentionalAggregation(self.gate_net)
        
        #Gated Graph Neural Network
        self.gate = GatedGraphConv(in_channels, prop_iter, aggregation=self.attn_aggr)

        #Simple neural networks for use in asymmetric attentional aggregation
        self.attn_net_i=nn.Sequential(nn.Linear(in_channels * 2, in_channels), nn.Softsign(),nn.Linear(in_channels, gather_width), nn.Softsign())
        self.attn_net_j=nn.Sequential(nn.Linear(in_channels, gather_width), nn.Softsign())

        #Final set of linear layers for making affinity prediction
        self.output = nn.Sequential(nn.Linear(gather_width, int(gather_width / 1.5)), nn.ReLU(), nn.Linear(int(gather_width / 1.5), int(gather_width / 2)), nn.ReLU(), nn.Linear(int(gather_width / 2), 1))

    def forward(self, data):

        #Move data to GPU
        if torch.cuda.is_available():
            data.x = data.x.cuda()
            data.edge_attr = data.edge_attr.cuda()
            data.edge_index = data.edge_index.cuda()
            data.batch = data.batch.cuda()

        # allow nodes to propagate messages to themselves
        data.edge_index, data.edge_attr = add_self_loops(data.edge_index, data.edge_attr.view(-1))

        # restrict edges to the distance cutoff
        row, col = data.edge_index
        mask = data.edge_attr <= self.dist_cutoff
        mask = mask.squeeze()
        row, col, edge_feat = row[mask], col[mask], data.edge_attr[mask]
        edge_index=torch.stack([row,col],dim=0)

        # propagation
        node_feat_0 = data.x
        node_feat_1 = self.gate(node_feat_0, edge_index, edge_feat)
        node_feat_attn = torch.nn.Softmax(dim=1)(self.attn_net_i(torch.cat([node_feat_1, node_feat_0], dim=1))) * self.attn_net_j(node_feat_0)

        # globally sum features and apply linear layers
        pool_x = global_add_pool(node_feat_attn, data.batch)
        prediction = self.output(pool_x)

        return prediction

In [None]:
""" Define 3D-CNN architecture class """

class Model_3DCNN(nn.Module):

  def __conv_filter__(self, in_channels, out_channels, kernel_size, stride, padding):
    conv_filter = nn.Sequential(nn.Conv3d(in_channels, out_channels, kernel_size=kernel_size, stride=stride, padding=padding, bias=True), nn.ReLU(inplace=True), nn.BatchNorm3d(out_channels))
    return conv_filter

  def __init__(self, feat_dim=19, output_dim=1, use_cuda=True):
    super(Model_3DCNN, self).__init__()     
    self.feat_dim = feat_dim
    self.output_dim = output_dim
    self.use_cuda = use_cuda
    
    # SE block
    self.conv_block1 = self.__conv_filter__(self.feat_dim, 64, 9, 2, 3)
    self.glob_pool1 = nn.AdaptiveAvgPool3d(1)
    self.SE_block1 = nn.Linear(in_features=64, out_features=64//16, bias=False)
    self.relu = nn.ReLU()
    self.SE_block1_ = nn.Linear(in_features=64//16, out_features=64, bias=False)
    self.sigmoid = nn.Sigmoid()

    # residual blocks
    self.res_block1 = self.__conv_filter__(64, 64, 7, 1, 3)
    self.res_block2 = self.__conv_filter__(64, 64, 7, 1, 3)

    # SE block
    self.conv_block2 = self.__conv_filter__(64, 128, 7, 3, 3)
    self.glob_pool = nn.AdaptiveAvgPool3d(1)
    self.SE_block2 = nn.Linear(in_features=128, out_features=128//16, bias=False)
    self.SE_block2_ = nn.Linear(in_features=128//16, out_features=128, bias=False)
    self.max_pool = nn.MaxPool3d(2)

    ## SE block
    self.conv_block3 = self.__conv_filter__(128, 256, 5, 2, 2)
    self.SE_block3 = nn.Linear(in_features=256, out_features=256//16, bias=False)
    self.SE_block3_ = nn.Linear(in_features=256//16, out_features=256, bias=False)

    # dense layers
    self.linear1 = nn.Linear(2048, 100)
    torch.nn.init.normal_(self.linear1.weight, 0, 1)
    self.linear1_bn = nn.BatchNorm1d(num_features=100, affine=True, momentum=0.1).train()
    self.linear2 = nn.Linear(100, 1)
    torch.nn.init.normal_(self.linear2.weight, 0, 1)

  def forward(self, x):
    if x.dim() == 1:
      x = x.unsqueeze(-1)

    # SE block 1
    conv1 = self.conv_block1(x)
    a1,b1, _, _, _ = conv1.shape
    glob_pool_conv1 = self.glob_pool(conv1).view(a1, b1)
    SE_block1 = self.SE_block1(glob_pool_conv1)   
    SE_block1a = self.relu(SE_block1)
    SE_block1_ = self.SE_block1_(SE_block1a)
    SE_block1_a = self.sigmoid(SE_block1_).view(a1, b1, 1, 1, 1)  
    se1 = conv1 * SE_block1_a.expand_as(conv1)  

    # residual blocks
    conv1_res1 = self.res_block1(se1)
    conv1_res12 = conv1_res1 + se1
    conv1_res2 = self.res_block2(conv1_res12)
    conv1_res2_2 = conv1_res2 + se1

    # SE block 2
    conv2 = self.conv_block2(conv1_res2_2)
    a2,b2, _, _, _ = conv2.shape
    glob_pool_conv2 = self.glob_pool(conv2).view(a2, b2)
    SE_block2 = self.SE_block2(glob_pool_conv2)        
    SE_block2a = self.relu(SE_block2)
    SE_block2_ = self.SE_block2_(SE_block2a)
    SE_block2_a = self.sigmoid(SE_block2_).view(a2, b2, 1, 1, 1)  
    se2 = conv2 * SE_block2_a.expand_as(conv2)  

    # Pooling layer
    pool2 = self.max_pool(se2)

    # SE block 3
    conv3 = self.conv_block3(pool2)
    a3,b3, _, _, _ = conv3.shape
    glob_pool_conv3 = self.glob_pool(conv3).view(a3, b3)
    SE_block3 = self.SE_block3(glob_pool_conv3)       
    SE_block3a = self.relu(SE_block3)
    SE_block3_ = self.SE_block3_(SE_block3a)
    SE_block3_a = self.sigmoid(SE_block3_).view(a3, b3, 1, 1, 1)  
    se3 = conv3 * SE_block3_a.expand_as(conv3)  

    # Pooling layer
    pool3 = se3

    # Flatten
    flatten = pool3.view(pool3.size(0), -1)

    # Linear layer 1
    linear1_z = self.linear1(flatten)
    linear1_y = self.relu(linear1_z)
    linear1 = self.linear1_bn(linear1_y) if linear1_y.shape[0]>1 else linear1_y

    # Linear layer 2
    linear2_z = self.linear2(linear1)

    return linear2_z, flatten

In [None]:
""" Define fully-connected network class """
class Model_Linear(nn.Module):
	def __init__(self, use_cuda=True):
		super(Model_Linear, self).__init__()     
		self.use_cuda = use_cuda

		self.fc1 = nn.Linear(2048, 100)
		torch.nn.init.normal_(self.fc1.weight, 0, 1)
		self.dropout1 = nn.Dropout(0.0)
		self.fc1_bn = nn.BatchNorm1d(num_features=100, affine=True, momentum=0.3).train()
		self.fc2 = nn.Linear(100, 1)
		torch.nn.init.normal_(self.fc2.weight, 0, 1)
		self.relu = nn.ReLU()


	def forward(self, x):
		fc1_z = self.fc1(x)
		fc1_y = self.relu(fc1_z)
		fc1_d = self.dropout1(fc1_y)
		fc1 = self.fc1_bn(fc1_d) if fc1_d.shape[0]>1 else fc1_d  #batchnorm train require more than 1 batch
		fc2_z = self.fc2(fc1)
		return fc2_z, fc1_z

## Functions

In [None]:
"""Define a function to train the GCN components"""

def train_gcn(train_data, val_data, checkpoint_name, best_checkpoint_name, load_checkpoint_path = None, best_previous_checkpoint=None):

    '''
    Inputs:
    1) train_data: training hdf file name
    2) val_data: validation hdf file name
    3) checkpoint_name: path to save checkpoint_name.pt
    4) best_checkpoint_name: path to save best_checkpoint_name.pt
    5) load_checkpoint_path: path to checkpoint file to load; default is None, i.e. training from scratch
    6) best_previous_checkpoint: path to the best checkpoint from the previous round of training (required); default is None, i.e. training from scratch
    Output:
    1) checkpoint file, to load into testing function; saved as: checkpoint_dir + checkpoint_name
    '''

    # define train and validation hdf files
    train_data_hdf = h5py.File(train_data, 'r')
    val_data_hdf = h5py.File(val_data, 'r')

    # define parameters
    epochs = 300                   # number of training epochs
    batch_size = 7                 # batch size to use for training
    learning_rate = 0.001          # learning rate to use for training
    gather_width = 128             # gather width
    prop_iter = 4                  # number of propagation interations
    dist_cutoff = 3.5              # common cutoff for donor-to-acceptor distance for energetically significant H bonds in proteins is 3.5 Å
    feature_size = 20              # number of features: 19 + Van der Waals radius

    # seed all random number generators and set cudnn settings for deterministic
    random.seed(0)
    np.random.seed(0)
    torch.manual_seed(0)
    torch.cuda.manual_seed(0)
    torch.cuda.manual_seed_all(0)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False 
    os.environ['PYTHONHASHSEED'] = '0'

    def worker_init_fn(worker_id):
        np.random.seed(int(0))
        
    # initialize checkpoint parameters
    checkpoint_epoch = 0
    checkpoint_step = 0
    epoch_train_losses, epoch_val_losses, epoch_avg_corr = [], [], []
    best_average_corr = float('-inf')

    # define function to return checkpoint dictionary
    def checkpoint_model(model, dataloader, epoch, step):
        validate_dict = validate(model, dataloader)
        model.train()
        checkpoint_dict = {'model_state_dict': model.state_dict(), 'step': step, 'epoch': epoch, 'validate_dict': validate_dict,
                           'epoch_train_losses': epoch_train_losses, 'epoch_val_losses': epoch_val_losses, 'epoch_avg_corr': epoch_avg_corr, 'best_avg_corr': best_average_corr}
        torch.save(checkpoint_dict, checkpoint_name)
        return checkpoint_dict

    # define function to perform validation
    def validate(model, val_dataloader):
        # initialize
        model.eval()
        y_true = np.zeros((len(val_dataset),), dtype=np.float32)
        y_pred = np.zeros((len(val_dataset),), dtype=np.float32)
        # validation
        for batch_ind, batch in enumerate(val_dataloader):
            data_list = []
            for dataset in batch:
                pdbid = dataset[0]
                affinity = val_data_hdf[pdbid].attrs['affinity'].reshape(1,-1)
                vdw_radii = (val_data_hdf[pdbid].attrs['van_der_waals'].reshape(-1, 1))
                node_feats = np.concatenate([vdw_radii, val_data_hdf[pdbid][:, 3:22]], axis=1)
                edge_index, edge_attr = dense_to_sparse(torch.from_numpy(dataset[1]).float())
                x = torch.from_numpy(node_feats).float()
                y = torch.FloatTensor(affinity).view(-1, 1)
                data = Data(x=x, edge_index=edge_index, edge_attr=edge_attr.view(-1, 1), y=y)
                data_list.append(data)
            batch_data = [x for x in data_list]
            y_ = model(batch_data)
            y = torch.cat([x.y for x in data_list])
            y_true[batch_ind*batch_size:batch_ind*batch_size+7] = y.cpu().float().data.numpy()[:,0]
            y_pred[batch_ind*batch_size:batch_ind*batch_size+7] = y_.cpu().float().data.numpy()[:,0]
            loss = criterion(y.float(), y_.cpu().float())
            print('[%d/%d-%d/%d] validation loss: %.3f' % (epoch+1, epochs, batch_ind+1, len(val_dataset)//batch_size, loss))

        # compute r^2
        r2 = r2_score(y_true=y_true, y_pred=y_pred)
        # compute mae
        mae = mean_absolute_error(y_true=y_true, y_pred=y_pred)
        # compute mse
        mse = mean_squared_error(y_true=y_true, y_pred=y_pred)
        # compute pearson correlation coefficient
        pearsonr = stats.pearsonr(y_true.reshape(-1), y_pred.reshape(-1))[0]
        # compte spearman correlation coefficient
        spearmanr = stats.spearmanr(y_true.reshape(-1), y_pred.reshape(-1))[0]
        # write out metrics
        print('r2: {}\tmae: {}\trmse: {}\tpearsonr: {}\t spearmanr: {}'.format(r2, mae, mse**(1/2), pearsonr, spearmanr))
        epoch_val_losses.append(mse)
        epoch_avg_corr.append((pearsonr+spearmanr)/2)
        model.train()
        return {'r2': r2, 'mse': mse, 'mae': mae, 'pearsonr': pearsonr, 'spearmanr': spearmanr,
                'y_true': y_true, 'y_pred': y_pred, 'best_average_corr': best_average_corr}
   
    # construct model
    model = GeometricDataParallel(GCN(in_channels=feature_size, gather_width=gather_width, prop_iter=prop_iter, dist_cutoff=dist_cutoff)).float()

    train_dataset = GCN_Dataset(data_file=train_data)
    val_dataset = GCN_Dataset(data_file=val_data)
        
    # construct training and validation dataloaders to be fed to model
    batch_count=len(train_dataset)
    train_dataloader = DataListLoader(train_dataset, batch_size=batch_size, shuffle=True, worker_init_fn=worker_init_fn, drop_last=True)
    val_dataloader = DataListLoader(val_dataset, batch_size=batch_size, shuffle=False, worker_init_fn=worker_init_fn, drop_last=True)

    # load checkpoint file
    if load_checkpoint_path != None:
        if torch.cuda.is_available():
            model_train_dict = torch.load(load_checkpoint_path)
            best_checkpoint = torch.load(best_previous_checkpoint)
        else:
            model_train_dict = torch.load(load_checkpoint_path, map_location=torch.device('cpu'))
            best_checkpoint = torch.load(best_previous_checkpoint, map_location = torch.device('cpu'))
        model.load_state_dict(model_train_dict['model_state_dict'])
        checkpoint_epoch = model_train_dict['epoch']
        checkpoint_step = model_train_dict['step']
        epoch_train_losses = model_train_dict['epoch_train_losses']
        epoch_val_losses = model_train_dict['epoch_val_losses']
        epoch_avg_corr = model_train_dict['epoch_avg_corr']
        val_dict = model_train_dict['validate_dict']
        torch.save(best_checkpoint, best_checkpoint_name)
        best_average_corr = best_checkpoint["best_avg_corr"]
        
    model.train()
    model.to(0)
    
    # set loss as MSE
    criterion = nn.MSELoss().float()
    # set Adam optimizer
    optimizer = Adam(model.parameters(), lr=learning_rate) 
    
    # train model
    step = checkpoint_step
    for epoch in range(checkpoint_epoch, epochs):
        y_true = np.zeros((len(train_dataset),), dtype=np.float32)
        y_pred = np.zeros((len(train_dataset),), dtype=np.float32)
        for batch_ind, batch in enumerate(train_dataloader):
            data_list = []
            pdbid_array = []
            for dataset in batch:
                pdbid = dataset[0]
                pdbid_array.append(pdbid)
                affinity = train_data_hdf[pdbid].attrs['affinity'].reshape(1,-1)
                vdw_radii = (train_data_hdf[pdbid].attrs['van_der_waals'].reshape(-1, 1))
                node_feats = np.concatenate([vdw_radii, train_data_hdf[pdbid][:, 3:22]], axis=1)
                edge_index, edge_attr = dense_to_sparse(torch.from_numpy(dataset[1]).float()) 
                x = torch.from_numpy(node_feats).float()
                y = torch.FloatTensor(affinity).view(-1, 1)
                data = Data(x=x, edge_index=edge_index, edge_attr=edge_attr.view(-1, 1), y=y)
                data_list.append(data)

            optimizer.zero_grad()
            batch_data = [x for x in data_list]
            y_ = model(batch_data)
            y = torch.cat([x.y for x in data_list])
            y_true[batch_ind*batch_size:batch_ind*batch_size+7] = y.cpu().float().data.numpy()[:,0]
            y_pred[batch_ind*batch_size:batch_ind*batch_size+7] = y_.cpu().float().data.numpy()[:,0]

            # compute loss and update parameters
            loss = criterion(y.float(), y_.cpu().float())
            loss.backward()
            optimizer.step()
            step += 1
            print("[%d/%d-%d/%d] training loss: %.3f" % (epoch+1, epochs, batch_ind+1, len(train_dataset)//batch_size, loss))

        r2 = r2_score(y_true=y_true, y_pred=y_pred)
        mae = mean_absolute_error(y_true=y_true, y_pred=y_pred)
        mse=mean_squared_error(y_true,y_pred)
        epoch_train_losses.append(mse)
        pearsonr = stats.pearsonr(y_true.reshape(-1), y_pred.reshape(-1))
        spearmanr = stats.spearmanr(y_true.reshape(-1), y_pred.reshape(-1))

        # write training summary for the epoch
        print('epoch: {}\trmse:{:0.4f}\tr2: {:0.4f}\t pearsonr: {:0.4f}\tspearmanr: {:0.4f}\tmae: {:0.4f}\tpred'.format(epoch+1, mse**(1/2), r2, float(pearsonr[0]),
                    float(spearmanr[0]), float(mae)))
        
        checkpoint_dict = checkpoint_model(model, val_dataloader, epoch+1, step)
        if (checkpoint_dict["validate_dict"]["pearsonr"] + checkpoint_dict["validate_dict"]["spearmanr"])/2 > best_average_corr:
          best_average_corr = (checkpoint_dict["validate_dict"]["pearsonr"] + checkpoint_dict["validate_dict"]["spearmanr"])/2
          torch.save(checkpoint_dict, best_checkpoint_name)
        torch.save(checkpoint_dict, checkpoint_name)
          
    # learning curve and correlation plot
    fig, axs = plt.subplots(2)
    axs[0].plot(np.arange(1, epochs+1), np.array(epoch_train_losses), label = 'training')
    axs[0].plot(np.arange(1, epochs+1), np.array(epoch_val_losses), label = 'validation')
    axs[0].set_xlabel('Epoch', fontsize=20)
    axs[0].set_ylabel('Loss', fontsize=20)
    axs[0].legend(fontsize=18)
    axs[1].plot(np.arange(1, epochs+1), np.array(epoch_avg_corr))
    axs[1].set_xlabel('Epoch', fontsize=20)
    axs[1].set_ylabel('Validation Correlation', fontsize=20)
    axs[1].set_ylim(0,1)
    plt.show()
  
    train_data_hdf.close()
    val_data_hdf.close()

In [None]:
"""Define a function to train the 3D-CNN component"""

def train_3dcnn(train_hdf, val_hdf, checkpoint_dir, best_checkpoint_dir, previous_checkpoint = None, best_previous_checkpoint=None):
    '''
    Inputs:
    1) train_hdf: training hdf file name
    2) val_hdf: validation hdf file name
    3) checkpoint_dir: path to save checkpoint file: 'path/to/file.pt'
    4) best_checkpoint_dir: path to save best checkpoint file: 'path/to/file.pt'
    5) previous_checkpoint: path to the checkpoint at which training should be started; default = None, i.e. training from scratch
    6) best_previous_checkpoint: path to the best checkpoint from the previous round of training (required); default = None, i.e. training from scratch
    Output:
    1) checkpoint file from the endpoint of the training
    2) checkpoint file from the epoch with highest average correlation on validation
    '''

    # define parameters
    batch_size = 50
    learning_rate = .0007
    learning_decay_iter=150
    epoch_count = 100

    # set CUDA for PyTorch
    use_cuda = torch.cuda.is_available()
    device = torch.device('cuda:0')
    torch.cuda.set_device(0)

    # initialize Datasets
    dataset = CNN_Dataset(train_hdf)
    val_dataset = CNN_Dataset(val_hdf)

    # initialize Dataloaders
    batch_count = len(dataset.data_info_list) // batch_size + 1
    dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True)
    val_dataloader = DataLoader(val_dataset, batch_size=batch_size, shuffle=False)

    # define model and helper functions
    model = Model_3DCNN(use_cuda=use_cuda)
    model.to(device)
    loss_func = nn.MSELoss().float()
    optimizer = RMSprop(model.parameters(), lr=learning_rate)
    scheduler = lr_scheduler.StepLR(optimizer, step_size=150, gamma=0.95)
    
    # initialize training variables
    epoch_start = 0
    step = 0
    epoch_train_losses, epoch_val_losses, epoch_avg_corr = [], [], []
    best_average_corr = float('-inf')

    #load previous checkpoint if applicable
    if previous_checkpoint!=None:
        best_checkpoint = torch.load(best_previous_checkpoint, map_location = device)
        torch.save(best_checkpoint, best_checkpoint_dir)
        best_average_corr = best_checkpoint["best_avg_corr"]
        checkpoint = torch.load(previous_checkpoint, map_location=device)
        model_state_dict = checkpoint.pop('model_state_dict')
        model.load_state_dict(model_state_dict, strict=False)
        optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
        epoch_start = checkpoint['epoch'] + 1
        step=checkpoint['step']
        epoch_train_losses = checkpoint['epoch_train_losses']
        epoch_val_losses = checkpoint['epoch_val_losses']
        epoch_avg_corr = checkpoint['epoch_avg_corr']
        print('checkpoint loaded: %s' % previous_checkpoint)

    def validate_model():
        y_true_array = np.zeros((len(val_dataset),), dtype=np.float32)
        y_pred_array = np.zeros((len(val_dataset),), dtype=np.float32)
        model.eval()
        with torch.no_grad():
            for batch_ind, batch in enumerate(val_dataloader):
               
                # transfer to GPU
                x_batch_cpu, y_batch_cpu, _ = batch
                x_batch, y_batch = x_batch_cpu.to(device), y_batch_cpu.to(device)
                ypred_batch, _ = model(x_batch[:x_batch.shape[0]])
                
                # compute and print batch loss
                loss = loss_func(ypred_batch.cpu().float(), y_batch_cpu.float())
                print('[%d/%d-%d/%d] validation loss: %.3f' % (epoch_ind+1, epoch_count, batch_ind+1, batch_count, loss.cpu().data.item()))
                
                #assemble the full datasets
                bsize = x_batch.shape[0]
                ytrue = y_batch_cpu.float().data.numpy()[:,0]
                ypred = ypred_batch.cpu().float().data.numpy()[:,0]
                y_true_array[batch_ind*batch_size:batch_ind*batch_size+bsize] = ytrue
                y_pred_array[batch_ind*batch_size:batch_ind*batch_size+bsize] = ypred
                
            #compute average correlation
            pearsonr = stats.pearsonr(y_true_array, y_pred_array)[0]
            spearmanr = stats.spearmanr(y_true_array, y_pred_array)[0]
            avg_corr = (pearsonr + spearmanr)/2
        
        # print information during training
        print('[%d/%d] validation results-- pearsonr: %.3f, spearmanr: %.3f, rmse: %.3f, mae: %.3f, r2: %.3f' % (epoch_ind+1, epoch_count, pearsonr, spearmanr, mean_squared_error(y_true_array, 
                                                                                          y_pred_array)**(1/2), mean_absolute_error(y_true_array, y_pred_array), r2_score(y_true_array, y_pred_array)))
        return mean_squared_error(y_true_array, y_pred_array), avg_corr

    for epoch_ind in range(epoch_start, epoch_count):
        x_batch = torch.zeros((batch_size,19,48,48,48)).float().to(device)
        y_true_epoch = np.zeros((len(dataset),), dtype=np.float32)
        y_pred_epoch = np.zeros((len(dataset),), dtype=np.float32)
        for batch_ind, batch in enumerate(dataloader):
            model.train()

        # transfer to GPU and save in the epoch array
            x_batch_cpu, y_batch_cpu, _ = batch
            x_batch, y_batch = x_batch_cpu.to(device), y_batch_cpu.to(device)
            bsize = x_batch.shape[0]
            ypred_batch, _ = model(x_batch[:x_batch.shape[0]])
            y_true_epoch[batch_ind*batch_size:batch_ind*batch_size+bsize] = y_batch_cpu.float().data.numpy()[:,0]
            y_pred_epoch[batch_ind*batch_size:batch_ind*batch_size+bsize] = ypred_batch.cpu().float().data.numpy()[:,0]

        # compute loss 
            loss = loss_func(ypred_batch.cpu().float(), y_batch_cpu.float())
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
            scheduler.step()
            step += 1
            print("[%d/%d-%d/%d] training loss: %.3f" % (epoch_ind+1, epoch_count, batch_ind+1, batch_count, loss))

        epoch_train_losses.append(mean_squared_error(y_true_epoch, y_pred_epoch))
        val_loss, average_corr = validate_model()
        epoch_val_losses.append(val_loss)
        epoch_avg_corr.append(average_corr)
        
        checkpoint_dict = {'model_state_dict': model.state_dict(),'optimizer_state_dict': optimizer.state_dict(),'loss': loss,'step': step,'epoch': epoch_ind,
                           'epoch_val_losses': epoch_val_losses,'epoch_train_losses': epoch_train_losses,'epoch_avg_corr' : epoch_avg_corr,'best_avg_corr': best_average_corr}

        if (average_corr > best_average_corr):
            best_average_corr = average_corr
            checkpoint_dict["best_avg_corr"] = best_average_corr
            torch.save(checkpoint_dict, best_checkpoint_dir)
            print("best checkpoint saved: %s" % best_checkpoint_dir)
        torch.save(checkpoint_dict, checkpoint_dir)
        print('checkpoint saved: %s' % checkpoint_dir)
    
    # create learning curve and correlation plot
    fig, axs = plt.subplots(2)
    axs[0].plot(np.arange(1, epoch_count+1), np.array(epoch_train_losses), label = 'training')
    axs[0].plot(np.arange(1, epoch_count+1), np.array(epoch_val_losses), label = 'validation')
    axs[0].set_xlabel('Epoch', fontsize=20)
    axs[0].set_ylabel('Loss', fontsize=20)
    axs[0].legend(fontsize=18)
    axs[1].plot(np.arange(1, epoch_count+1), np.array(epoch_avg_corr))
    axs[1].set_xlabel('Epoch', fontsize=20)
    axs[1].set_ylabel('Validation Correlation', fontsize=20)
    axs[1].set_ylim(0,1)
    plt.show()

    # close dataset
    dataset.close()
    val_dataset.close()

In [None]:
"""Define a function to extract flattened features from trained 3D-CNN"""

def savefeat_3dcnn(hdf_path, checkpoint_path, npy_path):

    """
    Inputs:
    1) hdf_path: path/to/file.hdf
    2) feature length: length of the flattened output features
    3) checkpoint_path: path/to/checkpoint/file.pt
    4) npy_path: path/to/save/features.npy
    Output:
    1) numpy file containing the saved features, with the last column being the true affinity value.
    """

    # define parameters
    multi_gpus = False
    batch_size = 50
    device_name = "cuda:0"
    # set CUDA for PyTorch
    use_cuda = torch.cuda.is_available()
    cuda_count = torch.cuda.device_count()
    if use_cuda:
        device = torch.device(device_name)
        torch.cuda.set_device(int(device_name.split(':')[1]))
    else:   
        device = torch.device("cpu")
    print(use_cuda, cuda_count, device)
    # load testing 
    dataset = CNN_Dataset(hdf_path)
    # check multi-gpus
    num_workers = 0
    if multi_gpus and cuda_count > 1:
        num_workers = cuda_count
    # initialize testing data loader
    batch_count = len(dataset) // batch_size
    dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=False, num_workers=num_workers, worker_init_fn=None)
    # define model
    model = Model_3DCNN(use_cuda=use_cuda)
    if multi_gpus and cuda_count > 1:
        model = nn.DataParallel(model)
    model.to(device)
    if isinstance(model, (DistributedDataParallel, DataParallel)):
        model = model.module
    # load checkpoint file
    checkpoint = torch.load(checkpoint_path, map_location=device)
    # model state dict
    model_state_dict = checkpoint.pop("model_state_dict")
    model.load_state_dict(model_state_dict, strict=False)
    # create empty arrays to hold predicted and true values
    ytrue_arr = np.zeros((len(dataset),), dtype=np.float32)
    ypred_arr = np.zeros((len(dataset),), dtype=np.float32)
    flatfeat_arr = np.zeros((len(dataset), 2048 + 1))
    pdbid_arr = np.zeros((len(dataset),), dtype=object)
    pred_list = []
    model.eval()
    with torch.no_grad():
        for batch_ind, batch in enumerate(dataloader):
            # transfer to GPU
            x_batch_cpu, y_batch_cpu, pdbid_batch = batch
            x_batch, y_batch = x_batch_cpu.to(device), y_batch_cpu.to(device)
            # arrange and filter
            bsize = x_batch.shape[0]
            ypred_batch, flatfeat_batch = model(x_batch[:x_batch.shape[0]])
            ytrue = y_batch_cpu.float().data.numpy()[:,0]
            ypred = ypred_batch.cpu().float().data.numpy()[:,0]
            flatfeat = flatfeat_batch.cpu().data.numpy()
            ytrue_arr[batch_ind*batch_size:batch_ind*batch_size+bsize] = ytrue
            ypred_arr[batch_ind*batch_size:batch_ind*batch_size+bsize] = ypred
            flatfeat_arr[batch_ind*batch_size:batch_ind*batch_size+bsize, :-1] = flatfeat
            pdbid_arr[batch_ind*batch_size:batch_ind*batch_size+bsize] = pdbid_batch
    flatfeat_arr[:,-1] = ytrue_arr
    np.save(npy_path, flatfeat_arr)

In [None]:
"""Define a function to train the fully-connected network with extracted features from 3D-CNN"""

def train_Linear(input_train_data, input_val_data, checkpoint_dir, best_checkpoint_dir, learning_decay_iter = 150, load_previous_checkpoint = False, previous_checkpoint = None, best_previous_checkpoint = None):
    """
    Inputs:
    1) input_train_data: path to train.npy data
    2) input_val_data: path to val.npy data
    3) checkpoint_dir: path to save checkpoint file: 'path/to/file.pt'
    4) best_checkpoint_dir: path to save best checkpoint file: 'path/to/file.pt'
    5) learning_decay_iter: frequency at which the learning rate is decreased by a multiplicative factor of decay_rate; set to 150 by default
    6) load_previous_checkpoint: boolean variable indicating whether or not training should be started from an existing checkpoint. False by default
    7) previous_checkpoint: path to the checkpoint at which training should be started. None by default.
    8) best_previous_checkpoint: path to the checkpoint which was saved as "best" in the previous training. None by default
    Outputs:
    1) checkpoint file from the endpoint of the training
    2) best checkpoint file, defined as that which maximizes the average of pearson and spearman correlations obtained from the validation data
    """

    # define parameters
    batch_size = 50
    learning_rate = .0007
    decay_iter = learning_decay_iter 
    decay_rate = 0.95
    epoch_count = 100
    checkpoint_iter = 343
    device_name = "cuda:0"
    multi_gpus = False

    # set CUDA for PyTorch
    use_cuda = torch.cuda.is_available()
    cuda_count = torch.cuda.device_count()
    if use_cuda:
        device = torch.device(device_name)
        torch.cuda.set_device(int(device_name.split(':')[1]))
    else:
        device = torch.device("cpu")
    print(use_cuda, cuda_count, device)

    def worker_init_fn(worker_id):
        np.random.seed(int(0))

    # build training dataset variable
    dataset = Linear_Dataset(input_train_data)

    # build validation dataset variable
    val_dataset = Linear_Dataset(input_val_data)

    # check multi-gpus
    num_workers = 0
    if multi_gpus and cuda_count > 1:
        num_workers = cuda_count

    def validate_model():
        loss_fn = nn.MSELoss().float()
        ytrue_arr = np.zeros((len(val_dataset),), dtype=np.float32)
        ypred_arr = np.zeros((len(val_dataset),), dtype=np.float32)
        model.eval()
        with torch.no_grad():
            for batch_ind, batch in enumerate(val_dataloader):
        # transfer to GPU
                x_batch_cpu, y_batch_cpu = batch
                x_batch, y_batch = x_batch_cpu.to(device), y_batch_cpu.to(device)
                ypred_batch, _ = model(x_batch[:x_batch.shape[0]])
        # compute and print batch loss
                loss = loss_fn(ypred_batch.cpu().float(), y_batch_cpu.float())
        #assemble the full datasets
                bsize = x_batch.shape[0]
                ytrue = y_batch_cpu.float().data.numpy()[:,0]
                ypred = ypred_batch.cpu().float().data.numpy()[:,0]
                ytrue_arr[batch_ind*batch_size:batch_ind*batch_size+bsize] = ytrue
                ypred_arr[batch_ind*batch_size:batch_ind*batch_size+bsize] = ypred            
        #compute average correlation
            pearsonr = stats.pearsonr(ytrue_arr, ypred_arr)[0]
            spearmanr = stats.spearmanr(ytrue_arr, ypred_arr)[0]
            avg_corr = (pearsonr + spearmanr)/2
        return mean_squared_error(ytrue_arr, ypred_arr), avg_corr

    # initialize training data loader
    batch_count = len(dataset) // batch_size
    dataloader = DataLoader(dataset, batch_size=batch_size, num_workers=num_workers, worker_init_fn=None, shuffle=True)

    # initialize validation data loader
    val_dataloader = DataLoader(val_dataset, batch_size=batch_size, shuffle=False, num_workers=num_workers, worker_init_fn=None)

    # define model
    model = Model_Linear(use_cuda=use_cuda, verbose=verbose)
    if multi_gpus and cuda_count > 1:
        model = nn.DataParallel(model)
    model.to(device)
    if isinstance(model, (DistributedDataParallel, DataParallel)):
        model = model.module

    # define loss
    loss_fn = nn.MSELoss().float()
    # define optimizer
    optimizer = RMSprop(model.parameters(), lr=learning_rate)
    # define scheduler
    scheduler = lr_scheduler.StepLR(optimizer, step_size=decay_iter, gamma=decay_rate)
    # train model
    epoch_start = 0
    step = 0
    epoch_train_losses, epoch_val_losses, epoch_avg_corr = [], [], []
    best_average_corr = float('-inf')
    best_checkpoint_dict = None
    #load previous checkpoint if applicable
    if load_previous_checkpoint:
        best_checkpoint = torch.load(best_previous_checkpoint, map_location = device)
        best_checkpoint_dict = best_checkpoint.pop("model_state_dict")
        best_average_corr = best_checkpoint["best_avg_corr"]
        checkpoint = torch.load(previous_checkpoint, map_location=device)
        model_state_dict = checkpoint.pop("model_state_dict")
        model.load_state_dict(model_state_dict, strict=False)
        optimizer.load_state_dict(checkpoint["optimizer_state_dict"])
        epoch_start = checkpoint["epoch"]
        loss = checkpoint["loss"]
        epoch_train_losses = checkpoint["epoch_train_losses"]
        epoch_val_losses = checkpoint["epoch_val_losses"]
        epoch_avg_corr = checkpoint["epoch_avg_corr"]
    for epoch_ind in range(epoch_start, epoch_count):
        losses = []
        for batch_ind, batch in enumerate(dataloader):
            model.train()
        # transfer to GPU
            x_batch_cpu, y_batch_cpu = batch
            x_batch, y_batch = x_batch_cpu.to(device), y_batch_cpu.to(device)
            ypred_batch, _ = model(x_batch[:x_batch.shape[0]])
        # compute loss
            loss = loss_fn(ypred_batch.cpu().float(), y_batch_cpu.float())
            losses.append(loss.cpu().data.item())
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
            scheduler.step()
            val_loss, average_corr = validate_model()
            checkpoint_dict = {
                "model_state_dict": model.state_dict(),
                "optimizer_state_dict": optimizer.state_dict(),
                "loss": loss,
                "step": step,
                "epoch": epoch_ind,
                "epoch_val_losses": epoch_val_losses,
                "epoch_train_losses": epoch_train_losses,
                "epoch_avg_corr" : epoch_avg_corr,
                "best_avg_corr" : best_average_corr 
            }
            if (average_corr > best_average_corr):
                best_average_corr = average_corr
                checkpoint_dict["best_avg_corr"] = best_average_corr
                best_checkpoint_dict = checkpoint_dict
                torch.save(best_checkpoint_dict, best_checkpoint_dir)
            torch.save(checkpoint_dict, checkpoint_dir)
        step += 1
    val_loss, average_corr = validate_model()
    epoch_train_losses.append(np.mean(losses))
    epoch_val_losses.append(val_loss)
    epoch_avg_corr.append(average_corr)
    checkpoint_dict = {
                "model_state_dict": model.state_dict(),
                "optimizer_state_dict": optimizer.state_dict(),
                "loss": loss,
                "step": step,
                "epoch": epoch_ind,
                "epoch_val_losses": epoch_val_losses,
                "epoch_train_losses": epoch_train_losses,
                "epoch_avg_corr" : epoch_avg_corr,
                "best_avg_corr": best_average_corr
            }
    if (average_corr > best_average_corr):
        best_average_corr = average_corr
        checkpoint_dict["best_avg_corr"] = best_average_corr
        best_checkpoint_dict = checkpoint_dict
        torch.save(best_checkpoint_dict, best_checkpoint_dir)
    torch.save(checkpoint_dict, checkpoint_dir)

## Train

In [None]:
# train GCN_0 for testing on PDBbind 2016 core set benchmark

train_gcn(train_data = '/content/drive/MyDrive/HAC-Net/HAC-Net_files/train_val_test_files/MP-GCN/test_on_core_2016/2020_train_minus_core.hdf',
          val_data = '/content/drive/MyDrive/HAC-Net/HAC-Net_files/train_val_test_files/MP-GCN/test_on_core_2016/2020_val_minus_core.hdf',
          checkpoint_name = 'content/gcn0_core2016.pt',
          best_checkpoint_name = 'content/best_gcn0_core2016.pt')

In [None]:
# train GCN_1 for testing on PDBbind 2016 core set benchmark

train_gcn(train_data = '/content/drive/MyDrive/HAC-Net/HAC-Net_files/train_val_test_files/MP-GCN/test_on_core_2016/2020_train_minus_core.hdf',
          val_data = '/content/drive/MyDrive/HAC-Net/HAC-Net_files/train_val_test_files/MP-GCN/test_on_core_2016/2020_val_minus_core.hdf',
          checkpoint_name = 'content/gcn1_core2016.pt',
          best_checkpoint_name = 'content/best_gcn1_core2016.pt')

In [None]:
# train 3D-CNN for testing on PDBbind 2016 core set benchmark

train_3dcnn(train_hdf = '/content/drive/MyDrive/HAC-Net/HAC-Net_files/train_val_test_files/3D-CNN/test_on_core_2016/vox_fixed_2020_train_minus_core.hdf',
          val_hdf = '/content/drive/MyDrive/HAC-Net/HAC-Net_files/train_val_test_files/3D-CNN/test_on_core_2016/vox_fixed_2020_val_minus_core.hdf',
          checkpoint_dir = 'content/3dcnn_core2016.pt',
          best_checkpoint_dir = 'content/best_3dcnn_core2016.pt')

In [None]:
"""Extract features from trained 3D-CNN"""

# training set
savefeat_3dcnn(hdf_path = '/content/drive/MyDrive/HAC-Net/HAC-Net_files/train_val_test_files/3D-CNN/test_on_core_2016/vox_fixed_2020_train_minus_core.hdf',
               checkpoint_path = 'content/best_3dcnn_core2016.pt',
               npy_path = 'content/3dcnn_features_train.npy')

# validation set
savefeat_3dcnn(hdf_path = '/content/drive/MyDrive/HAC-Net/HAC-Net_files/train_val_test_files/3D-CNN/test_on_core_2016/vox_fixed_2020_val_minus_core.hdf',
               checkpoint_path = 'content/best_3dcnn_core2016.pt',
               npy_path = 'content/3dcnn_features_val.npy')

# test set
savefeat_3dcnn(hdf_path = '/content/drive/MyDrive/HAC-Net/HAC-Net_files/train_val_test_files/3D-CNN/test_on_core_2016/vox_2016_core.hdf',
               checkpoint_path = 'content/best_3dcnn_core2016.pt',
               npy_path = 'content/3dcnn_features_test.npy')

In [None]:
# train fully-connected network for use in HAC-Net
train_Linear(input_train_data = 'content/3dcnn_features_train.npy',
             input_val_data = 'content/3dcnn_features_val.npy', 
             checkpoint_dir = 'content/3dcnn_fully-connected.pt', 
             best_checkpoint_dir = 'content/best_3dcnn_fully-connected.pt')