In [1]:
from __future__ import print_function
import torch
import torch.optim as optim
import torch.nn as nn
from torch.autograd import Variable
import time
import argparse
import os,sys
import numpy as np
import networkx as nx
import random
import argparse
from rdkit import Chem
from rdkit.Chem import ChemicalFeatures
from rdkit import RDConfig
from joblib import Parallel, delayed
from os import listdir
from os.path import isfile, join
import xml.etree.ElementTree as ET
import shutil
import torch.utils.data as data
import torch.nn.functional as F
from torch.autograd.variable import Variable
from rdkit import Chem
from rdkit.Chem import ChemicalFeatures
from rdkit import RDConfig
from os import listdir
import rdkit
import multiprocessing
from os.path import isfile, join
from datasets import utils
random.seed(2)
np.random.seed(2)
best_er1=0



In [2]:
parser = argparse.ArgumentParser(description='Neural message passing')
parser.add_argument('--root', nargs=1, help='Specify the data directory.', default=['data-parse/chemfindata/'])
parser.add_argument('--dataset', default='our_data')
parser.add_argument('--datasetPath', default='./data-parse/chemfindata/', help='dataset path')
parser.add_argument('--logPath', default='./log/our_data/mpnn/', help='log path')
parser.add_argument('--plotLr', default=False, help='allow plotting the data')
parser.add_argument('--plotPath', default='./plot/our_data/mpnn/', help='plot path')
parser.add_argument('--resume', default='./checkpoint/our_data/mpnn/',
                    help='path to latest checkpoint')
parser.add_argument('--batch-size', type=int, default=10, metavar='N',
                    help='Input batch size for training (default: 20)')
parser.add_argument('--no-cuda', action='store_true', default=False,
                    help='Enables CUDA training')
parser.add_argument('--epochs', type=int, default=360, metavar='N',
                    help='Number of epochs to train (default: 360)')
parser.add_argument('--lr', type=lambda x: restricted_float(x, [1e-5, 1e-2]), default=1e-3, metavar='LR',
                    help='Initial learning rate [1e-5, 5e-4] (default: 1e-4)')
parser.add_argument('--lr-decay', type=lambda x: restricted_float(x, [.01, 1]), default=0.6, metavar='LR-DECAY',
                    help='Learning rate decay factor [.01, 1] (default: 0.6)')
parser.add_argument('--schedule', type=list, default=[0.1, 0.9], metavar='S',
                    help='Percentage of epochs to start the learning rate decay [0, 1] (default: [0.1, 0.9])')
parser.add_argument('--momentum', type=float, default=0.9, metavar='M',
                    help='SGD momentum (default: 0.9)')
parser.add_argument('--log-interval', type=int, default=20, metavar='N',
                    help='How many batches to wait before logging training status')
parser.add_argument('--prefetch', type=int, default=2, help='Pre-fetching threads.')
args = parser.parse_args([])
args.cuda = not args.no_cuda and torch.cuda.is_available()

In [3]:
class NNet(nn.Module):

    def __init__(self, n_in, n_out, hlayers=(128, 256, 128)):
        super(NNet, self).__init__()
        self.n_hlayers = len(hlayers)
        self.fcs = nn.ModuleList([nn.Linear(n_in, hlayers[i]) if i == 0 else
                                  nn.Linear(hlayers[i-1], n_out) if i == self.n_hlayers else
                                  nn.Linear(hlayers[i-1], hlayers[i]) for i in range(self.n_hlayers+1)])

    def forward(self, x):
        x = x.contiguous().view(-1, self.num_flat_features(x))
        for i in range(self.n_hlayers):
            x = F.relu(self.fcs[i](x))
        x = self.fcs[-1](x)
        return x

    def num_flat_features(self, x):
        size = x.size()[1:]  # all dimensions except the batch dimension
        num_features = 1
        for s in size:
            num_features *= s
        return num_features

