<a href="https://colab.research.google.com/github/choderalab/gimlet/blob/master/lime/scripts/190923_yuanqing_gimlet_potential.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

Gimlet Potential
================

sep 23, 2019 yuanqing

# Introduction

We propose a flavor of machine learning potential where we preserve the traditional functional form of a Molecular Mechanics (MM) potential where the energy is the sum of bond, angle, torsion, and non-bonded terms, which are functions of the corresponding bond length, angles, dihedral angles, and distances, respectively. 

\begin{equation}
E = \sum\limits_{r} E_\text{bonded}(r) + \sum\limits_{\theta}E_\text{angle}(\theta) + \sum\limits_{\phi}E_\text{torsion}(\phi) + \sum\limits_{l} E_\text{non-bonded}(l).
\end{equation}

The difference is, however, that for any of these terms, rather than using a harmonic or polynomial expression, we design a graph net to parameterize a _functional_ that takes the numeric value of $r, \theta, \phi, l$ and outputs the energy contribution.

# Prep
boring stuff.
please, please, please don't look inside.

In [0]:
! wget https://s3-us-west-1.amazonaws.com/deepchem.io/datasets/molnet_publish/qm9.zip
! unzip qm9.zip

--2019-09-24 18:39:34--  https://s3-us-west-1.amazonaws.com/deepchem.io/datasets/molnet_publish/qm9.zip
Resolving s3-us-west-1.amazonaws.com (s3-us-west-1.amazonaws.com)... 52.219.116.152
Connecting to s3-us-west-1.amazonaws.com (s3-us-west-1.amazonaws.com)|52.219.116.152|:443... connected.
HTTP request sent, awaiting response... 200 OK
Length: 44827467 (43M) [application/zip]
Saving to: ‘qm9.zip’


2019-09-24 18:39:38 (14.4 MB/s) - ‘qm9.zip’ saved [44827467/44827467]

Archive:  qm9.zip
  inflating: gdb9.sdf                
  inflating: gdb9.sdf.csv            
  inflating: QM9_README              


In [0]:
! pip install tensorflow==2.0.0-beta1
! pip uninstall gin-config -y
! rm -rf gimlet
! git clone https://github.com/choderalab/gimlet.git

