In [82]:
import tensorflow as tf
import numpy as np
import jax.numpy as jnp
import jax.experimental.optimizers as optimizers
import jax
import jax.nn as jnn
import math
from scipy.special import softmax
import warnings

In [14]:
def data_parse(record):
    features = {
        'N': tf.io.FixedLenFeature([], tf.int64),
        'labels': tf.io.FixedLenFeature([16], tf.float32),
        'elements': tf.io.VarLenFeature(tf.int64),
        'coords': tf.io.VarLenFeature(tf.float32),
    }
    parsed_features = tf.io.parse_single_example(
        serialized=record, features=features)
    coords = tf.reshape(tf.sparse.to_dense(parsed_features['coords'], default_value=0),[-1,4])
    elements = tf.sparse.to_dense(parsed_features['elements'], default_value=0)
    return (elements, coords), parsed_features['labels']
data = tf.data.TFRecordDataset(
    'qm9.tfrecords', compression_type='GZIP').map(data_parse)

[6 1 1 1 1] [[-1.2698136e-02  1.0858041e+00  8.0009960e-03 -5.3568900e-01]
 [ 2.1504159e-03 -6.0313176e-03  1.9761203e-03  1.3392100e-01]
 [ 1.0117308e+00  1.4637512e+00  2.7657481e-04  1.3392200e-01]
 [-5.4081506e-01  1.4475266e+00 -8.7664372e-01  1.3392299e-01]
 [-5.2381361e-01  1.4379326e+00  9.0639728e-01  1.3392299e-01]] [ 1.0000000e+00  1.5771181e+02  1.5770998e+02  1.5770699e+02
  0.0000000e+00  1.3210000e+01 -3.8769999e-01  1.1710000e-01
  5.0480002e-01  3.5364101e+01  4.4748999e-02 -4.0478931e+01
 -4.0476063e+01 -4.0475117e+01 -4.0498596e+01  6.4689999e+00]


In [43]:
{'C':6,'H':1,'O':8,'N':7,'F':9}

def make_graph(e,x):
    e = e.numpy()
    x = x.numpy()
    r = x[:,:3]
    r2 = np.sum((r - r[:,np.newaxis,:])**2,axis=-1)
    edges = np.where(r2!=0, 1/r2,0.0)
    nodes = np.zeros((len(e),9))
    nodes[np.arange(len(e)), e-1] = 1
    return nodes,edges

def get_label(y):
    return y.numpy()[13]

In [45]:
for d in data:
    (e,x),y = d
    nodes,edges = make_graph(e,x)
    label = get_label(y)
    print (nodes.shape,edges.shape,label)
    break

(5, 9) (5, 5) -40.475117


  


In [115]:
def gcn_layer(nodes,edges,train_weights):
    
    query = jnp.dot(nodes,train_weights[0]) #wq has shape (9,10), output from query has shape (5,10)
    
    #convert r into embeddings
    
    #pairwise distances are used here wk has shape of (9,10), output from keys has shape (5,5,10)
    keys = jnp.dot(jnp.repeat(nodes[jnp.newaxis,...],nodes.shape[0],axis=0), train_weights[1]) * edges
    d_sq = math.sqrt(keys.shape[-1])
    b = jnn.softmax(query[jnp.newaxis,...] * keys/d_sq)
    
    #wv has shape (9,10), output shape (5,5,10)
    values = jnp.dot(jnp.repeat(nodes[jnp.newaxis,...],nodes.shape[0],axis=0), train_weights[2])
    
    messages = b * values #out shape (5,5,10)
    
    net_message = jnp.mean(messages,axis= 1)
    self_message = nodes @ train_weights[3]
    out = jnn.relu((net_message+self_message)) 
    return out,edges
    #messages =jnp.dot(jnp.repeat(nodes[jnp.newaxis,...],nodes.shape[0],axis=0), train_weights[0])
    
def graph_level_fts(nodes):
    node_avg = jnp.mean(nodes,axis=1)
    return node_avg

In [102]:
r = np.repeat(nodes[np.newaxis,...],nodes.shape[0],axis=0) @ np.ones((9,10 )) 
r.shape

(5, 5, 10)

In [99]:
a = jnp.ones((5,5))
b = jnp.ones((5,10))
c = a@b

In [100]:
print(c.shape)
#print(softmax(c).shape)

(5, 10)


In [114]:
out_dim = 10
embed_dim = 4

#get node embeddings instead of one-hot 
element_embeddings = np.random.normal(size=(9,embed_dim))
embed_nodes = nodes @ element_embeddings

#get edge embeddings from pairwise distances
edge_embeddings = np.random.normal(size=(1,len(edges),out_dim))
embed_edges = edges[...,np.newaxis] * edge_embeddings

#trainable weights
w1 = np.random.normal(size = (4,embed_dim,out_dim))
w2 = np.random.normal(size = (4,embed_dim,out_dim))


#call gcn

n,e = gcn_layer(embed_nodes,embed_edges,w1)
n,e = gcn_layer(embed_nodes,embed_edges,w2)
n = graph_level_fts(n)
print(n.shape,e)

(5, 10) [[[-0.00000000e+00 -0.00000000e+00  0.00000000e+00  0.00000000e+00
    0.00000000e+00  0.00000000e+00 -0.00000000e+00  0.00000000e+00
    0.00000000e+00  0.00000000e+00]
  [-1.22371526e+00  8.93893943e-01 -4.57451839e-01  4.39772426e-01
    3.25835304e-02  3.07619013e-01  6.95134339e-02 -1.58712913e-01
    1.29516147e+00  4.09459342e-01]
  [ 2.79695709e-01  1.04352116e+00  4.08119803e-01 -9.51305289e-01
   -1.10860121e-03  2.94458796e-02 -7.27332328e-01  1.31587317e+00
    5.74063674e-01 -8.52763956e-02]
  [ 9.33926729e-01  1.27769865e+00 -9.49628075e-01 -1.04375695e+00
   -1.61830670e-01  6.14040358e-01 -9.49619987e-01 -8.31977120e-01
   -1.09164526e-02 -1.28190247e+00]
  [ 1.70567426e-01  3.31296255e-01  4.72902912e-01  6.85578945e-01
   -1.07110293e+00 -8.07125034e-01 -1.45487169e-01  1.31759370e+00
   -2.92419599e-01 -1.75344817e+00]]

 [[-8.07275589e-02 -7.39898960e-02  4.08378002e-03  4.40485782e-01
    8.25081290e-01  7.01286254e-01 -6.15533140e-01  3.90065873e-01
    2.

In [110]:
print(out.shape)

(5, 10)