In [4]:
class MessageFunction(nn.Module):

    # Constructor
    def __init__(self, message_def='duvenaud', args={}):
        super(MessageFunction, self).__init__()
        self.m_definition = ''
        self.m_function = None
        self.args = {}
        self.__set_message(message_def, args)

    # Message from h_v to h_w through e_vw
    def forward(self, h_v, h_w, e_vw, args=None):
        return self.m_function(h_v, h_w, e_vw, args)

    # Set a message function
    def __set_message(self, message_def, args={}):
        self.m_definition = message_def.lower()

        self.m_function = {
                    'duvenaud':         self.m_duvenaud,
                    'ggnn':             self.m_ggnn,
                    'intnet':           self.m_intnet,
                    'mpnn':             self.m_mpnn,
                    'mgc':              self.m_mgc,
                    'bruna':            self.m_bruna,
                    'defferrard':       self.m_deff,
                    'kipf':             self.m_kipf
                }.get(self.m_definition, None)

        if self.m_function is None:
            print('WARNING!: Message Function has not been set correctly\n\tIncorrect definition ' + message_def)
            quit()

        init_parameters = {
            'duvenaud': self.init_duvenaud,
            'ggnn':     self.init_ggnn,
            'intnet':   self.init_intnet,
            'mpnn':     self.init_mpnn
        }.get(self.m_definition, lambda x: (nn.ParameterList([]), nn.ModuleList([]), {}))

        self.learn_args, self.learn_modules, self.args = init_parameters(args)

        self.m_size = {
                'duvenaud':     self.out_duvenaud,
                'ggnn':         self.out_ggnn,
                'intnet':       self.out_intnet,
                'mpnn':         self.out_mpnn
            }.get(self.m_definition, None)

    # Get the name of the used message function
    def get_definition(self):
        return self.m_definition

    # Get the message function arguments
    def get_args(self):
        return self.args

    # Get Output size
    def get_out_size(self, size_h, size_e, args=None):
        return self.m_size(size_h, size_e, args)

    # Definition of various state of the art message functions
    
    # Duvenaud et al. (2015), Convolutional Networks for Learning Molecular Fingerprints
    def m_duvenaud(self, h_v, h_w, e_vw, args):
        m = torch.cat([h_w, e_vw], 2)
        return m

    def out_duvenaud(self, size_h, size_e, args):
        return size_h + size_e

    def init_duvenaud(self, params):
        learn_args = []
        learn_modules = []
        args = {}
        return nn.ParameterList(learn_args), nn.ModuleList(learn_modules), args

    # Li et al. (2016), Gated Graph Neural Networks (GG-NN)
    def m_ggnn(self, h_v, h_w, e_vw, opt={}):

        m = Variable(torch.zeros(h_w.size(0), h_w.size(1), self.args['out']).type_as(h_w.data))

        for w in range(h_w.size(1)):
            if torch.nonzero(e_vw[:, w, :].data).size():
                for i, el in enumerate(self.args['e_label']):
                    ind = (el == e_vw[:,w,:]).type_as(self.learn_args[0][i])

                    parameter_mat = self.learn_args[0][i][None, ...].expand(h_w.size(0), self.learn_args[0][i].size(0),
                                                                            self.learn_args[0][i].size(1))

                    m_w = torch.transpose(torch.bmm(torch.transpose(parameter_mat, 1, 2),
                                                                        torch.transpose(torch.unsqueeze(h_w[:, w, :], 1),
                                                                                        1, 2)), 1, 2)
                    m_w = torch.squeeze(m_w)
                    m[:,w,:] = ind.expand_as(m_w)*m_w
        return m

    def out_ggnn(self, size_h, size_e, args):
        return self.args['out']

    def init_ggnn(self, params):
        learn_args = []
        learn_modules = []
        args = {}

        args['e_label'] = params['e_label']
        args['in'] = params['in']
        args['out'] = params['out']

        # Define a parameter matrix A for each edge label.
        learn_args.append(nn.Parameter(torch.randn(len(params['e_label']), params['in'], params['out'])))

        return nn.ParameterList(learn_args), nn.ModuleList(learn_modules), args

    # Battaglia et al. (2016), Interaction Networks
    def m_intnet(self, h_v, h_w, e_vw, args):
        m = torch.cat([h_v[:, None, :].expand_as(h_w), h_w, e_vw], 2)
        b_size = m.size()

        m = m.view(-1, b_size[2])

        m = self.learn_modules[0](m)
        m = m.view(b_size[0], b_size[1], -1)
        return m

    def out_intnet(self, size_h, size_e, args):
        return self.args['out']

    def init_intnet(self, params):
        learn_args = []
        learn_modules = []
        args = {}
        args['in'] = params['in']
        args['out'] = params['out']
        learn_modules.append(NNet(n_in=params['in'], n_out=params['out']))
        return nn.ParameterList(learn_args), nn.ModuleList(learn_modules), args

    # Gilmer et al. (2017), Neural Message Passing for Quantum Chemistry
    def m_mpnn(self, h_v, h_w, e_vw, opt={}):
        # Matrices for each edge
        edge_output = self.learn_modules[0](e_vw)
        edge_output = edge_output.view(-1, self.args['out'], self.args['in'])

        h_w_rows = h_w[..., None].expand(h_w.size(0), h_w.size(1), h_v.size(1)).contiguous()

        h_w_rows = h_w_rows.view(-1, self.args['in'])

        h_multiply = torch.bmm(edge_output, torch.unsqueeze(h_w_rows,2))

        m_new = torch.squeeze(h_multiply)

        return m_new

    def out_mpnn(self, size_h, size_e, args):
        return self.args['out']

    def init_mpnn(self, params):
        learn_args = []
        learn_modules = []
        args = {}

        args['in'] = params['in']
        args['out'] = params['out']

        # Define a parameter matrix A for each edge label.
        learn_modules.append(NNet(n_in=params['edge_feat'], n_out=(params['in']*params['out'])))

        return nn.ParameterList(learn_args), nn.ModuleList(learn_modules), args

    # Kearnes et al. (2016), Molecular Graph Convolutions
    def m_mgc(self, h_v, h_w, e_vw, args):
        m = e_vw
        return m
    
    # Laplacian based methods
    # Bruna et al. (2013)
    def m_bruna(self, h_v, h_w, e_vw, args):
        # TODO
        m = [] 
        return m

    # Defferrard et al. (2016)
    def m_deff(self, h_v, h_w, e_vw, args):
        # TODO
        m = []
        return m

    # Kipf & Welling (2016)
    def m_kipf(self, h_v, h_w, e_vw, args):
        # TODO
        m = []
        return m

