# Make MPNN Neural Network
This notebook creates a Message Passing Neural Network using [nfp](http://github.com/nrel/nfp) in a form ready for training.

In [1]:
from jcesr_ml.benchmark import load_benchmark_data, dielectric_constants
from jcesr_ml.keras import cartesian_product
from jcesr_ml.mpnn import (SolvationPreprocessor, PartialChargesPreprocessor, save_model_files,
                           atom_feature_element_only, bond_feature_type_only)
from nfp.layers import (MessageLayer, GRUStep, Set2Set, ReduceAtomToMol, 
                        Embedding2D, Embedding2DCompressed, Squeeze)
from keras.layers import (Add, Input, Dense, BatchNormalization, Reshape, Concatenate,
                          Activation, Dropout, Embedding, Lambda)
from keras import backend as K
from nfp.models import GraphModel
from nfp.preprocessing import GraphSequence
import tensorflow as tf
import pickle as pkl
import numpy as np
import shutil
import json
import os

  _np_qint8 = np.dtype([("qint8", np.int8, 1)])
  _np_quint8 = np.dtype([("quint8", np.uint8, 1)])
  _np_qint16 = np.dtype([("qint16", np.int16, 1)])
  _np_quint16 = np.dtype([("quint16", np.uint16, 1)])
  _np_qint32 = np.dtype([("qint32", np.int32, 1)])
  np_resource = np.dtype([("resource", np.ubyte, 1)])
Using TensorFlow backend.


In [2]:
train_data, _ = load_benchmark_data()

Get the columns for the outputs

In [3]:
sol_cols = [x for x in train_data.columns if x.startswith('sol_')]

In [4]:
sols_without_acn = list(sol_cols)
sols_without_acn.remove('sol_acn')

## Make the Preprocessing Tools
These tools convert the SMILES representation of a molecule into a set of features needed for the graph training.

In [5]:
preprocessor = SolvationPreprocessor([], explicit_hs=True)

In [6]:
preprocessor.fit(train_data['smiles_0']);

100%|██████████| 117232/117232 [01:01<00:00, 1903.16it/s]


## Make Utility Functions
Make a model-building function and a tool to save a model to disk

In [7]:
def build_fn(preprocessor, embedding=128, mol_features=1024,
             message_steps=6, bond_2d=False, activation='relu',
             set2set=False, set2set_steps=3, num_output=1):
    """Build a MPNN Keras model
    
    Adapted from: https://github.com/NREL/nfp/blob/master/examples/run_2D_model_noatom_bn.py
    
    Args:
        preprocessor (SmilesPreprocessor): Tool to generate inputs from SMILES string
        embedding (int): Size of the atom/bond embedding
        mol_features (int): Number of features to use to describe a molecule
        message_steps (int): Number of message-passing steps
        bond_2d (bool): Whether to use 2D embeddings
        set2set (bool): Whether to use set2set for the input
        set2set_steps (int): Number of set2set iterations
        num_output (int): Number of output features
    """
    
    # Raw (integer) graph inputs
    #  node_graph_indices - Maps the atom index to which molecule it came from
    #  atom_types - Categorical type of each atom
    #  bond_types - Categorical type of each bond
    #  connectivity - Atoms on each end of each bond
    node_graph_indices = Input(shape=(1,), name='node_graph_indices', dtype='int32')
    atom_types = Input(shape=(1,), name='atom', dtype='int32')
    bond_types = Input(shape=(1,), name='bond', dtype='int32')
    connectivity = Input(shape=(2,), name='connectivity', dtype='int32')

    # The "indices" and "type" inputs have 1 feature per "entry"
    #  The Squeeze layer removes this singleton dimension to make the data easier to use
    squeeze = Squeeze()
    snode_graph_indices = squeeze(node_graph_indices)
    satom_types = squeeze(atom_types)
    sbond_types = squeeze(bond_types)

    # Create the embedding for each atom type
    atom_state = Embedding(
        preprocessor.atom_classes,
        embedding, name='atom_embedding')(satom_types)

    # Create the embedding for each bond type
    if bond_2d:
        bond_matrix = Embedding2D(
            preprocessor.bond_classes,
            embedding, name='bond_embedding')(sbond_types)
    else:
        bond_matrix = Embedding2DCompressed(
            preprocessor.bond_classes,
            embedding, name='bond_embedding')(sbond_types)

    # The core of the message passing framework: Recurrent and Message-passing layers
    #  The Message Layer computes an update message for each atom given the state of it's neighbors
    #  The Reccurent Layer (GRUStep) computes how the state of the atom changes given a message
    atom_rnn_layer = GRUStep(embedding)
    message_layer = MessageLayer(reducer='sum')

    # Perform the message passing
    for _ in range(message_steps):       
        # Get the message updates to each atom
        message = message_layer([atom_state, bond_matrix, connectivity])

        # Update memory and atom states
        atom_state = atom_rnn_layer([message, atom_state])
        

    # After the message passing step, we allow the atom state to be transformed with a dense layer
    # atom_state = BatchNormalization(momentum=0.9)(atom_state)
    atom_fingerprint = Dense(mol_features, activation='sigmoid')(atom_state)
    
    # To create a representation for the moleccule, we sum over all its atoms
    if set2set:
        # Or set2set
        mol_out = Set2Set(set2set_steps)([atom_fingerprint, snode_graph_indices])
    else:
        mol_out = ReduceAtomToMol(reducer='sum')([atom_fingerprint, snode_graph_indices])

    # Final dense steps to map the molecular representation to properties
    mol_out = BatchNormalization(momentum=0.9)(mol_out)
    mol_out = Dense(mol_features // 2, activation=activation)(mol_out)

    mol_out = BatchNormalization(momentum=0.9)(mol_out)
    mol_out = Dense(mol_features // 4, activation=activation)(mol_out)
    mol_out = Dense(num_output)(mol_out)
    
    # Add a scale layer
    mol_out = Dense(num_output, name='scale')(mol_out)

    return GraphModel([node_graph_indices, atom_types, bond_types, connectivity], [mol_out])

## Single-Task Network
Predict only solvation energy in water

In [8]:
model = build_fn(preprocessor)

Instructions for updating:
Colocations handled automatically by placer.


In [9]:
model.summary()

Model: "graphmodel_1"
__________________________________________________________________________________________________
Layer (type)                    Output Shape         Param #     Connected to                     
atom (InputLayer)               (None, 1)            0                                            
__________________________________________________________________________________________________
bond (InputLayer)               (None, 1)            0                                            
__________________________________________________________________________________________________
squeeze_1 (Squeeze)             (None,)              0           node_graph_indices[0][0]         
                                                                 atom[0][0]                       
                                                                 bond[0][0]                       
_______________________________________________________________________________________

In [10]:
save_model_files('single-task', preprocessor, model)

Already output. Skipping


## Multi-Task Network
Predict solvation energy for all of the things

In [11]:
model = build_fn(preprocessor, num_output=len(sol_cols))

In [12]:
save_model_files('multi-task', preprocessor, model, output_props=sol_cols)

Already output. Skipping


Make a version that does not train on ACN

In [13]:
model = build_fn(preprocessor, num_output=len(sols_without_acn))

In [14]:
save_model_files('multi-task_no-acn', preprocessor, model, output_props=sols_without_acn)

Already output. Skipping


Multi-task where you have a different dense layer for each output

In [15]:
def build_fn(preprocessor, embedding=128, mol_features=1024,
             message_steps=6, bond_2d=False, activation='relu',
             set2set=False, set2set_steps=3, num_output=1):
    """Build a MPNN Keras model
    
    Adapted from: https://github.com/NREL/nfp/blob/master/examples/run_2D_model_noatom_bn.py
    
    Args:
        preprocessor (SmilesPreprocessor): Tool to generate inputs from SMILES string
        embedding (int): Size of the atom/bond embedding
        mol_features (int): Number of features to use to describe a molecule
        message_steps (int): Number of message-passing steps
        activation (str): Activation function
        bond_2d (bool): Whether to use 2D embeddings
        set2set (bool): Whether to use set2set for the input
        set2set_steps (int): Number of set2set iterations
        num_output (int): Number of output features
    """
    
    # Raw (integer) graph inputs
    #  node_graph_indices - Maps the atom index to which molecule it came from
    #  atom_types - Categorical type of each atom
    #  bond_types - Categorical type of each bond
    #  connectivity - Atoms on each end of each bond
    node_graph_indices = Input(shape=(1,), name='node_graph_indices', dtype='int32')
    atom_types = Input(shape=(1,), name='atom', dtype='int32')
    bond_types = Input(shape=(1,), name='bond', dtype='int32')
    connectivity = Input(shape=(2,), name='connectivity', dtype='int32')

    # The "indices" and "type" inputs have 1 feature per "entry"
    #  The Squeeze layer removes this singleton dimension to make the data easier to use
    squeeze = Squeeze()
    snode_graph_indices = squeeze(node_graph_indices)
    satom_types = squeeze(atom_types)
    sbond_types = squeeze(bond_types)

    # Create the embedding for each atom type
    atom_state = Embedding(
        preprocessor.atom_classes,
        embedding, name='atom_embedding')(satom_types)

    # Create the embedding for each bond type
    if bond_2d:
        bond_matrix = Embedding2D(
            preprocessor.bond_classes,
            embedding, name='bond_embedding')(sbond_types)
    else:
        bond_matrix = Embedding2DCompressed(
            preprocessor.bond_classes,
            embedding, name='bond_embedding')(sbond_types)

    # The core of the message passing framework: Recurrent and Message-passing layers
    #  The Message Layer computes an update message for each atom given the state of it's neighbors
    #  The Reccurent Layer (GRUStep) computes how the state of the atom changes given a message
    atom_rnn_layer = GRUStep(embedding)
    message_layer = MessageLayer(reducer='sum')

    # Perform the message passing
    for _ in range(message_steps):

        # Get the message updates to each atom
        message = message_layer([atom_state, bond_matrix, connectivity])

        # Update memory and atom states
        atom_state = atom_rnn_layer([message, atom_state])

    # After the message passing step, we allow the atom state to be transformed with a dense layer
    # atom_state = BatchNormalization(momentum=0.9)(atom_state)
    atom_fingerprint = Dense(mol_features, activation='sigmoid')(atom_state)
    
    # To create a representation for the moleccule, we sum over all its atoms
    if set2set:
        # Or set2set
        mol_out = Set2Set(set2set_steps)([atom_fingerprint, snode_graph_indices])
    else:
        mol_out = ReduceAtomToMol(reducer='sum')([atom_fingerprint, snode_graph_indices])

    # Final dense steps to map the molecular representation to properties
    #  One network for each solvation output
    outputs = []
    for i in range(num_output):
        solv_out = BatchNormalization(momentum=0.9)(mol_out)
        solv_out = Dense(mol_features // 2, activation=activation)(solv_out)

        solv_out = BatchNormalization(momentum=0.9)(solv_out)
        solv_out = Dense(mol_features // 4, activation=activation)(solv_out)
        solv_out = Dense(1)(solv_out)
        
        outputs.append(solv_out)
        
    # Concatenate them
    output = Concatenate()(outputs)
    
    # Add a scale layer
    mol_out = Dense(num_output, name='scale')(mol_out)
        
    return GraphModel([node_graph_indices, atom_types, bond_types, connectivity], [output])

In [16]:
model = build_fn(preprocessor, num_output=len(sol_cols))

In [17]:
save_model_files('multi-task_multi-dense', preprocessor, model, output_props=sol_cols)

Already output. Skipping


## Including Dielectric Constant
The multiple solvation energies may be distinct output classes, but they are really the same class with three different inputs.
Here, we define a network that also takes the dielectric compounds as inputs for a machine learning model.
As these are hard-coded in the training set (we have the same 5 solvents for each molecules), we hard code the values of these inputs in our models

In [18]:
def build_fn_with_dec(preprocessor, embedding=128, mol_features=1024,
             message_steps=6, bond_2d=False, activation='relu',
             reduce_atom='sum', set2set=False, set2set_steps=3):
    """Build a MPNN Keras model
    
    Adapted from: https://github.com/NREL/nfp/blob/master/examples/run_2D_model_noatom_bn.py
    
    Args:
        preprocessor (SolvationPreprocessor): Tool to generate inputs from SMILES string
        embedding (int): Size of the atom/bond embedding
        mol_features (int): Number of features to use to describe a molecule
        message_steps (int): Number of message-passing steps
        activation (str): Desired activation function
        reduce_atom (str): Method used to reduce atom contribution to mol representation
        bond_2d (bool): Whether to use 2D embeddings
        set2set (bool): Whether to use set2set for the input
        set2set_steps (int): Number of set2set iterations
    """
    
    # Raw (integer) graph inputs
    #  node_graph_indices - Maps the atom index to which molecule it came from
    #  atom_types - Categorical type of each atom
    #  bond_types - Categorical type of each bond
    #  connectivity - Atoms on each end of each bond
    node_graph_indices = Input(shape=(1,), name='node_graph_indices', dtype='int32')
    atom_types = Input(shape=(1,), name='atom', dtype='int32')
    bond_types = Input(shape=(1,), name='bond', dtype='int32')
    connectivity = Input(shape=(2,), name='connectivity', dtype='int32')
    dielectric_cnst_input = Input(shape=(None,), name='dielectric_constants')
    
    # Expand dielectric constants to a Nx1 array
    dielectric_cnst = Lambda(K.max, arguments={'axis': 0, 'keepdims': True})(dielectric_cnst_input)
    dielectric_cnst = Lambda(K.transpose)(dielectric_cnst)

    # The "indices" and "type" inputs have 1 feature per "entry"
    #  The Squeeze layer removes this singleton dimension to make the data easier to use
    squeeze = Squeeze()
    snode_graph_indices = squeeze(node_graph_indices)
    satom_types = squeeze(atom_types)
    sbond_types = squeeze(bond_types)

    # Create the embedding for each atom type
    atom_state = Embedding(
        preprocessor.atom_classes,
        embedding, name='atom_embedding')(satom_types)

    # Create the embedding for each bond type
    if bond_2d:
        bond_matrix = Embedding2D(
            preprocessor.bond_classes,
            embedding, name='bond_embedding')(sbond_types)
    else:
        bond_matrix = Embedding2DCompressed(
            preprocessor.bond_classes,
            embedding, name='bond_embedding')(sbond_types)

    # The core of the message passing framework: Recurrent and Message-passing layers
    #  The Message Layer computes an update message for each atom given the state of it's neighbors
    #  The Reccurent Layer (GRUStep) computes how the state of the atom changes given a message
    atom_rnn_layer = GRUStep(embedding)
    message_layer = MessageLayer(reducer='sum')

    # Perform the message passing
    for _ in range(message_steps):

        # Get the message updates to each atom
        message = message_layer([atom_state, bond_matrix, connectivity])

        # Update memory and atom states
        atom_state = atom_rnn_layer([message, atom_state])

    # After the message passing step, we allow the atom state to be transformed with a dense layer
    # atom_state = BatchNormalization(momentum=0.9)(atom_state)
    atom_fingerprint = Dense(mol_features, activation='sigmoid')(atom_state)
    
    # To create a representation for the moleccule, we sum over all its atoms
    if set2set:
        # Or set2set
        mol_out = Set2Set(set2set_steps)([atom_fingerprint, snode_graph_indices])
    else:
        mol_out = ReduceAtomToMol(reducer=reduce_atom)([atom_fingerprint, snode_graph_indices])
        
    # Add dielectric cnst to molecule output
    #  Outputs shape will be [n_mols, n_dielectrics, mol_features+1]
    mol_out = Lambda(cartesian_product)([mol_out, dielectric_cnst])

    # Final dense steps to map the molecular representation to properties
    mol_out = BatchNormalization(momentum=0.9)(mol_out)
    mol_out = Dense(mol_features // 2, activation=activation)(mol_out)

    mol_out = BatchNormalization(momentum=0.9)(mol_out)
    mol_out = Dense(mol_features // 4, activation=activation)(mol_out)
    mol_out = Dense(1)(mol_out)
    
    # Layer to simplify learning scaling for the outputs
    mol_out = Dense(1, name='scale')(mol_out)
    
    # Reshape to the desired number of ouputs
    mol_out = Squeeze(axis=-1)(mol_out)

    return GraphModel([node_graph_indices, atom_types, bond_types, connectivity,
                       dielectric_cnst_input], [mol_out])

Make a version where we use all solvents in the training set

In [19]:
preprocessor.dielectric_cnsts = [dielectric_constants[s] for s in sol_cols]

In [20]:
model = build_fn_with_dec(preprocessor)

In [21]:
save_model_files('dielectric-constant', preprocessor, model, output_props=sol_cols, normalize=False)

Already output. Skipping


Make a version with softplus activation

In [22]:
model = build_fn_with_dec(preprocessor, activation='softplus')

In [23]:
save_model_files('dielectric-constant-softplus', preprocessor, model, output_props=sol_cols, normalize=False)

Already output. Skipping


Make a version where we withhold a solvent. Here, we choose ACN because it is in the middle

In [24]:
preprocessor.dielectric_cnsts = [dielectric_constants[s] for s in sols_without_acn]

In [25]:
model = build_fn_with_dec(preprocessor)

In [26]:
save_model_files('dielectric-constant_no-acn', preprocessor, model, normalize=False, output_props=sols_without_acn)

Already output. Skipping


Do the same thing, but with softplus activation

In [27]:
model = build_fn_with_dec(preprocessor, activation='softplus')

In [28]:
save_model_files('dielectric-constant-softplus_no-acn', preprocessor, model, normalize=False, output_props=sols_without_acn)

Already output. Skipping


## Models with Atomic Partial Charges
Have a model that incorporates the partial charges on each atom

Make a preprocessor that injects partial charges as inputs

In [29]:
with open(os.path.join('..', 'partial-charges', 'mapped_charges.pkl'), 'rb') as fp:
    partial_charges = pkl.load(fp)

In [30]:
charges_preprocessor = PartialChargesPreprocessor(partial_charges, [])

In [31]:
charges_preprocessor.fit(train_data['smiles_0']);

100%|██████████| 117232/117232 [01:02<00:00, 1872.12it/s]


In [32]:
def build_fn_with_dec_charges(preprocessor, embedding=128, mol_features=1024,
             message_steps=6, bond_2d=False, activation='relu',
             reduce_atom='sum', set2set=False, set2set_steps=3):
    """Build a MPNN Keras model with partial charges and 
    
    Adapted from: https://github.com/NREL/nfp/blob/master/examples/run_2D_model_noatom_bn.py
    
    Args:
        preprocessor (SolvationPreprocessor): Tool to generate inputs from SMILES string
        embedding (int): Size of the atom/bond embedding
        mol_features (int): Number of features to use to describe a molecule
        message_steps (int): Number of message-passing steps
        activation (str): Desired activation function
        reduce_atom (str): Method used to reduce atom contribution to mol representation
        bond_2d (bool): Whether to use 2D embeddings
        set2set (bool): Whether to use set2set for the input
        set2set_steps (int): Number of set2set iterations
    """
    
    # Raw (integer) graph inputs
    #  node_graph_indices - Maps the atom index to which molecule it came from
    #  atom_types - Categorical type of each atom
    #  bond_types - Categorical type of each bond
    #  connectivity - Atoms on each end of each bond
    node_graph_indices = Input(shape=(1,), name='node_graph_indices', dtype='int32')
    atom_types = Input(shape=(1,), name='atom', dtype='int32')
    bond_types = Input(shape=(1,), name='bond', dtype='int32')
    connectivity = Input(shape=(2,), name='connectivity', dtype='int32')
    partial_charges_input = Input(shape=(1,), name='partial_charges', dtype=K.floatx())
    dielectric_cnst_input = Input(shape=(None,), name='dielectric_constants')
    
    # Expand dielectric constants to a Nx1 array
    dielectric_cnst = Lambda(K.max, arguments={'axis': 0, 'keepdims': True})(dielectric_cnst_input)
    dielectric_cnst = Lambda(K.transpose)(dielectric_cnst)

    # The "indices" and "type" inputs have 1 feature per "entry"
    #  The Squeeze layer removes this singleton dimension to make the data easier to use
    squeeze = Squeeze()
    snode_graph_indices = squeeze(node_graph_indices)
    satom_types = squeeze(atom_types)
    sbond_types = squeeze(bond_types)

    # Create the embedding for each atom type
    atom_state = Embedding(
        preprocessor.atom_classes,
        embedding - 1, name='atom_embedding')(satom_types)

    # Create the embedding for each bond type
    if bond_2d:
        bond_matrix = Embedding2D(
            preprocessor.bond_classes,
            embedding, name='bond_embedding')(sbond_types)
    else:
        bond_matrix = Embedding2DCompressed(
            preprocessor.bond_classes,
            embedding, name='bond_embedding')(sbond_types)
        
    # Add the partial charges to the embedding
    atom_state = Concatenate(name='add_charges')([atom_state, partial_charges_input])

    # The core of the message passing framework: Recurrent and Message-passing layers
    #  The Message Layer computes an update message for each atom given the state of it's neighbors
    #  The Reccurent Layer (GRUStep) computes how the state of the atom changes given a message
    atom_rnn_layer = GRUStep(embedding)
    message_layer = MessageLayer(reducer='sum')

    # Perform the message passing
    for _ in range(message_steps):

        # Get the message updates to each atom
        message = message_layer([atom_state, bond_matrix, connectivity])

        # Update memory and atom states
        atom_state = atom_rnn_layer([message, atom_state])

    # After the message passing step, we allow the atom state to be transformed with a dense layer
    # atom_state = BatchNormalization(momentum=0.9)(atom_state)
    atom_fingerprint = Dense(mol_features, activation='sigmoid')(atom_state)
    
    # To create a representation for the moleccule, we sum over all its atoms
    if set2set:
        # Or set2set
        mol_out = Set2Set(set2set_steps)([atom_fingerprint, snode_graph_indices])
    else:
        mol_out = ReduceAtomToMol(reducer=reduce_atom)([atom_fingerprint, snode_graph_indices])
        
    # Add dielectric cnst to molecule output
    #  Outputs shape will be [n_mols, n_dielectrics, mol_features+1]
    mol_out = Lambda(cartesian_product)([mol_out, dielectric_cnst])

    # Final dense steps to map the molecular representation to properties
    mol_out = BatchNormalization(momentum=0.9)(mol_out)
    mol_out = Dense(mol_features // 2, activation=activation)(mol_out)

    mol_out = BatchNormalization(momentum=0.9)(mol_out)
    mol_out = Dense(mol_features // 4, activation=activation)(mol_out)
    mol_out = Dense(1)(mol_out)
    
    # Layer to simplify learning scaling for the outputs
    mol_out = Dense(1, name='scale')(mol_out)
    
    # Reshape to the desired number of ouputs
    mol_out = Squeeze(axis=-1)(mol_out)

    return GraphModel([node_graph_indices, atom_types, bond_types, connectivity,
                       dielectric_cnst_input, partial_charges_input], [mol_out])

Make a version trained on all solvents

In [33]:
model = build_fn_with_dec_charges(preprocessor, activation='softplus')

In [34]:
charges_preprocessor.dielectric_cnsts = [dielectric_constants[s] for s in sol_cols]

In [35]:
save_model_files('dielectric-constant-softplus-charges', charges_preprocessor,
           model, output_props=sol_cols, normalize=False)

Already output. Skipping


Make a version trained on all solvents

In [36]:
charges_preprocessor.dielectric_cnsts = [dielectric_constants[s] for s in sols_without_acn]

In [37]:
save_model_files('dielectric-constant-softplus-charges_no-acn', charges_preprocessor,
                 model, output_props=sols_without_acn, normalize=False)

Already output. Skipping


Models were we use either Gasteiger partial charges or those predicted from an MPNN

In [38]:
with open(os.path.join('..', 'partial-charges', 'gasteiger-charges.pkl'), 'rb') as fp:
    charges_preprocessor.charges_lookup = pkl.load(fp)

In [39]:
save_model_files('dielectric-constant-softplus-gasteiger-charges_no-acn', charges_preprocessor,
                 model, output_props=sols_without_acn, normalize=False)

Already output. Skipping


In [40]:
with open(os.path.join('..', 'partial-charges', 'mpnn-charges.pkl'), 'rb') as fp:
    charges_preprocessor.charges_lookup = pkl.load(fp)

In [41]:
save_model_files('dielectric-constant-softplus-mpnn-charges_no-acn', charges_preprocessor,
                 model, output_props=sols_without_acn, normalize=False)

Already output. Skipping


Save a version where we use all solvation energies and MPNN charges

In [42]:
charges_preprocessor.dielectric_cnsts = [dielectric_constants[s] for s in sol_cols]

In [43]:
save_model_files('dielectric-constant-softplus-mpnn-charges', charges_preprocessor,
                 model, output_props=sol_cols, normalize=False)

Already output. Skipping


## Single-Task with Atomic Contributions
Breaking up the predictions into per-atom contributions. Hopefully, this will aloow for scaling to larger molecules

In [44]:
def build_fn(preprocessor, embedding=128, mol_features=1024,
             message_steps=6, bond_2d=False, activation='relu',
             num_output=1):
    """Build a MPNN Keras model
    
    Adapted from: https://github.com/NREL/nfp/blob/master/examples/run_2D_model_noatom_bn.py
    
    Args:
        preprocessor (SmilesPreprocessor): Tool to generate inputs from SMILES string
        embedding (int): Size of the atom/bond embedding
        mol_features (int): Number of features to use to describe a molecule
        message_steps (int): Number of message-passing steps
        bond_2d (bool): Whether to use 2D embeddings
        num_output (int): Number of output features
    """
    
    # Raw (integer) graph inputs
    #  node_graph_indices - Maps the atom index to which molecule it came from
    #  atom_types - Categorical type of each atom
    #  bond_types - Categorical type of each bond
    #  connectivity - Atoms on each end of each bond
    node_graph_indices = Input(shape=(1,), name='node_graph_indices', dtype='int32')
    atom_types = Input(shape=(1,), name='atom', dtype='int32')
    bond_types = Input(shape=(1,), name='bond', dtype='int32')
    connectivity = Input(shape=(2,), name='connectivity', dtype='int32')

    # The "indices" and "type" inputs have 1 feature per "entry"
    #  The Squeeze layer removes this singleton dimension to make the data easier to use
    squeeze = Squeeze()
    snode_graph_indices = squeeze(node_graph_indices)
    satom_types = squeeze(atom_types)
    sbond_types = squeeze(bond_types)

    # Create the embedding for each atom type
    atom_state = Embedding(
        preprocessor.atom_classes,
        embedding, name='atom_embedding')(satom_types)

    # Create the embedding for each bond type
    if bond_2d:
        bond_matrix = Embedding2D(
            preprocessor.bond_classes,
            embedding, name='bond_embedding')(sbond_types)
    else:
        bond_matrix = Embedding2DCompressed(
            preprocessor.bond_classes,
            embedding, name='bond_embedding')(sbond_types)

    # The core of the message passing framework: Recurrent and Message-passing layers
    #  The Message Layer computes an update message for each atom given the state of it's neighbors
    #  The Reccurent Layer (GRUStep) computes how the state of the atom changes given a message
    atom_rnn_layer = GRUStep(embedding)
    message_layer = MessageLayer(reducer='sum')

    # Perform the message passing
    for _ in range(message_steps):

        # Get the message updates to each atom
        message = message_layer([atom_state, bond_matrix, connectivity])

        # Update memory and atom states
        atom_state = atom_rnn_layer([message, atom_state])

    # After the message passing step, we allow the atom state to be transformed with a dense layer
    atom_out = Dense(mol_features, activation='sigmoid')(atom_state) 
    
    # Final dense steps to map the molecular representation to properties
    atom_out = BatchNormalization(momentum=0.9)(atom_out)
    atom_out = Dense(mol_features // 2, activation=activation)(atom_out)

    atom_out = BatchNormalization(momentum=0.9)(atom_out)
    atom_out = Dense(mol_features // 4, activation=activation)(atom_out)
    atom_out = Dense(num_output)(atom_out)
    
    # To create a representation for the moleccule, we sum over all its atoms
    mol_out = ReduceAtomToMol(reducer='sum')([atom_out, snode_graph_indices])
    
    # Make a "scaling layer" out of a dense layer
    mol_out = Dense(1, activation='linear', name='scale')(mol_out)

    return GraphModel([node_graph_indices, atom_types, bond_types, connectivity], [mol_out])

In [45]:
model = build_fn(preprocessor)

In [46]:
save_model_files('single-task-atomic-contrib', preprocessor,
                 model, output_props=['sol_water'], normalize=False)

Already output. Skipping


### Multi-Task with Atomic Contributions
Same thing, but with the dielectric constant as an additional input

In [47]:
def build_fn(preprocessor, embedding=128, mol_features=1024,
             message_steps=6, bond_2d=False, activation='relu',
             reduce_atom='sum'):
    """Build a MPNN Keras model
    
    Adapted from: https://github.com/NREL/nfp/blob/master/examples/run_2D_model_noatom_bn.py
    
    Args:
        preprocessor (SolvationPreprocessor): Tool to generate inputs from SMILES string
        embedding (int): Size of the atom/bond embedding
        mol_features (int): Number of features to use to describe a molecule
        message_steps (int): Number of message-passing steps
        activation (str): Desired activation function
        reduce_atom (str): Method used to reduce atom contribution to mol representation
        bond_2d (bool): Whether to use 2D embeddings
    """
    
    # Raw (integer) graph inputs
    #  node_graph_indices - Maps the atom index to which molecule it came from
    #  atom_types - Categorical type of each atom
    #  bond_types - Categorical type of each bond
    #  connectivity - Atoms on each end of each bond
    node_graph_indices = Input(shape=(1,), name='node_graph_indices', dtype='int32')
    atom_types = Input(shape=(1,), name='atom', dtype='int32')
    bond_types = Input(shape=(1,), name='bond', dtype='int32')
    connectivity = Input(shape=(2,), name='connectivity', dtype='int32')
    dielectric_cnst_input = Input(shape=(None,), name='dielectric_constants')
    
    # Expand dielectric constants to a Nx1 array
    dielectric_cnst = Lambda(K.max, arguments={'axis': 0, 'keepdims': True})(dielectric_cnst_input)
    dielectric_cnst = Lambda(K.transpose)(dielectric_cnst)

    # The "indices" and "type" inputs have 1 feature per "entry"
    #  The Squeeze layer removes this singleton dimension to make the data easier to use
    squeeze = Squeeze()
    snode_graph_indices = squeeze(node_graph_indices)
    satom_types = squeeze(atom_types)
    sbond_types = squeeze(bond_types)

    # Create the embedding for each atom type
    atom_state = Embedding(
        preprocessor.atom_classes,
        embedding, name='atom_embedding')(satom_types)

    # Create the embedding for each bond type
    if bond_2d:
        bond_matrix = Embedding2D(
            preprocessor.bond_classes,
            embedding, name='bond_embedding')(sbond_types)
    else:
        bond_matrix = Embedding2DCompressed(
            preprocessor.bond_classes,
            embedding, name='bond_embedding')(sbond_types)

    # The core of the message passing framework: Recurrent and Message-passing layers
    #  The Message Layer computes an update message for each atom given the state of it's neighbors
    #  The Reccurent Layer (GRUStep) computes how the state of the atom changes given a message
    atom_rnn_layer = GRUStep(embedding)
    message_layer = MessageLayer(reducer='sum')

    # Perform the message passing
    for _ in range(message_steps):

        # Get the message updates to each atom
        message = message_layer([atom_state, bond_matrix, connectivity])

        # Update memory and atom states
        atom_state = atom_rnn_layer([message, atom_state])

    # After the message passing step, we allow the atom state to be transformed with a dense layer
    # atom_state = BatchNormalization(momentum=0.9)(atom_state)
    atom_out = Dense(mol_features, activation='sigmoid')(atom_state)
           
    # Add dielectric cnst to molecule output
    #  Outputs shape will be [n_mols, n_dielectrics, mol_features+1]
    atom_out = Lambda(cartesian_product)([atom_out, dielectric_cnst])

    # Final dense steps to map the molecular representation to properties
    atom_out = BatchNormalization(momentum=0.9)(atom_out)
    atom_out = Dense(mol_features // 2, activation=activation)(atom_out)

    atom_out = BatchNormalization(momentum=0.9)(atom_out)
    atom_out = Dense(mol_features // 4, activation=activation)(atom_out)
    atom_out = Dense(1)(atom_out)
    
    # To create a representation for the moleccule, we sum over all its atoms
    mol_out = ReduceAtomToMol(reducer='sum')([atom_out, snode_graph_indices])
    
    # Layer to simplify learning scaling for the outputs
    mol_out = Dense(1, name='scale')(mol_out)
    
    # Reshape to the desired number of ouputs
    mol_out = Squeeze(axis=-1)(mol_out)

    return GraphModel([node_graph_indices, atom_types, bond_types, connectivity,
                       dielectric_cnst_input], [mol_out])

In [48]:
model = build_fn(preprocessor, activation='softplus')

Save a "all solvents" version

In [49]:
preprocessor.dielectric_cnsts = [dielectric_constants[s] for s in sol_cols]

In [50]:
save_model_files('dielectric-constant-softplus-atomic-contrib', preprocessor, model, output_props=sol_cols, normalize=False)

Already output. Skipping


And a "no-ACN version"

In [51]:
preprocessor.dielectric_cnsts = [dielectric_constants[s] for s in sols_without_acn]

In [52]:
save_model_files('dielectric-constant-softplus-atomic-contrib_no-acn', preprocessor, model, output_props=sol_cols, normalize=False)

Already output. Skipping


### Version with Reduced Features
Make a version of the networks where we use only the element as atomic features, 
and bond type (order+is conjugated).

In [53]:
preprocessor = SolvationPreprocessor([dielectric_constants[s] for s in sol_cols], explicit_hs=True,
                                    atom_features=atom_feature_element_only, 
                                    bond_features=bond_feature_type_only)
preprocessor.fit(train_data['smiles_0']);

100%|██████████| 117232/117232 [00:40<00:00, 2867.59it/s]


In [55]:
model = build_fn(preprocessor, activation='softplus')

Save a "all solvents" version

In [56]:
preprocessor.dielectric_cnsts = [dielectric_constants[s] for s in sol_cols]

In [57]:
save_model_files('dielectric-constant-softplus-atomic-contrib-featureless', preprocessor, model, output_props=sol_cols, normalize=False)