Collecting tensorflow==2.0.0-beta1
[?25l  Downloading https://files.pythonhosted.org/packages/29/6c/2c9a5c4d095c63c2fb37d20def0e4f92685f7aee9243d6aae25862694fd1/tensorflow-2.0.0b1-cp36-cp36m-manylinux1_x86_64.whl (87.9MB)
[K     |████████████████████████████████| 87.9MB 30.6MB/s 
[?25hCollecting tb-nightly<1.14.0a20190604,>=1.14.0a20190603 (from tensorflow==2.0.0-beta1)
[?25l  Downloading https://files.pythonhosted.org/packages/a4/96/571b875cd81dda9d5dfa1422a4f9d749e67c0a8d4f4f0b33a4e5f5f35e27/tb_nightly-1.14.0a20190603-py3-none-any.whl (3.1MB)
[K     |████████████████████████████████| 3.1MB 31.7MB/s 
Collecting tf-estimator-nightly<1.14.0.dev2019060502,>=1.14.0.dev2019060501 (from tensorflow==2.0.0-beta1)
[?25l  Downloading https://files.pythonhosted.org/packages/32/dd/99c47dd007dcf10d63fd895611b063732646f23059c618a373e85019eb0e/tf_estimator_nightly-1.14.0.dev2019060501-py2.py3-none-any.whl (496kB)
[K     |████████████████████████████████| 501kB 26.0MB/s 
Installing collected p

In [0]:
import sys
sys.path.append('/content/gimlet')
from sklearn import metrics
import os

import tensorflow as tf
import gin
import lime
import pandas as pd
import numpy as np

In [0]:
mols_ds = gin.i_o.from_sdf.to_ds('gdb9.sdf', has_charge=False)

attr_ds = pd.read_csv('gdb9.sdf.csv').values[:, 1:].astype(np.float32)

attr_ds = attr_ds / np.linalg.norm(attr_ds, axis=0) - np.std(attr_ds, axis=0)

attr_ds = tf.data.Dataset.from_tensor_slices(attr_ds)

ds = tf.data.Dataset.zip((mols_ds, attr_ds))

ds = ds.take(10240)

ds = ds.map(
    lambda mol, attr: (mol[0], mol[1], mol[2], attr))

ds = gin.probabilistic.gn_hyper.HyperGraphNet.batch(
    ds,
    128,
    attr_dimension=19).cache(
        str(os.getcwd()) + 'temp')

n_batches = int(gin.probabilistic.gn.GraphNet.get_number_batches(ds))
n_te = n_batches // 10

ds_te = ds.take(n_te)
ds_vl = ds.skip(n_te).take(n_te)
ds_tr = ds.skip(2 * n_te)

Instructions for updating:
Use tf.where in 2.0, which has the same broadcast rule as np.where


In [0]:
pd.read_csv('gdb9.sdf.csv')

Unnamed: 0,mol_id,A,B,C,mu,alpha,homo,lumo,gap,r2,zpve,u0,u298,h298,g298,cv,u0_atom,u298_atom,h298_atom,g298_atom
0,gdb_1,157.71180,157.709970,157.706990,0.0000,13.21,-0.3877,0.1171,0.5048,35.3641,0.044749,-40.478930,-40.476062,-40.475117,-40.498597,6.469,-395.999595,-398.643290,-401.014647,-372.471772
1,gdb_2,293.60975,293.541110,191.393970,1.6256,9.46,-0.2570,0.0829,0.3399,26.1563,0.034358,-56.525887,-56.523026,-56.522082,-56.544961,6.316,-276.861363,-278.620271,-280.399259,-259.338802
2,gdb_3,799.58812,437.903860,282.945450,1.8511,6.31,-0.2928,0.0687,0.3615,19.0002,0.021375,-76.404702,-76.401867,-76.400922,-76.422349,6.002,-213.087624,-213.974294,-215.159658,-201.407171
3,gdb_4,0.00000,35.610036,35.610036,0.0000,16.28,-0.2845,0.0506,0.3351,59.5248,0.026841,-77.308427,-77.305527,-77.304583,-77.327429,8.574,-385.501997,-387.237686,-389.016047,-365.800724
4,gdb_5,0.00000,44.593883,44.593883,2.8937,12.99,-0.3604,0.0191,0.3796,48.7476,0.016601,-93.411888,-93.409370,-93.408425,-93.431246,6.278,-301.820534,-302.906752,-304.091489,-288.720028
5,gdb_6,285.48839,38.982300,34.298920,2.1089,14.18,-0.2670,-0.0406,0.2263,59.9891,0.026603,-114.483613,-114.480746,-114.479802,-114.505268,6.413,-358.756935,-360.512706,-362.291066,-340.464421
6,gdb_7,80.46225,19.906490,19.906330,0.0000,23.95,-0.3385,0.1041,0.4426,109.5031,0.074542,-79.764152,-79.760666,-79.759722,-79.787269,10.098,-670.788296,-675.710476,-679.860821,-626.927299
7,gdb_8,127.83497,24.858720,23.978720,1.5258,16.97,-0.2653,0.0784,0.3437,83.7940,0.051208,-115.679136,-115.675816,-115.674872,-115.701876,8.751,-481.106758,-484.355372,-487.319724,-450.124128
8,gdb_9,160.28041,8.593230,8.593210,0.7156,28.78,-0.2609,0.0613,0.3222,177.1963,0.055410,-116.609549,-116.605550,-116.604606,-116.633775,12.482,-670.268091,-673.980434,-677.537155,-631.346845
9,gdb_10,159.03567,9.223270,9.223240,3.8266,24.45,-0.3264,0.0376,0.3640,160.7223,0.045286,-132.718150,-132.714563,-132.713619,-132.742149,10.287,-589.812024,-592.893721,-595.857446,-557.125708


# Architecture

In [0]:
config_space = {
    'D_V': [16, 32, 64, 128, 256],
    'D_E': [16, 32, 64, 128, 256],
    'D_A': [16, 32, 64, 128, 256],
    'D_T': [16, 32, 64, 128, 256],
    'D_U': [16, 32, 64, 128, 256],


    'phi_e_0': [32, 64, 128],
    'phi_e_a_0': ['elu', 'relu', 'tanh', 'sigmoid'],
    'phi_e_a_1': ['elu', 'relu', 'tanh', 'sigmoid'],

    'phi_v_0': [32, 64, 128],
    'phi_v_a_0': ['elu', 'relu', 'tanh', 'sigmoid'],
    'phi_v_a_1': ['elu', 'relu', 'tanh', 'sigmoid'],

    'phi_a_0': [32, 64, 128],
    'phi_a_a_0': ['elu', 'relu', 'tanh', 'sigmoid'],
    'phi_a_a_1': ['elu', 'relu', 'tanh', 'sigmoid'],

    'phi_t_0': [32, 64, 128],
    'phi_t_a_0': ['elu', 'relu', 'tanh', 'sigmoid'],
    'phi_t_a_1': ['elu', 'relu', 'tanh', 'sigmoid'],

    'phi_u_0': [32, 64, 128],
    'phi_u_a_0': ['elu', 'relu', 'tanh', 'sigmoid'],
    'phi_u_a_1': ['elu', 'relu', 'tanh', 'sigmoid'],

    'f_e_0': [32, 64, 128],
    'f_e_a_0': ['elu', 'relu', 'tanh', 'sigmoid'],
    'f_e_a_1': ['elu', 'relu', 'tanh', 'sigmoid'],

    'f_r': [32, 64, 128],
    'f_r_a': ['elu', 'relu', 'tanh', 'sigmoid'],

    'learning_rate': [1e-5, 1e-4, 1e-3]

}

point = dict(zip(config_space.keys(), [value[0] for value in config_space.values()]))

In [0]:
class f_v(tf.keras.Model):
    """ Featurization of nodes.
    Here we simply featurize atoms using one-hot encoding.

    """
    def __init__(self, units=point['D_V']):
        super(f_v, self).__init__()
        self.d = tf.keras.layers.Dense(units)

    @tf.function
    def call(self, x):
        x = tf.one_hot(x, 10)
        # set shape because Dense doesn't like variation
        x.set_shape([None, 10])
        return self.d(x)

In [0]:
class f_r(tf.keras.Model):
    """ Readout function
    """
    def __init__(self, units=point['f_r'], f_r_a=point['f_r_a']):
        super(f_r, self).__init__()
        self.d_k = tf.keras.layers.Dense(units, activation='tanh')
        self.d_q = tf.keras.layers.Dense(units, activation='tanh')
        self.d_pair_0 = tf.keras.layers.Dense(units, activation=f_r_a)
        self.d_pair_1 = tf.keras.layers.Dense(16)
        
        self.d_e_1 = tf.keras.layers.Dense(16)
        self.d_e_0 = tf.keras.layers.Dense(units, activation='tanh')
        
        self.d_a_1 = tf.keras.layers.Dense(16)
        self.d_a_0 = tf.keras.layers.Dense(units, activation='tanh')
        
        self.d_t_1 = tf.keras.layers.Dense(16)
        self.d_t_0 = tf.keras.layers.Dense(units, activation='tanh')

        self.units = units
        self.d_v = point['D_V']
        self.d_e = point['D_E']
        self.d_a = point['D_A']
        self.d_t = point['D_T']
        self.d_u = point['D_U']

    # @tf.function
    def call(self, h_v, h_e, h_a, h_t, h_u,
        h_v_history, h_e_history, h_a_history,
        h_t_history, h_u_history,
        atom_in_mol, bond_in_mol, angle_in_mol, torsion_in_mol,
        adjacency_map, coordinates):
        
        adjacency_map_full = tf.math.add(
            tf.transpose(
                adjacency_map),
            adjacency_map)
        
        per_mol_mask = tf.matmul(
            tf.where(
                atom_in_mol,
                tf.ones_like(atom_in_mol, dtype=tf.float32),
                tf.zeros_like(atom_in_mol, dtype=tf.float32)),
            tf.transpose(
                tf.where(
                    atom_in_mol,
                    tf.ones_like(atom_in_mol, dtype=tf.float32),
                    tf.zeros_like(atom_in_mol, dtype=tf.float32))))

        # get distance matrix
        distance = gin.deterministic.md.get_distance_matrix(coordinates)

        distance = tf.expand_dims(
            distance,
            2)

        n_atoms = tf.shape(distance, tf.int64)[0]

        # (n_atoms, n_atoms, units)
        k = tf.multiply(
            tf.tile(
                tf.expand_dims(
                    per_mol_mask,
                    2),
                [1, 1, self.units]),
            tf.tile(
                tf.expand_dims(
                    self.d_k(h_v),
                    1),
                [1, n_atoms, 1]))

        # (n_atoms, n_atoms, units)
        q = tf.multiply(
            tf.tile(
                tf.expand_dims(
                    per_mol_mask,
                    2),
                [1, 1, self.units]),
            tf.tile(
                tf.expand_dims(
                    self.d_q(h_v),
                    0),
                [n_atoms, 1, 1]))

        h_pair = tf.concat(
            [
                k,
                q,
            ],
            axis=2)

        h_pair = tf.math.multiply(
            tf.tile(
                tf.expand_dims(
                    tf.math.multiply(
                        per_mol_mask,
                        tf.where(
                            tf.equal(
                                adjacency_map_full,
                                tf.constant(0, dtype=tf.float32)),
                            tf.ones_like(adjacency_map),
                            tf.zeros_like(adjacency_map))),
                    2),
                [1, 1, 16]),
            self.d_pair_1(self.d_pair_0(h_pair)))
            
        y_pair = h_pair

        y_a = self.d_a_1(
            self.d_a_0(
                tf.reshape(
                    h_a_history,
                    [
                        tf.shape(h_a_history)[0],
                        -1
                    ])))
        
        y_e = self.d_e_1(
            self.d_e_0(
                tf.reshape(
                    h_e_history,
                    [
                        tf.shape(h_e_history)[0],
                        -1
                    ])))
        
        y_t = self.d_t_1(
            self.d_t_0(
                tf.reshape(
                    h_t_history,
                    [
                        tf.shape(h_t_history)[0],
                        -1
                    ])))

        return y_e, y_a, y_t, y_pair


In [0]:
gn = gin.probabilistic.gn_hyper.HyperGraphNet(
    f_e=lime.nets.for_gn.ConcatenateThenFullyConnect(
        (point['f_e_0'], point['f_e_a_0'], point['D_E'], point['f_e_a_1'])),
    f_a=tf.keras.layers.Dense(point['D_A']),
    f_t=tf.keras.layers.Dense(point['D_T']),
    f_v=f_v(),
    f_u=(lambda atoms, adjacency_map, batched_attr_in_mol: \
        tf.tile(
            tf.zeros((1, point['D_U'])),
            [
                 tf.math.count_nonzero(batched_attr_in_mol),
                 1
            ]
        )),
    phi_e=lime.nets.for_gn.ConcatenateThenFullyConnect(
        (point['phi_e_0'], point['phi_e_a_0'], point['D_E'], point['phi_e_a_1'])),
    phi_u=lime.nets.for_gn.ConcatenateThenFullyConnect(
        (point['phi_u_0'], point['phi_u_a_0'], point['D_U'], point['phi_u_a_1'])),
    phi_v=lime.nets.for_gn.ConcatenateThenFullyConnect(
        (point['phi_v_0'], point['phi_v_a_0'], point['D_V'], point['phi_v_a_1'])),
    phi_a=lime.nets.for_gn.ConcatenateThenFullyConnect(
        (point['phi_a_0'], point['phi_a_a_0'], point['D_A'], point['phi_a_a_1'])),
    phi_t=lime.nets.for_gn.ConcatenateThenFullyConnect(
        (point['phi_t_0'], point['phi_t_a_0'], point['D_T'], point['phi_t_a_1'])),
    f_r=f_r(),
    repeat=5)

In [0]:
optimizer = tf.keras.optimizers.Adam(1e-5)

In [0]:
for dummy_idx in range(10):
    for atoms, adjacency_map, coordinates, attr, atom_in_mol, attr_in_mol in ds_tr:
        with tf.GradientTape() as tape:
            y_e, y_a, y_t, y_pair = gn(
                atoms, adjacency_map, coordinates, atom_in_mol, attr_in_mol)
            
            bond_idxs, angle_idxs, torsion_idxs = gin.probabilistic.gn_hyper\
                .get_geometric_idxs(atoms, adjacency_map)
            
            is_bond = tf.greater(
                adjacency_map,
                tf.constant(0, dtype=tf.float32))
            
            distance_matrix = gin.deterministic.md.get_distance_matrix(
                coordinates)
            
            bond_distances = tf.boolean_mask(
                distance_matrix,
                is_bond)

            angle_angles = gin.deterministic.md.get_angles(
                coordinates,
                angle_idxs)

            torsion_dihedrals = gin.deterministic.md.get_dihedrals(
                coordinates,
                torsion_idxs)
            
            u_bond = tf.math.polyval(
                tf.transpose(y_e),
                bond_distances)
            
            u_angle = tf.math.polyval(
                tf.transpose(y_a),
                angle_angles)
            
            u_dihedral = tf.math.polyval(
                tf.transpose(y_t),
                torsion_dihedrals)
            
            u_pair = tf.math.polyval(
                tf.transpose(y_pair),
                distance_matrix)
            
        
        print(u_pair)
        
        break

        variables = gn.variables
        grad = tape.gradient(loss, variables)
        optimizer.apply_gradients(
            zip(grad, variables))

tf.Tensor(
[[-1.5051595e+00  0.0000000e+00 -6.1214119e+05 ...  0.0000000e+00
   0.0000000e+00  0.0000000e+00]
 [ 0.0000000e+00 -1.0669639e+00  0.0000000e+00 ...  0.0000000e+00
   0.0000000e+00  0.0000000e+00]
 [-5.0146897e+05  0.0000000e+00 -1.0669639e+00 ...  0.0000000e+00
   0.0000000e+00  0.0000000e+00]
 ...
 [ 0.0000000e+00  0.0000000e+00  0.0000000e+00 ...  0.0000000e+00
   0.0000000e+00  0.0000000e+00]
 [ 0.0000000e+00  0.0000000e+00  0.0000000e+00 ...  0.0000000e+00
   0.0000000e+00  0.0000000e+00]
 [ 0.0000000e+00  0.0000000e+00  0.0000000e+00 ...  0.0000000e+00
   0.0000000e+00  0.0000000e+00]], shape=(128, 128), dtype=float32)
tf.Tensor(
[[-1.5051595e+00  0.0000000e+00 -6.1214119e+05 ...  0.0000000e+00
   0.0000000e+00  0.0000000e+00]
 [ 0.0000000e+00 -1.0669639e+00  0.0000000e+00 ...  0.0000000e+00
   0.0000000e+00  0.0000000e+00]
 [-5.0146897e+05  0.0000000e+00 -1.0669639e+00 ...  0.0000000e+00
   0.0000000e+00  0.0000000e+00]
 ...
 [ 0.0000000e+00  0.0000000e+00  0.0000000

KeyboardInterrupt: ignored