In [5]:
class UpdateFunction(nn.Module):

    # Constructor
    def __init__(self, update_def='nn', args={}):
        super(UpdateFunction, self).__init__()
        self.u_definition = ''
        self.u_function = None
        self.args = {}
        self.__set_update(update_def, args)

    # Update node hv given message mv
    def forward(self, h_v, m_v, opt={}):
        return self.u_function(h_v, m_v, opt)

    # Set update function
    def __set_update(self, update_def, args):
        self.u_definition = update_def.lower()

        self.u_function = {
                    'duvenaud':         self.u_duvenaud,
                    'ggnn':             self.u_ggnn,
                    'intnet':           self.u_intnet,
                    'mpnn':             self.u_mpnn
                }.get(self.u_definition, None)

        if self.u_function is None:
            print('WARNING!: Update Function has not been set correctly\n\tIncorrect definition ' + update_def)

        init_parameters = {
            'duvenaud':         self.init_duvenaud,
            'ggnn':             self.init_ggnn,
            'intnet':           self.init_intnet,
            'mpnn':             self.init_mpnn
        }.get(self.u_definition, lambda x: (nn.ParameterList([]), nn.ModuleList([]), {}))

        self.learn_args, self.learn_modules, self.args = init_parameters(args)

    # Get the name of the used update function
    def get_definition(self):
        return self.u_definition

    # Get the update function arguments
    def get_args(self):
        return self.args

    ## Definition of various state of the art update functions

    # Duvenaud
    def u_duvenaud(self, h_v, m_v, opt):

        param_sz = self.learn_args[0][opt['deg']].size()
        parameter_mat = torch.t(self.learn_args[0][opt['deg']])[None, ...].expand(m_v.size(0), param_sz[1], param_sz[0])

        aux = torch.bmm(parameter_mat, torch.transpose(m_v, 1, 2))

        return torch.transpose(torch.nn.Sigmoid()(aux), 1, 2)

    def init_duvenaud(self, params):
        learn_args = []
        learn_modules = []
        args = {}

        # Filter degree 0 (the message will be 0 and therefore there is no update
        args['deg'] = [i for i in params['deg'] if i!=0]
        args['in'] = params['in']
        args['out'] = params['out']

        # Define a parameter matrix H for each degree.
        learn_args.append(torch.nn.Parameter(torch.randn(len(args['deg']), args['in'], args['out'])))

        return nn.ParameterList(learn_args), nn.ModuleList(learn_modules), args

    # GG-NN, Li et al.
    def u_ggnn(self, h_v, m_v, opt={}):
        h_v.contiguous()
        m_v.contiguous()
        h_new = self.learn_modules[0](torch.transpose(m_v, 0, 1), torch.unsqueeze(h_v, 0))[0]  # 0 or 1???
        return torch.transpose(h_new, 0, 1)

    def init_ggnn(self, params):
        learn_args = []
        learn_modules = []
        args = {}

        args['in_m'] = params['in_m']
        args['out'] = params['out']

        # GRU
        learn_modules.append(nn.GRU(params['in_m'], params['out']))

        return nn.ParameterList(learn_args), nn.ModuleList(learn_modules), args

    # Battaglia et al. (2016), Interaction Networks
    def u_intnet(self, h_v, m_v, opt):
        if opt['x_v'].ndimension():
            input_tensor = torch.cat([h_v, opt['x_v'], torch.squeeze(m_v)], 1)
        else:
            input_tensor = torch.cat([h_v, torch.squeeze(m_v)], 1)

        return self.learn_modules[0](input_tensor)

    def init_intnet(self, params):
        learn_args = []
        learn_modules = []
        args = {}

        args['in'] = params['in']
        args['out'] = params['out']

        learn_modules.append(NNet(n_in=params['in'], n_out=params['out']))

        return nn.ParameterList(learn_args), nn.ModuleList(learn_modules), args

    def u_mpnn(self, h_v, m_v, opt={}):
        h_in = h_v.view(-1,h_v.size(2))
        m_in = m_v.view(-1,m_v.size(2))
        h_new = self.learn_modules[0](m_in[None,...],h_in[None,...])[0] # 0 or 1???
        return torch.squeeze(h_new).view(h_v.size())

    def init_mpnn(self, params):
        learn_args = []
        learn_modules = []
        args = {}

        args['in_m'] = params['in_m']
        args['out'] = params['out']

        # GRU
        learn_modules.append(nn.GRU(params['in_m'], params['out']))

        return nn.ParameterList(learn_args), nn.ModuleList(learn_modules), args

In [6]:
class ReadoutFunction(nn.Module):

    # Constructor
    def __init__(self, readout_def='nn', args={}):
        super(ReadoutFunction, self).__init__()
        self.r_definition = ''
        self.r_function = None
        self.args = {}
        self.__set_readout(readout_def, args)

    # Readout graph given node values at las layer
    def forward(self, h_v):
        return self.r_function(h_v)

    # Set a readout function
    def __set_readout(self, readout_def, args):
        self.r_definition = readout_def.lower()

        self.r_function = {
                    'duvenaud': self.r_duvenaud,
                    'ggnn':     self.r_ggnn,
                    'intnet':   self.r_intnet,
                    'mpnn':     self.r_mpnn
                }.get(self.r_definition, None)

        if self.r_function is None:
            print('WARNING!: Readout Function has not been set correctly\n\tIncorrect definition ' + readout_def)
            quit()

        init_parameters = {
            'duvenaud': self.init_duvenaud,
            'ggnn':     self.init_ggnn,
            'intnet':   self.init_intnet,
            'mpnn':     self.init_mpnn
        }.get(self.r_definition, lambda x: (nn.ParameterList([]), nn.ModuleList([]), {}))

        self.learn_args, self.learn_modules, self.args = init_parameters(args)

    # Get the name of the used readout function
    def get_definition(self):
        return self.r_definition

    ## Definition of various state of the art update functions

    # Duvenaud
    def r_duvenaud(self, h):
        # layers
        aux = []
        for l in range(len(h)):
            param_sz = self.learn_args[l].size()
            parameter_mat = torch.t(self.learn_args[l])[None, ...].expand(h[l].size(0), param_sz[1],
                                                                                      param_sz[0])

            aux.append(torch.transpose(torch.bmm(parameter_mat, torch.transpose(h[l], 1, 2)), 1, 2))

            for j in range(0, aux[l].size(1)):
                # Mask whole 0 vectors
                aux[l][:, j, :] = nn.Softmax()(aux[l][:, j, :].clone())*(torch.sum(aux[l][:, j, :] != 0, 1) > 0).expand_as(aux[l][:, j, :]).type_as(aux[l])

        aux = torch.sum(torch.sum(torch.stack(aux, 3), 3), 1)
        return self.learn_modules[0](torch.squeeze(aux))

    def init_duvenaud(self, params):
        learn_args = []
        learn_modules = []
        args = {}

        args['out'] = params['out']

        # Define a parameter matrix W for each layer.
        for l in range(params['layers']):
            learn_args.append(nn.Parameter(torch.randn(params['in'][l], params['out'])))

        # learn_modules.append(nn.Linear(params['out'], params['target']))

        learn_modules.append(NNet(n_in=params['out'], n_out=params['target']))
        return nn.ParameterList(learn_args), nn.ModuleList(learn_modules), args

    # GG-NN, Li et al.
    def r_ggnn(self, h):

        aux = Variable( torch.Tensor(h[0].size(0), self.args['out']).type_as(h[0].data).zero_() )
        # For each graph
        for i in range(h[0].size(0)):
            nn_res = nn.Sigmoid()(self.learn_modules[0](torch.cat([h[0][i,:,:], h[-1][i,:,:]], 1)))*self.learn_modules[1](h[-1][i,:,:])

            # Delete virtual nodes
            nn_res = (torch.sum(h[0][i,:,:],1).expand_as(nn_res)>0).type_as(nn_res)* nn_res

            aux[i,:] = torch.sum(nn_res,0)

        return aux

    def init_ggnn(self, params):
        learn_args = []
        learn_modules = []
        args = {}

        # i
        learn_modules.append(NNet(n_in=2*params['in'], n_out=params['target']))

        # j
        learn_modules.append(NNet(n_in=params['in'], n_out=params['target']))

        args['out'] = params['target']

        return nn.ParameterList(learn_args), nn.ModuleList(learn_modules), args


    # Battaglia et al. (2016), Interaction Networks
    def r_intnet(self, h):

        aux = torch.sum(h[-1],1)

        return self.learn_modules[0](aux)

    def init_intnet(self, params):
        learn_args = []
        learn_modules = []
        args = {}

        learn_modules.append(NNet(n_in=params['in'], n_out=params['target']))

        return nn.ParameterList(learn_args), nn.ModuleList(learn_modules), args

    def r_mpnn(self, h):

        aux = Variable( torch.Tensor(h[0].size(0), self.args['out']).type_as(h[0].data).zero_() )
        # For each graph
        for i in range(h[0].size(0)):
            nn_res = nn.Sigmoid()(self.learn_modules[0](torch.cat([h[0][i,:,:], h[-1][i,:,:]], 1)))*self.learn_modules[1](h[-1][i,:,:])

            # Delete virtual nodes
            nn_res = (torch.sum(h[0][i,:,:],1)[...,None].expand_as(nn_res)>0).type_as(nn_res)* nn_res

            aux[i,:] = torch.sum(nn_res,0)

        return aux

    def init_mpnn(self, params):
        learn_args = []
        learn_modules = []
        args = {}

        # i
        learn_modules.append(NNet(n_in=2*params['in'], n_out=params['target']))

        # j
        learn_modules.append(NNet(n_in=params['in'], n_out=params['target']))

        args['out'] = params['target']

        return nn.ParameterList(learn_args), nn.ModuleList(learn_modules), args


