In [1]:
import tensorflow as tf
import numpy as np
import jax
import jax.numpy as jnp
import jax.experimental.optimizers as optimizers
from jax.experimental import stax
from jax.experimental.stax import (Conv, Dense, MaxPool, Relu, Flatten)
from jax import jit, grad, random,vmap,value_and_grad
import jax.nn as jnn
from jax.tree_util import tree_multimap
import math
from scipy.special import softmax
import matplotlib.pyplot as plt
import warnings
from functools import partial # for use with vmap

In [2]:
def data_parse(record):
    features = {
        'N': tf.io.FixedLenFeature([], tf.int64),
        'labels': tf.io.FixedLenFeature([16], tf.float32),
        'elements': tf.io.VarLenFeature(tf.int64),
        'coords': tf.io.VarLenFeature(tf.float32),
    }
    parsed_features = tf.io.parse_single_example(
        serialized=record, features=features)
    coords = tf.reshape(tf.sparse.to_dense(parsed_features['coords'], default_value=0),[-1,4])
    elements = tf.sparse.to_dense(parsed_features['elements'], default_value=0)
    return (elements, coords), parsed_features['labels']
data = tf.data.TFRecordDataset(
    'qm9.tfrecords', compression_type='GZIP').map(data_parse)

In [3]:
def convert_record(d):
    # break up record
    (e, x), y = d
    # 
    e = e.numpy()
    x = x.numpy()
    r = x[:, :3]    
    # use nearest power of 2 (16)
    ohc = np.zeros((len(e), 16))
    ohc[np.arange(len(e)), e - 1] = 1    
    return (ohc, r), y.numpy()[13]

for d in data:
    (e,x), y = convert_record(d)
    print('Element one hots\n', e)
    print('Coordinates\n', x)
    print('Label:', y)
    break
    
def x2e(x):
    '''convert xyz coordinates to inverse pairwise distance'''    
    r2 = jnp.sum((x - x[:, jnp.newaxis, :])**2, axis=-1)
    e = jnp.where(r2 != 0, 1 / r2, 0.)
    return e

def gnn_layer(nodes, edges, features, we, wv, wu):
    '''Implementation of the GNN'''
    # make nodes be N x N so we can just multiply directly
    ek = jax.nn.relu(
        jnp.repeat(nodes[jnp.newaxis,...], nodes.shape[0], axis=0) @ we * edges[...,jnp.newaxis])
    ebar = jnp.sum(ek, axis=1)
    new_nodes = jax.nn.relu(ebar @ wv) + nodes
    
    global_node_features = jnp.sum(new_nodes, axis=0)
    #print(global_node_features.shape)
    new_features = jax.nn.relu(global_node_features  @ wu) + features    
    return new_nodes, edges, new_features

Element one hots
 [[0. 0. 0. 0. 0. 1. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0.]
 [1. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0.]
 [1. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0.]
 [1. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0.]
 [1. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0.]]
Coordinates
 [[-1.2698136e-02  1.0858041e+00  8.0009960e-03]
 [ 2.1504159e-03 -6.0313176e-03  1.9761203e-03]
 [ 1.0117308e+00  1.4637512e+00  2.7657481e-04]
 [-5.4081506e-01  1.4475266e+00 -8.7664372e-01]
 [-5.2381361e-01  1.4379326e+00  9.0639728e-01]]
Label: -40.475117


In [4]:
graph_feature_len = 8
node_feature_len = 16
msg_feature_len = 16

# make our weights
def init_weights(g, n, m):
    we = np.random.normal(size=(n, m), scale=1e-1)
    wv = np.random.normal(size=(m, n), scale=1e-1)
    wu = np.random.normal(size=(n, g), scale=1e-1)
    return we, wv, wu

# make a graph
nodes = e
edges = x2e(x)
features = jnp.zeros(graph_feature_len)

# eval
out = gnn_layer(nodes, edges, features, *init_weights(graph_feature_len, node_feature_len, msg_feature_len))
print('input feautres', features)
print('output features', out[2])



input feautres [0. 0. 0. 0. 0. 0. 0. 0.]
output features [0.         0.         0.         0.         0.03054902 0.
 0.38822997 0.49316457]


In [5]:
w1 = init_weights(graph_feature_len, node_feature_len, msg_feature_len)
w2 = init_weights(graph_feature_len, node_feature_len, msg_feature_len)
w3 = np.random.normal(size=(graph_feature_len))
b = -325. # starting guess

@jax.jit
def model(nodes, coords, w1, w2, w3, b):
    f0 = jnp.zeros(graph_feature_len)
    e0 = x2e(coords)
    n0 = nodes
    n,e,f = gnn_layer(n0, e0, f0, *w1)
    n,e,f = gnn_layer(n, e, f, *w2)
    yhat = f @ w3 + b
    return yhat

def lossA(nodes, coords, y, w1, w2, w3, b):
    return (model(nodes, coords, w1, w2, w3, b) - y)**2
loss_grad = jax.grad(lossA, (3, 4, 5, 6))

In [8]:
test_set = data.take(100)
valid_set = data.skip(100).take(10)
train_set = data.skip(110).take(50).shuffle(50)

epochs = 16
batch_size = 32
eta = 1e-2
val_loss = [0. for _ in range(epochs)]
for epoch in range(epochs):
    bi = 0
    grad_est = None
    for d in train_set:         
        # do training step
        # but do not update
        # until have enough points
        (e,x), y = convert_record(d)
        if grad_est is None:
            grad_est = loss_grad(e, x, y, w1, w2, w3, b)
        else:
            grad_est += loss_grad(e, x, y, w1, w2, w3, b)
        bi += 1
        if bi == batch_size:
            # have enough to update            
            # update regression weights
            w3 -= eta * grad_est[2]  / batch_size
            b -= eta * grad_est[3]  / batch_size
            # update GNN weights            
            for i,w in [(0, w1), (1, w2)]:
                for j, param in enumerate(w):
                    param -= eta * grad_est[i][j] / batch_size
            # reset tracking of batch index
            bi = 0            
            grad_est = None            
    # compute validation loss    
    for v in valid_set:
        (e,x), y = convert_record(v)
        # convert SE to RMSE
        val_loss[epoch] += jnp.sqrt(lossA(e, x, y, w1, w2, w3, b) / 1000)
    print('epoch:', epoch, 'loss: {:.2f}'.format(val_loss[epoch]))

epoch: 0 loss: 31.86
epoch: 1 loss: 31.80
epoch: 2 loss: 31.67
epoch: 3 loss: 31.44
epoch: 4 loss: 31.31
epoch: 5 loss: 31.20
epoch: 6 loss: 31.06
epoch: 7 loss: 30.99
epoch: 8 loss: 30.96
epoch: 9 loss: 30.77
epoch: 10 loss: 30.66
epoch: 11 loss: 30.58
epoch: 12 loss: 30.44
epoch: 13 loss: 30.31
epoch: 14 loss: 30.25
epoch: 15 loss: 30.15
