In [3]:
import pymc4 as pm
import tensorflow as tf

In [4]:
from fundl.datasets import make_graph_counting_dataset
from fundl.utils import pad_graph
import numpy as onp
import networkx as nx
import jax.numpy as np
from chemgraph import atom_graph
import janitor.chemistry


In [5]:
import pandas as pd

df = (
    pd.read_csv("bace.csv")
    .rename_column("mol", "structure")
    .smiles2mol("structure", "mol")
    .join_apply(lambda x: atom_graph(x["mol"]), "graph")
    .join_apply(lambda x: len(x["graph"]), "graph_size")
)

Gs = df["graph"].tolist()

print("Generating feature matrices and adjacency matrices...")
Fs = []
As = []
for G in Gs:
    Fs.append(onp.vstack([d["features"] for n, d in G.nodes(data=True)]))
    As.append(onp.asarray(nx.adjacency_matrix(G).todense()))

largest_graph_size = max([len(G) for G in Gs])

print("Preparing outputs...")
# Next line is a dummy task, count number of nodes in graph.
# y = np.array([len(G) for G in Gs]).reshape(-1, 1)

# Next line is a real task.
y = df["pIC50"].values.reshape(-1, 1)

print("Padding graphs to correct size...")
for i, (F, A) in enumerate(zip(Fs, As)):
    F, A = pad_graph(F, A, largest_graph_size)
    Fs[i] = F
    As[i] = A


Generating feature matrices and adjacency matrices...
Preparing outputs...
Padding graphs to correct size...


In [6]:
Fs = onp.stack(Fs).astype(float)
As = onp.stack(As).astype(float)

print(Fs.shape)
print(As.shape)


(1513, 97, 9)
(1513, 97, 97)


In [63]:
As_tensor = tf.convert_to_tensor(As, dtype=float)
Fs_tensor = tf.convert_to_tensor(Fs, dtype=float)

In [64]:
from fundl.activations import relu
from jax import lax


def mpnn(w, b, A, F, nonlin=relu):
    """Follow semantics of fundl.layers.graph.mpnn"""
    # F = lax.batch_matmul(A, F)
    F = tf.keras.backend.batch_dot(A, F)
    F = tf.matmul(F, w) + b
    return nonlin(F)


def gather(F):
    """Follow semantics of fundl.layers.graph.gather"""
    return np.reduce_sum(F, axis=1)

def dense(w, b, x, nonlin=relu):
    """Follow semantics of fundl.layers.dense"""
    a = nonlin(tf.matmul(x, w) + b)
    return a

In [66]:
@pm.model
def graph_neural_network():
    g1w = yield pm.Normal(f"g1w", mu=0, sigma=0.1, shape=(9, 9))
    g1b = yield pm.Normal(f"g1b", mu=0, sigma=0.1, shape=(9,))

    g2w = yield pm.Normal(f"g2w", mu=0, sigma=0.1, shape=(9, 5))
    g2b = yield pm.Normal(f"g2b", mu=0, sigma=0.1, shape=(5,))

    d1w = yield pm.Normal(f"d1w", mu=0, sigma=0.1, shape=(5, 5))
    d1b = yield pm.Normal(f"d1b", mu=0, sigma=0.1, shape=(5,))

    d1w = yield pm.Normal(f"d2w", mu=0, sigma=0.1, shape=(5, 1))
    d1b = yield pm.Normal(f"d2b", mu=0, sigma=0.1, shape=(1,))

    acts1 = mpnn(g1w, g1b, As_tensor, Fs_tensor)
    acts2 = mpnn(g2w, g2b, As_tensor, acts1)
    out = gather(acts2)
    out = dense(d1w, d1b, out)
    out = dense(d2w, d2b, out)
    
    # Prior on noise in measurement.
    sd = yield pm.Exponential("sd", loc=1)
    
    # Likelihood
    like = yield pm.Normal("like", mu=out, sigma=sd, observed=y)

In [67]:
pm.inference.sampling.sample(graph_neural_network())

InvalidArgumentError: In[1] ndims must be >= 2: 0 [Op:BatchMatMulV2] name: MatMul/