In [19]:
class MpnnGGNN(nn.Module):
    """
        MPNN as proposed by Li et al..

        This class implements the whole Li et al. model following the functions proposed by Gilmer et al. as
        Message, Update and Readout.

        Parameters
        ----------
        e : int list.
            Possible edge labels for the input graph.
        hidden_state_size : int
            Size of the hidden states (the input will be padded with 0's to this size).
        message_size : int
            Message function output vector size.
        n_layers : int
            Number of iterations Message+Update (weight tying).
        l_target : int
            Size of the output.
        type : str (Optional)
            Classification | [Regression (default)]. If classification, LogSoftmax layer is applied to the output vector.
    """

    def __init__(self, e, hidden_state_size, message_size, n_layers, l_target, type='regression'):
        super(MpnnGGNN, self).__init__()

        # Define message
        self.m = nn.ModuleList([MessageFunction('ggnn', args={'e_label': e, 'in': hidden_state_size, 'out': message_size})])

        # Define Update
        self.u = nn.ModuleList([UpdateFunction('ggnn',
                                                args={'in_m': message_size,
                                                'out': hidden_state_size})])

        # Define Readout
        self.r = ReadoutFunction('ggnn',
                                 args={'in': hidden_state_size,
                                       'target': l_target})

        self.type = type

        self.args = {}
        self.args['out'] = hidden_state_size

        self.n_layers = n_layers

    def forward(self, g, h_in, e):

        h = []

        # Padding to some larger dimension d
        h_t = torch.cat([h_in, Variable(torch.Tensor(h_in.size(0), h_in.size(1), h_in.size(2)).type_as(h_in.data).zero_())], 2)
        

        h.append(h_t.clone())

        # Layer
        for t in range(0, self.n_layers):

            h_t = Variable(torch.zeros(h[0].size(0), h[0].size(1), h[0].size(2)).type_as(h_in.data))

            # Apply one layer pass (Message + Update)
            for v in range(0, h_in.size(1)):
                print(h[t].shape)
                print(e.shape)
                m = self.m[0].forward(h[t][:, v, :], h[t], e[:, v, :])
                #m = self.m[0].forward(h[t][:, v, :], h[t], e[:, v, :])

                # Nodes without edge set message to 0
                m = g[:, v, :, None].expand_as(m) * m

                m = torch.sum(m, 1)

                # Update
                h_t[:, v, :] = self.u[0].forward(h[t][:, v, :], m)

            # Delete virtual nodes
            h_t = (torch.sum(h_in, 2)[..., None].expand_as(h_t) > 0).type_as(h_t) * h_t
            h.append(h_t.clone())

        # Readout
        res = self.r.forward(h)
        if self.type == 'classification':
            res = nn.LogSoftmax()(res)
        return res

