In [10]:
import tensorflow as tf
import numpy as np
import jax
import jax.numpy as jnp
import jax.experimental.optimizers as optimizers
from jax.experimental import stax
from jax.experimental.stax import (Conv, Dense, MaxPool, Relu, Flatten)
from jax import jit, grad, random,vmap,value_and_grad
import jax.nn as jnn
from jax.tree_util import tree_multimap
import math
from scipy.special import softmax
import matplotlib.pyplot as plt
import warnings
from functools import partial # for use with vmap

In [2]:
def data_parse(record):
    features = {
        'N': tf.io.FixedLenFeature([], tf.int64),
        'labels': tf.io.FixedLenFeature([16], tf.float32),
        'elements': tf.io.VarLenFeature(tf.int64),
        'coords': tf.io.VarLenFeature(tf.float32),
    }
    parsed_features = tf.io.parse_single_example(
        serialized=record, features=features)
    coords = tf.reshape(tf.sparse.to_dense(parsed_features['coords'], default_value=0),[-1,4])
    elements = tf.sparse.to_dense(parsed_features['elements'], default_value=0)
    return (elements, coords), parsed_features['labels']
data = tf.data.TFRecordDataset(
    'qm9.tfrecords', compression_type='GZIP').map(data_parse)


In [3]:
def make_graph(e,x):
    e = e.numpy()
    x = x.numpy()
    r = x[:,:3]
    r2 = np.sum((r - r[:,np.newaxis,:])**2,axis=-1)
    edges = np.where(r2!=0, 1/r2,0.0) #[N,N]
    nodes = np.zeros((len(e),9))
    nodes[np.arange(len(e)), e-1] = 1
    return nodes,edges

def get_label(y):
    return y.numpy()[13]

## Check this code

In [4]:
#Definition of GCN with attention

def GCN(out_dim,embed_dim):
    def init_fun(global_ft_len):
        #output_shape = input_shape[:-1] + (out_dim,)
        #trainable weights
        #w = np.random.normal(size =(4, embed_dim,out_dim), scale=1e-1)
        wq = np.random.normal(size =(embed_dim,out_dim), scale=1e-1)
        wk = np.random.normal(size =(embed_dim,out_dim), scale=1e-1)
        wv = np.random.normal(size =(embed_dim,out_dim), scale=1e-1)
        wn = np.random.normal(size =(embed_dim,out_dim), scale=1e-1)
        wu = np.random.normal(size =(out_dim,global_ft_len), scale=1e-1)
    
        return (wq,wk,wv,wn,wu)
    
    def apply_fun(train_weights,nodes,edges,features, **kwargs):
       
        query = jnp.dot(nodes,train_weights[0]) 
        
        keys = jnp.dot(jnp.repeat(nodes[jnp.newaxis,...],nodes.shape[0],axis=0), 
                       train_weights[1])* edges[...,jnp.newaxis]
        
        d_sq = math.sqrt(keys.shape[-1])
        b = jnn.softmax(query[jnp.newaxis,...] * keys/d_sq)
       
        values = jnp.dot(jnp.repeat(nodes[jnp.newaxis,...],nodes.shape[0],axis=0), train_weights[2])

        messages = b * values 
        
        net_message = jnp.mean(messages,axis= 1)
        
        self_message = jnp.dot(nodes, train_weights[3])

        #self loop
        out_nodes = jnn.relu((net_message))+self_message
        
        #global features
        global_node_features = jnp.sum(out_nodes, axis=0)
       
        new_features = jax.nn.relu(global_node_features@train_weights[4]) + features

        return out_nodes,edges,new_features
    
    return init_fun,apply_fun
  

def y_hat(n,e,params1,params2,params3,b):
    init_fts = jnp.zeros(global_ft_len)
    n,e,fts = gcn_apply(params1,n,e,init_fts)
    n,e,fts = gcn_apply(params2,n,e,fts)
    y_hat = fts @ params3 + b
    
    return y_hat


def loss(nodes,edges, targets,params1,params2,params3,b):
    predictions = y_hat(nodes,edges,params1,params2,params3,b)
    
    return (targets - predictions)**2

#gradient of loss wrt params1,params2,params3,b
loss_grad = jax.grad(loss, (3, 4, 5,6))



# Training 

In [22]:
#Parameters
out_dim = 9
embed_dim = 9
global_ft_len = 8

#You have to plot for different datasets
#Plot the test-loss for test set

test_set = data.take(100)
valid_test_len = 100
valid_set = data.skip(100).take(10)
train_set = data.skip(110).take(50).shuffle(50)

gcn_init,gcn_apply = GCN(out_dim,embed_dim)
params1 = gcn_init(global_ft_len)
params2 = gcn_init(global_ft_len)
params3 = np.random.normal(size=(global_ft_len))

b = 245.


epochs = 16
batch_size = 32
eta = 1e-2
val_loss = [0. for _ in range(epochs)]


In [24]:
#Now this is training.

for epoch in range(epochs):
    bi = 0
    grad_est = None
    for d in train_set:         
        # do training step
        # but do not update
        # until have enough points
        (e,x), y = d
        nodes,edges = make_graph(e,x)
        label = get_label(y)
        
        if grad_est is None:
            grad_est = loss_grad(nodes,edges, label, params1, params2, params3, b)
        else:
            grad_est += loss_grad(nodes,edges, label, params1, params2, params3, b)
        bi += 1
        if bi == batch_size:
            # have enough to update            
            # update regression weights
            params3 -= eta * grad_est[2]  / batch_size
            b -= eta * grad_est[3]  / batch_size
            # update GNN weights            
            for i,w in [(0, params1), (1, params2)]:
                for j, param in enumerate(w):
                    param -= eta * grad_est[i][j] / batch_size
            # reset tracking of batch index
            bi = 0            
            grad_est = None            
    # compute validation loss    
    for v in valid_set:
        (e,x), y = v
        nodes,edges = make_graph(e,x)
        label = get_label(y)
        # convert SE to RMSE
        val_loss[epoch] += jnp.sqrt(loss(nodes,edges, label, params1, params2, params3, b) / valid_test_len)
    print('epoch:', epoch, 'loss: {:.2f}'.format(val_loss[epoch]))

  


epoch: 0 loss: 923.17
epoch: 1 loss: 922.10
epoch: 2 loss: 921.03
epoch: 3 loss: 919.95
epoch: 4 loss: 918.90
epoch: 5 loss: 917.85
epoch: 6 loss: 916.76
epoch: 7 loss: 915.69
epoch: 8 loss: 914.65
epoch: 9 loss: 913.58
epoch: 10 loss: 912.53
epoch: 11 loss: 911.44
epoch: 12 loss: 910.39
epoch: 13 loss: 909.34
epoch: 14 loss: 452.00
epoch: 15 loss: 451.70


## Ignore the following cell, used to debug

In [160]:
features = jnp.zeros(global_ft_len)
print(nodes.shape,'embed nodes')

out = gcn_apply(gcn_init(global_ft_len),nodes,edges,features)

print('input feautres', features)
print('output features', out[2].shape)
pred = y_hat(nodes,edges, gcn_init(global_ft_len), gcn_init(global_ft_len), np.random.normal(size=(global_ft_len)), 222.)
print(pred)
grad_est = loss_grad(nodes,edges, label, gcn_init(global_ft_len), gcn_init(global_ft_len)
                     , np.random.normal(size=(global_ft_len)), 222.)
print(grad_est[3])

(5, 9) embed nodes
input feautres [0. 0. 0. 0. 0. 0. 0. 0.]
output features (8,)
221.61868
525.4042