In [22]:
class MpnnDuvenaud(nn.Module):
    """
        MPNN as proposed by Duvenaud et al..

        This class implements the whole Duvenaud et al. model following the functions proposed by Gilmer et al. as 
        Message, Update and Readout.

        Parameters
        ----------
        d : int list.
            Possible degrees for the input graph.
        in_n : int list
            Sizes for the node and edge features.
        out_update : int list
            Output sizes for the different Update functions.
        hidden_state_readout : int
            Input size for the neural net used inside the readout function.
        l_target : int
            Size of the output.
        type : str (Optional)
            Classification | [Regression (default)]. If classification, LogSoftmax layer is applied to the output vector.
    """

    def __init__(self, d, in_n, out_update, hidden_state_readout, l_target, type='regression'):
        super(MpnnDuvenaud, self).__init__()

        n_layers = len(out_update)

        # Define message 1 & 2
        self.m = nn.ModuleList([MessageFunction('duvenaud') for _ in range(n_layers)])

        # Define Update 1 & 2
        self.u = nn.ModuleList([UpdateFunction('duvenaud', args={'deg': d, 'in': self.m[i].get_out_size(in_n[0], in_n[1]), 'out': out_update[0]}) if i == 0 else
                                UpdateFunction('duvenaud', args={'deg': d, 'in': self.m[i].get_out_size(out_update[i-1], in_n[1]), 'out': out_update[i]}) for i in range(n_layers)])

        # Define Readout
        self.r = ReadoutFunction('duvenaud',
                                 args={'layers': len(self.m) + 1,
                                       'in': [in_n[0] if i == 0 else out_update[i-1] for i in range(n_layers+1)],
                                       'out': hidden_state_readout,
                                       'target': l_target})

        self.type = type

    def forward(self, g, h_in, e, plotter=None):

        h = []
        h.append(h_in)

        # Layer
        for t in range(0, len(self.m)):

            u_args = self.u[t].get_args()

            h_t = Variable(torch.zeros(h_in.size(0), h_in.size(1), u_args['out']).type_as(h[t].data))

            # Apply one layer pass (Message + Update)
            for v in range(0, h_in.size(1)):

                m = self.m[t].forward(h[t][:, v, :], h[t], e[:, v, :])

                # Nodes without edge set message to 0
                m = g[:, v, :, None].expand_as(m) * m

                m = torch.sum(m, 1)

                # Duvenaud
                deg = torch.sum(g[:, v, :].data, 1)

                # Separate degrees
                for i in range(len(u_args['deg'])):
                    ind = deg == u_args['deg'][i]
                    ind = Variable(torch.squeeze(torch.nonzero(torch.squeeze(ind))), volatile=True)

                    opt = {'deg': i}

                    # Update
                    if len(ind) != 0:
                        aux = self.u[t].forward(torch.index_select(h[t], 0, ind)[:, v, :], torch.index_select(m, 0, ind), opt)

                        ind = ind.data.cpu().numpy()
                        for j in range(len(ind)):
                            h_t[ind[j], v, :] = aux[j, :]

            if plotter is not None:
                num_feat = h_t.size(2)
                color = h_t[0,:,:].data.cpu().numpy()
                for i in range(num_feat):
                    plotter(color[:, i], 'layer_' + str(t) + '_element_' + str(i) + '.png')

            h.append(h_t.clone())
        # Readout
        res = self.r.forward(h)
        if self.type == 'classification':
            res = nn.LogSoftmax()(res)
        return res

In [7]:
class MPNN(nn.Module):
    """
        MPNN as proposed by Gilmer et al..

        This class implements the whole Gilmer et al. model following the functions Message, Update and Readout.

        Parameters
        ----------
        in_n : int list
            Sizes for the node and edge features.
        hidden_state_size : int
            Size of the hidden states (the input will be padded with 0's to this size).
        message_size : int
            Message function output vector size.
        n_layers : int
            Number of iterations Message+Update (weight tying).
        l_target : int
            Size of the output.
        type : str (Optional)
            Classification | [Regression (default)]. If classification, LogSoftmax layer is applied to the output vector.
    """

    def __init__(self, in_n, hidden_state_size, message_size, n_layers, l_target, type='regression'):
        super(MPNN, self).__init__()

        # Define message
        self.m = nn.ModuleList(
            [MessageFunction('mpnn', args={'edge_feat': in_n[1], 'in': hidden_state_size, 'out': message_size})])

        # Define Update
        self.u = nn.ModuleList([UpdateFunction('mpnn',
                                               args={'in_m': message_size,
                                                     'out': hidden_state_size})])

        # Define Readout
        self.r = ReadoutFunction('mpnn',
                                 args={'in': hidden_state_size,
                                       'target': l_target})

        self.type = type

        self.args = {}
        self.args['out'] = hidden_state_size

        self.n_layers = n_layers

    def forward(self, g, h_in, e):

        h = []

        # Padding to some larger dimension d
        h_t = torch.cat([h_in, Variable(
            torch.zeros(h_in.size(0), h_in.size(1), self.args['out'] - h_in.size(2)).type_as(h_in.data))], 2)

        h.append(h_t.clone())

        # Layer
        for t in range(0, self.n_layers):
            e_aux = e.view(-1, e.size(3))

            h_aux = h[t].view(-1, h[t].size(2))

            m = self.m[0].forward(h[t], h_aux, e_aux)
            m = m.view(h[0].size(0), h[0].size(1), -1, m.size(1))

            # Nodes without edge set message to 0
            m = torch.unsqueeze(g, 3).expand_as(m) * m

            m = torch.squeeze(torch.sum(m, 1))

            h_t = self.u[0].forward(h[t], m)

            # Delete virtual nodes
            h_t = (torch.sum(h_in, 2)[..., None].expand_as(h_t) > 0).type_as(h_t) * h_t
            h.append(h_t)

        # Readout
        res = self.r.forward(h)

        if self.type == 'classification':
            res = nn.LogSoftmax()(res)
        return res

In [8]:
def qm9_nodes(g, hydrogen=False):
    h = []
    for n, d in g.nodes_iter(data=True):
        h_t = []
        # Atom type (One-hot H, C, N, O F)
        h_t += [int(d['a_type'] == x) for x in ['H', 'C', 'N', 'O', 'F']]
        # Atomic number
        h_t.append(d['a_num'])
        # Partial Charge
        h_t.append(d['pc'])
        # Acceptor
        h_t.append(d['acceptor'])
        # Donor
        h_t.append(d['donor'])
        # Aromatic
        h_t.append(int(d['aromatic']))
        # Hybradization
        h_t += [int(d['hybridization'] == x) for x in [rdkit.Chem.rdchem.HybridizationType.SP, rdkit.Chem.rdchem.HybridizationType.SP2, rdkit.Chem.rdchem.HybridizationType.SP3]]
        # If number hydrogen is used as a
        if hydrogen:
            h_t.append(d['num_h'])
        h.append(h_t)
    return h


def qm9_edges(g, e_representation='raw_distance'):
    remove_edges = []
    e={}    
    for n1, n2, d in g.edges_iter(data=True):
        e_t = []
        # Raw distance function
        if e_representation == 'chem_graph':
            if d['b_type'] is None:
                remove_edges += [(n1, n2)]
            else:
                e_t += [i+1 for i, x in enumerate([rdkit.Chem.rdchem.BondType.SINGLE, rdkit.Chem.rdchem.BondType.DOUBLE,
                                                rdkit.Chem.rdchem.BondType.TRIPLE, rdkit.Chem.rdchem.BondType.AROMATIC])
                        if x == d['b_type']]
        elif e_representation == 'distance_bin':
            if d['b_type'] is None:
                step = (6-2)/8.0
                start = 2
                b = 9
                for i in range(0, 9):
                    if d['distance'] < (start+i*step):
                        b = i
                        break
                e_t.append(b+5)
            else:
                e_t += [i+1 for i, x in enumerate([rdkit.Chem.rdchem.BondType.SINGLE, rdkit.Chem.rdchem.BondType.DOUBLE,
                                                   rdkit.Chem.rdchem.BondType.TRIPLE, rdkit.Chem.rdchem.BondType.AROMATIC])
                        if x == d['b_type']]
        elif e_representation == 'raw_distance':
            if d['b_type'] is None:
                remove_edges += [(n1, n2)]
            else:
                e_t.append(d['distance'])
                e_t += [int(d['b_type'] == x) for x in [rdkit.Chem.rdchem.BondType.SINGLE, rdkit.Chem.rdchem.BondType.DOUBLE,
                                                        rdkit.Chem.rdchem.BondType.TRIPLE, rdkit.Chem.rdchem.BondType.AROMATIC]]
        else:
            print('Incorrect Edge representation transform')
            quit()
        if e_t:
            e[(n1, n2)] = e_t
    for edg in remove_edges:
        g.remove_edge(*edg)
    return nx.to_numpy_matrix(g), e
    

def normalize_data(data, mean, std):
    data_norm = (data-mean)/std
    return data_norm


In [9]:
def init_graph_our(g_tag,prop):
    
    prop = prop.split()
    g_H=float(prop[0])
    g_p=float(prop[1])

    labels = [g_H, g_p]
    return nx.Graph(tag=g_tag,H=g_H, p=g_p), labels

In [10]:
def xyz_graph_reader(graph_file):

    with open(graph_file,'r') as f:
        # Number of atoms
        #na = int(f.readline())
        tag=f.readline().strip()
        # Graph properties
        properties = f.readline()
        na=int(f.readline())
        #g, l = init_graph(properties)
        g,l=init_graph_our(tag, properties)
        atom_properties = []
        # Atoms properties
        for i in range(na):
            a_properties = f.readline()
            a_properties = a_properties.replace('.*^', 'e')
            a_properties = a_properties.replace('*^', 'e')
            a_properties = a_properties.split()
            atom_properties.append(a_properties)

        # Frequencies
        #f.readline()

        # SMILES
        smiles = f.readline()
        smiles = smiles.split()
        smiles = smiles[0]
        
        m = Chem.MolFromSmiles(smiles)
        m = Chem.AddHs(m)
        
        fdef_name = os.path.join(RDConfig.RDDataDir, 'BaseFeatures.fdef')
        factory = ChemicalFeatures.BuildFeatureFactory(fdef_name)
        feats = factory.GetFeaturesForMol(m)

        # Create nodes
        for i in range(0, m.GetNumAtoms()):
            atom_i = m.GetAtomWithIdx(i)

            g.add_node(i, a_type=atom_i.GetSymbol(), a_num=atom_i.GetAtomicNum(), acceptor=0, donor=0,
                       aromatic=atom_i.GetIsAromatic(), hybridization=atom_i.GetHybridization(),
                       num_h=atom_i.GetTotalNumHs(), coord=np.array(atom_properties[i][1:4]).astype(np.float),
                       pc=float(atom_properties[i][4]))

        for i in range(0, len(feats)):
            if feats[i].GetFamily() == 'Donor':
                node_list = feats[i].GetAtomIds()
                for i in node_list:
                    g.node[i]['donor'] = 1
            elif feats[i].GetFamily() == 'Acceptor':
                node_list = feats[i].GetAtomIds()
                for i in node_list:
                    g.node[i]['acceptor'] = 1

        # Read Edges
        for i in range(0, m.GetNumAtoms()):
            for j in range(0, m.GetNumAtoms()):
                e_ij = m.GetBondBetweenAtoms(i, j)
                if e_ij is not None:
                    g.add_edge(i, j, b_type=e_ij.GetBondType(),
                               distance=np.linalg.norm(g.node[i]['coord']-g.node[j]['coord']))
                else:
                    # Unbonded
                    g.add_edge(i, j, b_type=None,
                               distance=np.linalg.norm(g.node[i]['coord'] - g.node[j]['coord']))
    return g , l


In [11]:
def restricted_float(x, inter):
    x = float(x)
    if x < inter[0] or x > inter[1]:
        raise argparse.ArgumentTypeError("%r not in range [1e-5, 1e-4]"%(x,))
    return x

In [12]:
class Qm9(data.Dataset):

    # Constructor
    def __init__(self, root_path, ids, vertex_transform=utils.qm9_nodes, edge_transform=utils.qm9_edges,
                 target_transform=None, e_representation='raw_distance'):
        self.root = root_path
        self.ids = ids
        self.vertex_transform = vertex_transform
        self.edge_transform = edge_transform
        self.target_transform = target_transform
        self.e_representation = e_representation

    def __getitem__(self, index):
        g, target = xyz_graph_reader(os.path.join(self.root, self.ids[index]))
        if self.vertex_transform is not None:
            h = self.vertex_transform(g)

        if self.edge_transform is not None:
            g, e = utils.qm9_edges(g, self.e_representation)

        if self.target_transform is not None:
            target = self.target_transform(target)

        return (g, h, e), target

    def __len__(self):
        return len(self.ids)

    def set_target_transform(self, target_transform):
        self.target_transform = target_transform

In [13]:
root = args.root[0]
files = [f for f in os.listdir(root) if os.path.isfile(os.path.join(root, f))]

idx = np.random.permutation(len(files))
idx = idx.tolist()

valid_ids = [files[i] for i in idx[0:1000]]
test_ids  = [files[i] for i in idx[1000:2000]]
train_ids = [files[i] for i in idx[2000:]]

data_train = Qm9(root, train_ids, vertex_transform=utils.qm9_nodes, edge_transform=lambda g: utils.qm9_edges(g, e_representation='raw_distance'))
data_valid = Qm9(root, valid_ids)
data_test = Qm9(root, test_ids)

In [14]:
class AverageMeter(object):
    """Computes and stores the average and current value"""
    def __init__(self):
        self.reset()

    def reset(self):
        self.val = 0
        self.avg = 0
        self.sum = 0
        self.count = 0

    def update(self, val, n=1):
        self.val = val
        self.sum += val * n
        self.count += n
        self.avg = self.sum / self.count

In [15]:
utils.get_graph_stats(data_valid, 'degrees')
stat_dict = utils.get_graph_stats(data_valid, ['target_mean', 'target_std'])

data_train.set_target_transform(lambda x: utils.normalize_data(x,stat_dict['target_mean'],stat_dict['target_std']))
data_valid.set_target_transform(lambda x: utils.normalize_data(x, stat_dict['target_mean'],stat_dict['target_std']))
data_test.set_target_transform(lambda x: utils.normalize_data(x, stat_dict['target_mean'], stat_dict['target_std']))


In [16]:
train_loader = torch.utils.data.DataLoader(data_train,
                                           batch_size=args.batch_size, shuffle=True,
                                           collate_fn=utils.collate_g,
                                        pin_memory=True)
valid_loader = torch.utils.data.DataLoader(data_valid,
                                           batch_size=args.batch_size, collate_fn=utils.collate_g,
                                           pin_memory=True)
test_loader = torch.utils.data.DataLoader(data_test,
                                          batch_size=args.batch_size, collate_fn=utils.collate_g,
                                          pin_memory=True)

In [17]:
def train(train_loader, model, criterion, optimizer, epoch, evaluation, logger):
    batch_time = AverageMeter()
    data_time = AverageMeter()
    losses = AverageMeter()
    error_ratio = AverageMeter()

    # switch to train mode
    model.train()

    end = time.time()
    for i, (g, h, e, target) in enumerate(train_loader):

        # Prepare input data
        if args.cuda:
            g, h, e, target = g.cuda(), h.cuda(), e.cuda(), target.cuda()
        g, h, e, target = Variable(g), Variable(h), Variable(e), Variable(target)

        # Measure data loading time
        data_time.update(time.time() - end)

        optimizer.zero_grad()

        # Compute output
        output = model(g, h, e)
        train_loss = criterion(output, target)

        # Logs
        losses.update(train_loss.item(), g.size(0))
        error_ratio.update(evaluation(output, target).item(), g.size(0))

        # compute gradient and do SGD step
        train_loss.backward()
        optimizer.step()

        # Measure elapsed time
        batch_time.update(time.time() - end)
        end = time.time()

        if i % args.log_interval == 0 and i > 0:

            print('Epoch: [{0}][{1}/{2}]\t'
                  'Time {batch_time.val:.3f} ({batch_time.avg:.3f})\t'
                  'Data {data_time.val:.3f} ({data_time.avg:.3f})\t'
                  'Loss {loss.val:.4f} ({loss.avg:.4f})\t'
                  'Error Ratio {err.val:.4f} ({err.avg:.4f})'
                  .format(epoch+1, i, len(train_loader), batch_time=batch_time,
                          data_time=data_time, loss=losses, err=error_ratio))
                          

    print('Epoch: [{0}] Avg Error Ratio {err.avg:.3f}; Average Loss {loss.avg:.3f}; Avg Time x Batch {b_time.avg:.3f}'
          .format(epoch, err=error_ratio, loss=losses, b_time=batch_time))

In [18]:
def validate(val_loader, model, criterion, evaluation, logger=None):
    batch_time = AverageMeter()
    losses = AverageMeter()
    error_ratio = AverageMeter()

    # switch to evaluate mode
    model.eval()

    end = time.time()
    for i, (g, h, e, target) in enumerate(val_loader):

        # Prepare input data
        if args.cuda:
            g, h, e, target = g.cuda(), h.cuda(), e.cuda(), target.cuda()
        g, h, e, target = Variable(g), Variable(h), Variable(e), Variable(target)

        # Compute output
        output = model(g, h, e)

        # Logs
        losses.update(criterion(output, target).item(), g.size(0))
        error_ratio.update(evaluation(output, target).item(), g.size(0))

        # measure elapsed time
        batch_time.update(time.time() - end)
        end = time.time()

        if i % args.log_interval == 0 and i > 0:
            
            print('Test: [{0}/{1}]\t'
                  'Time {batch_time.val:.3f} ({batch_time.avg:.3f})\t'
                  'Loss {loss.val:.4f} ({loss.avg:.4f})\t'
                  'Error Ratio {err.val:.4f} ({err.avg:.4f})'
                  .format(i, len(val_loader), batch_time=batch_time,
                          loss=losses, err=error_ratio))

    print(' * Average Error Ratio {err.avg:.3f}; Average Loss {loss.avg:.3f}'
          .format(err=error_ratio, loss=losses))


    return error_ratio.avg

In [28]:
g_tuple, l = data_train[0]
g, h_t, e = g_tuple
in_n = [len(h_t[0]), len(list(e.values())[0])]
hidden_state_size = 73
message_size = 73
n_layers = 3
l_target = len(l)
type ='regression'
model = MpnnDuvenaud(utils.get_graph_stats(data_valid, 'degrees'),in_n, hidden_state_size, message_size, l_target, type=type)
del in_n, hidden_state_size, message_size, n_layers, l_target, type
optimizer = optim.Adam(model.parameters(), lr=args.lr)
criterion = nn.MSELoss()
evaluation = lambda output, target: torch.mean(torch.abs(output - target) / torch.abs(target))
lr_step = (args.lr-args.lr*args.lr_decay)/(args.epochs*args.schedule[1] - args.epochs*args.schedule[0])

TypeError: object of type 'int' has no len()

In [58]:
if args.resume:
    checkpoint_dir = args.resume
    best_model_file = os.path.join(checkpoint_dir, 'model_best.pth')
    if not os.path.isdir(checkpoint_dir):
        os.makedirs(checkpoint_dir)
    if os.path.isfile(best_model_file):
        print("=> loading best model '{}'".format(best_model_file))
        checkpoint = torch.load(best_model_file)
        args.start_epoch = checkpoint['epoch']
        best_acc1 = checkpoint['best_er1']
        model.load_state_dict(checkpoint['state_dict'])
        optimizer.load_state_dict(checkpoint['optimizer'])
        print("=> loaded best model '{}' (epoch {})".format(best_model_file, checkpoint['epoch']))
    else:
        print("=> no best model found at '{}'".format(best_model_file))

=> loading best model './checkpoint/our_data/mpnn/model_best.pth'
=> loaded best model './checkpoint/our_data/mpnn/model_best.pth' (epoch 44)


In [26]:
if args.cuda:
    print('\t* Cuda')
    model = model.cuda()
    criterion = criterion.cuda()
model

	* Cuda


MpnnGGNN(
  (m): ModuleList(
    (0): MessageFunction(
      (learn_args): ParameterList(  (0): Parameter containing: [torch.cuda.FloatTensor of size 25x73x73 (GPU 0)])
      (learn_modules): ModuleList()
    )
  )
  (u): ModuleList(
    (0): UpdateFunction(
      (learn_args): ParameterList()
      (learn_modules): ModuleList(
        (0): GRU(73, 73)
      )
    )
  )
  (r): ReadoutFunction(
    (learn_args): ParameterList()
    (learn_modules): ModuleList(
      (0): NNet(
        (fcs): ModuleList(
          (0): Linear(in_features=146, out_features=128, bias=True)
          (1): Linear(in_features=128, out_features=256, bias=True)
          (2): Linear(in_features=256, out_features=128, bias=True)
          (3): Linear(in_features=128, out_features=2, bias=True)
        )
      )
      (1): NNet(
        (fcs): ModuleList(
          (0): Linear(in_features=73, out_features=128, bias=True)
          (1): Linear(in_features=128, out_features=256, bias=True)
          (2): Linear(in_

In [64]:
for epoch in range(0, args.epochs):

    if epoch > args.epochs * args.schedule[0] and epoch < args.epochs * args.schedule[1]:
        args.lr -= lr_step
        for param_group in optimizer.param_groups:
            param_group['lr'] = args.lr

    # train for one epoch
    train(train_loader, model, criterion, optimizer, epoch, evaluation, logger=None)

    # evaluate on test set
    er1 = validate(valid_loader, model, criterion, evaluation)

    is_best = er1 > best_er1
    best_er1 = min(er1, best_er1)
    utils.save_checkpoint({'epoch': epoch + 1, 'state_dict': model.state_dict(), 'best_er1': best_er1,
                           'optimizer': optimizer.state_dict(), }, is_best=is_best, directory=args.resume)


RuntimeError: CUDA out of memory. Tried to allocate 1.12 GiB (GPU 0; 4.00 GiB total capacity; 1.84 GiB already allocated; 115.79 MiB free; 905.70 MiB cached)

In [44]:
validate(test_loader, model, criterion, evaluation)

Test: [20/500]	Time 0.046 (0.052)	Loss 0.1834 (0.3290)	Error Ratio 0.8179 (1.4637)
Test: [40/500]	Time 0.035 (0.056)	Loss 0.2218 (0.3783)	Error Ratio 1.0000 (1.3206)
Test: [60/500]	Time 0.054 (0.057)	Loss 0.1126 (0.9725)	Error Ratio 1.0472 (1.3199)
Test: [80/500]	Time 0.072 (0.056)	Loss 0.1723 (1.0757)	Error Ratio 1.2968 (1.3300)
Test: [100/500]	Time 0.050 (0.056)	Loss 0.3681 (0.9656)	Error Ratio 1.4976 (1.2820)
Test: [120/500]	Time 0.045 (0.056)	Loss 0.4640 (0.8590)	Error Ratio 1.3237 (1.3108)
Test: [140/500]	Time 0.053 (0.056)	Loss 0.2501 (0.9449)	Error Ratio 1.2683 (1.3011)
Test: [160/500]	Time 0.104 (0.057)	Loss 1.9813 (0.8884)	Error Ratio 0.8331 (1.4272)
Test: [180/500]	Time 0.043 (0.057)	Loss 0.1123 (0.8331)	Error Ratio 0.7974 (1.3982)
Test: [200/500]	Time 0.054 (0.056)	Loss 0.4142 (0.8103)	Error Ratio 1.5463 (1.3763)
Test: [220/500]	Time 0.041 (0.056)	Loss 0.2909 (0.7667)	Error Ratio 1.3218 (1.3967)
Test: [240/500]	Time 0.057 (0.056)	Loss 0.4908 (0.7543)	Error Ratio 1.0000 (1.43

1.5121180399656295

In [74]:
data_train[0]

((matrix([[0., 1., 0., 0., 0., 1., 0., 0., 0., 0., 0., 0., 0., 0., 1., 0.,
           0., 0., 0., 0., 0., 0., 0., 0.],
          [1., 0., 1., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 1.,
           0., 0., 0., 0., 0., 0., 0., 0.],
          [0., 1., 0., 1., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
           1., 0., 0., 0., 0., 0., 0., 0.],
          [0., 0., 1., 0., 1., 0., 1., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
           0., 0., 0., 0., 0., 0., 0., 0.],
          [0., 0., 0., 1., 0., 1., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
           0., 1., 0., 0., 0., 0., 0., 0.],
          [1., 0., 0., 0., 1., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
           0., 0., 1., 0., 0., 0., 0., 0.],
          [0., 0., 0., 1., 0., 0., 0., 1., 0., 0., 0., 0., 0., 0., 0., 0.,
           0., 0., 0., 0., 0., 0., 0., 0.],
          [0., 0., 0., 0., 0., 0., 1., 0., 1., 0., 0., 0., 0., 0., 0., 0.,
           0., 0., 0., 0., 0., 0., 0., 0.],
          [0., 0., 0., 0., 0., 0., 0., 1., 0., 1