In [2]:
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

from graph_nets import blocks
from graph_nets import graphs
from graph_nets import modules
from graph_nets import utils_np
from graph_nets import utils_tf
from graph_nets.demos.models import EncodeProcessDecode as EPD

import matplotlib.pyplot as plt
import networkx as nx
import numpy as np
import sonnet as snt
import tensorflow as tf
import h5py
from progressbar import progressbar
import matplotlib.pyplot as plt
%matplotlib inline

pi = np.pi
twopi = np.pi*2

Useful reference on using Sonnet:  
https://github.com/deepmind/sonnet/blob/v2/examples/mlp_on_mnist.ipynb  
DeepMind graph nn example on physics system:  
https://colab.research.google.com/github/deepmind/graph_nets/blob/master/graph_nets/demos/physics.ipynb#scrollTo=toCQhJIM93en  
graph_nets InteractionNetwork class:  
https://github.com/deepmind/graph_nets/blob/master/docs/graph_nets.md#class-modulesinteractionnetwork

#### Let's grab a few graphs

In [3]:
def get_node_coord_dict(h5):
    node_np = h5['node_coords']
    d = {}
    for i,coords in enumerate(node_np):
        d.update({i:(coords[0],coords[1])})
    return d

def draw_graph(graph, node_pos_dict, col_lims=None):
    if col_lims:
        vmin,vmax = col_lims[0], col_lims[1]
        e_vmin,e_vmax = col_lims[2], col_lims[3]
    else:
        vmin,vmax = -0.5, 10
        e_vmin,e_vmax = -0.5, 5

    nodecols = graph.nodes[:,0]
    edgecols = graph.edges[:,0]

    graphs_nx = utils_np.graphs_tuple_to_networkxs(graph)
    fig,ax = plt.subplots(figsize=(15,15))
    nx.draw(graphs_nx[0],ax=ax,pos=node_pos_dict,node_color=nodecols,
            edge_color=edgecols,node_size=100,
            cmap=plt.cm.winter,edge_cmap=plt.cm.winter,
            vmin=vmin,vmax=vmax,edge_vmin=e_vmin,edge_vmax=e_vmax,
            arrowsize=10)
    return fig,ax

In [4]:
runname = "secondring_t1.0v1.0l10"
inputfname = "nn_inputs/"+runname+".hdf5"
h5in = h5py.File(inputfname,'r')
N_NODE_FEAT = 3
N_EDGE_FEAT = 6
N_NODE = h5in.attrs['n_nodes']
N_EDGE = h5in.attrs['n_edges']
NTG = h5in.attrs['nTG']
node_pos_np = h5in['node_coords'][:]
node_pos = get_node_coord_dict(h5in)
h5in.close()

In [5]:
def snap2graph(h5file,day,tg,placeholder=False):
    snapstr = 'day'+str(day)+'tg'+str(tg)
    glbls = h5file['glbl_features/'+snapstr][0] # Seems glbls have extra dimension
    nodes = h5file['node_features/'+snapstr]
    edges = h5file['edge_features/'+snapstr]
    senders = h5in['senders']
    receivers = h5in['receivers']

    graphdat_dict = {
        "globals": glbls[:].astype(np.float),
        "nodes": nodes[:].astype(np.float),
        "edges": edges[:].astype(np.float),
        "senders": senders[:],
        "receivers": receivers[:]
    }

    if not placeholder:
        graphs_tuple = utils_np.data_dicts_to_graphs_tuple([graphdat_dict])
    else:
        graphs_tuple = utils_tf.placeholders_from_data_dicts([graphdat_dict])
     
    return graphs_tuple

In [None]:
h5in = h5py.File(inputfname,'r')
h5g = snap2graph(h5in,day=1,tg=72)
h5in.close()

SE link for calculating running mu and var of data  
https://math.stackexchange.com/questions/20593/calculate-variance-from-a-stream-of-sample-values  
Let's ignore these stats for now

In [None]:
col_lims = [-1., np.max(h5g.nodes[:,0])*0.5, -1., np.max(h5g.edges[:,0])*0.5]
# col_lims = [0, 1, 0, 1]
fig, ax = draw_graph(h5g,node_pos, col_lims=col_lims)

In [9]:
class MyMLP(snt.Module):
    def __init__(self,name=None):
#         super(MyMLP,self).__init__(name=name)
        self.hidden1=snt.Linear(10,name="hidden1")
        self.output=snt.Linear(2,name="output")
    def __call__(self,x):
        x=self.hidden1(x)
        x=tf.nn.relu(x)
        x=self.output(x)
        return x

In [32]:
class timecrement(snt.Module):
    def __init__(self,name=None):
#         super(timecrement,self).__init__(name=name)
        self.ntg = 144
        self.T = tf.placeholder(np.float,shape=(1,2))
        
    def __call__(self,inputs: tf.Tensor) -> tf.Tensor:
        #tg_T = (T[1] + 1)%self.ntg
        #day_T = T[0]
#         T_ = tf.gather(T,[0])
        tg_T = tf.Variable(inputs[0,0],dtype=np.float)
        day_T = tf.Variable(inputs[0,1],dtype=np.float)
        tg_T = tf.add(tg_T,1.)
        tg_T = tf.floormod(tg_T,self.ntg)
        def f1(): return tf.floormod(tf.add(day_T,1.),7)
        def f2(): return day_T
        day_T = tf.cond(tf.math.equal(tg_T,0.),f1,f2)
#         self.T = tf.Variable([day_T,tg_T])
        return inputs

In [28]:
input_graph.globals

<tf.Tensor 'placeholders_from_data_dicts/globals:0' shape=(?, 2) dtype=float64>

In [33]:
tf.reset_default_graph()

h5in = h5py.File(inputfname,'r')
# edge_model_fn: a callable to perform per-edge computations, etc.
graphnet = modules.GraphNetwork(
    edge_model_fn=lambda: snt.Linear(output_size=N_EDGE_FEAT),
    node_model_fn=lambda: snt.Linear(output_size=N_NODE_FEAT),
    global_model_fn=lambda: timecrement(),
    global_block_opt={"use_edges":False,"use_nodes":False})

input_graph = snap2graph(h5in,day=0,tg=0,placeholder=True)
output_graph = graphnet(input_graph)
lbl_graph = snap2graph(h5in,day=0,tg=0,placeholder=True)

print("Output edges size: {}".format(output_graph.edges.shape[-1]))  # Equal to OUTPUT_EDGE_SIZE
print("Output nodes size: {}".format(output_graph.nodes.shape[-1]))  # Equal to OUTPUT_NODE_SIZE
print("Output globals size: {}".format(output_graph.globals.shape[-1]))

loss = tf.losses.mean_squared_error(labels=lbl_graph.nodes, 
                                    predictions=output_graph.nodes)\
      +tf.losses.mean_squared_error(labels=lbl_graph.edges,
                                    predictions=output_graph.edges)
opt = tf.train.AdamOptimizer(learning_rate=1e-3)
loss_op = opt.minimize(loss)
h5in.close()

ValueError: Tensor conversion requested dtype float32 for Tensor with dtype float64: <tf.Tensor 'graph_network_1/global_block/strided_slice:0' shape=() dtype=float64>

originally defined at:
  File "<ipython-input-33-04933fd3a065>", line 9, in <module>
    global_block_opt={"use_edges":False,"use_nodes":False})
  File "/usr/local/lib/python3.5/dist-packages/graph_nets/modules.py", line 286, in __init__
    global_model_fn=global_model_fn, **global_block_opt)
  File "/usr/local/lib/python3.5/dist-packages/graph_nets/blocks.py", line 611, in __init__
    super(GlobalBlock, self).__init__(name=name)
  File "/usr/local/lib/python3.5/dist-packages/sonnet/python/modules/base.py", line 180, in __init__
    custom_getter_=self._custom_getter)
  File "/usr/local/lib/python3.5/dist-packages/tensorflow_core/python/ops/template.py", line 161, in make_template
    **kwargs)


originally defined at:
  File "<ipython-input-33-04933fd3a065>", line 9, in <module>
    global_block_opt={"use_edges":False,"use_nodes":False})
  File "/usr/local/lib/python3.5/dist-packages/graph_nets/modules.py", line 275, in __init__
    super(GraphNetwork, self).__init__(name=name)
  File "/usr/local/lib/python3.5/dist-packages/sonnet/python/modules/base.py", line 180, in __init__
    custom_getter_=self._custom_getter)
  File "/usr/local/lib/python3.5/dist-packages/tensorflow_core/python/ops/template.py", line 161, in make_template
    **kwargs)


In [31]:
losses = []
sel_nodes_out, sel_nodes_input = [], []

with h5py.File(inputfname,'r') as h5in:
    with tf.Session() as sess:
        sess.run(tf.global_variables_initializer())
        for i in progressbar(range(1000)):
            tg = i%NTG
            day = (i//NTG)%7
            lbltg = (i+1)%NTG
            lblday = (day+((tg+1)//NTG))%7
#             day,lblday = 0, 0
            graph = snap2graph(h5in,day=day,tg=tg)
            lbl = snap2graph(h5in,day=lblday,tg=lbltg)
            train_dict = sess.run({
                "loss": loss,
                "loss_op": loss_op,
                "outputs": output_graph,
                "train_vars": graphnet.trainable_variables
            }, feed_dict={
                input_graph: graph,
                lbl_graph: lbl
            })
            if (i)%100==0:
#                 print(day,tg,train_dict['loss'])
#                 sel_nodes_out.append([train_dict['outputs'].nodes[:3]])
                losses.append(train_dict['loss'])

        # Test

InvalidArgumentError: You must feed a value for placeholder tensor 'placeholders_from_data_dicts/globals' with dtype double and shape [?,2]
	 [[node placeholders_from_data_dicts/globals (defined at /usr/local/lib/python3.5/dist-packages/tensorflow_core/python/framework/ops.py:1748) ]]

Original stack trace for 'placeholders_from_data_dicts/globals':
  File "/usr/lib/python3.5/runpy.py", line 184, in _run_module_as_main
    "__main__", mod_spec)
  File "/usr/lib/python3.5/runpy.py", line 85, in _run_code
    exec(code, run_globals)
  File "/usr/local/lib/python3.5/dist-packages/ipykernel_launcher.py", line 16, in <module>
    app.launch_new_instance()
  File "/usr/local/lib/python3.5/dist-packages/traitlets/config/application.py", line 664, in launch_instance
    app.start()
  File "/usr/local/lib/python3.5/dist-packages/ipykernel/kernelapp.py", line 563, in start
    self.io_loop.start()
  File "/usr/local/lib/python3.5/dist-packages/tornado/platform/asyncio.py", line 148, in start
    self.asyncio_loop.run_forever()
  File "/usr/lib/python3.5/asyncio/base_events.py", line 345, in run_forever
    self._run_once()
  File "/usr/lib/python3.5/asyncio/base_events.py", line 1312, in _run_once
    handle._run()
  File "/usr/lib/python3.5/asyncio/events.py", line 125, in _run
    self._callback(*self._args)
  File "/usr/local/lib/python3.5/dist-packages/tornado/ioloop.py", line 690, in <lambda>
    lambda f: self._run_callback(functools.partial(callback, future))
  File "/usr/local/lib/python3.5/dist-packages/tornado/ioloop.py", line 743, in _run_callback
    ret = callback()
  File "/usr/local/lib/python3.5/dist-packages/tornado/gen.py", line 787, in inner
    self.run()
  File "/usr/local/lib/python3.5/dist-packages/tornado/gen.py", line 748, in run
    yielded = self.gen.send(value)
  File "/usr/local/lib/python3.5/dist-packages/ipykernel/kernelbase.py", line 361, in process_one
    yield gen.maybe_future(dispatch(*args))
  File "/usr/local/lib/python3.5/dist-packages/tornado/gen.py", line 209, in wrapper
    yielded = next(result)
  File "/usr/local/lib/python3.5/dist-packages/ipykernel/kernelbase.py", line 268, in dispatch_shell
    yield gen.maybe_future(handler(stream, idents, msg))
  File "/usr/local/lib/python3.5/dist-packages/tornado/gen.py", line 209, in wrapper
    yielded = next(result)
  File "/usr/local/lib/python3.5/dist-packages/ipykernel/kernelbase.py", line 541, in execute_request
    user_expressions, allow_stdin,
  File "/usr/local/lib/python3.5/dist-packages/tornado/gen.py", line 209, in wrapper
    yielded = next(result)
  File "/usr/local/lib/python3.5/dist-packages/ipykernel/ipkernel.py", line 300, in do_execute
    res = shell.run_cell(code, store_history=store_history, silent=silent)
  File "/usr/local/lib/python3.5/dist-packages/ipykernel/zmqshell.py", line 536, in run_cell
    return super(ZMQInteractiveShell, self).run_cell(*args, **kwargs)
  File "/usr/local/lib/python3.5/dist-packages/IPython/core/interactiveshell.py", line 2855, in run_cell
    raw_cell, store_history, silent, shell_futures)
  File "/usr/local/lib/python3.5/dist-packages/IPython/core/interactiveshell.py", line 2881, in _run_cell
    return runner(coro)
  File "/usr/local/lib/python3.5/dist-packages/IPython/core/async_helpers.py", line 68, in _pseudo_sync_runner
    coro.send(None)
  File "/usr/local/lib/python3.5/dist-packages/IPython/core/interactiveshell.py", line 3058, in run_cell_async
    interactivity=interactivity, compiler=compiler, result=result)
  File "/usr/local/lib/python3.5/dist-packages/IPython/core/interactiveshell.py", line 3249, in run_ast_nodes
    if (await self.run_code(code, result,  async_=asy)):
  File "/usr/local/lib/python3.5/dist-packages/IPython/core/interactiveshell.py", line 3326, in run_code
    exec(code_obj, self.user_global_ns, self.user_ns)
  File "<ipython-input-30-04933fd3a065>", line 11, in <module>
    input_graph = snap2graph(h5in,day=0,tg=0,placeholder=True)
  File "<ipython-input-5-ab94a858e412>", line 20, in snap2graph
    graphs_tuple = utils_tf.placeholders_from_data_dicts([graphdat_dict])
  File "/usr/local/lib/python3.5/dist-packages/graph_nets/utils_tf.py", line 279, in placeholders_from_data_dicts
    graph, force_dynamic_num_graphs=force_dynamic_num_graphs)
  File "/usr/local/lib/python3.5/dist-packages/graph_nets/utils_tf.py", line 214, in _placeholders_from_graphs_tuple
    force_dynamic_num_graphs=force_dynamic_num_graphs)
  File "/usr/local/lib/python3.5/dist-packages/graph_nets/utils_tf.py", line 191, in _build_placeholders_from_specs
    dct[field] = tf.placeholder(dtype, shape=shape, name=field)
  File "/usr/local/lib/python3.5/dist-packages/tensorflow_core/python/ops/array_ops.py", line 2619, in placeholder
    return gen_array_ops.placeholder(dtype=dtype, shape=shape, name=name)
  File "/usr/local/lib/python3.5/dist-packages/tensorflow_core/python/ops/gen_array_ops.py", line 6669, in placeholder
    "Placeholder", dtype=dtype, shape=shape, name=name)
  File "/usr/local/lib/python3.5/dist-packages/tensorflow_core/python/framework/op_def_library.py", line 794, in _apply_op_helper
    op_def=op_def)
  File "/usr/local/lib/python3.5/dist-packages/tensorflow_core/python/util/deprecation.py", line 507, in new_func
    return func(*args, **kwargs)
  File "/usr/local/lib/python3.5/dist-packages/tensorflow_core/python/framework/ops.py", line 3357, in create_op
    attrs, op_def, compute_device)
  File "/usr/local/lib/python3.5/dist-packages/tensorflow_core/python/framework/ops.py", line 3426, in _create_op_internal
    op_def=op_def)
  File "/usr/local/lib/python3.5/dist-packages/tensorflow_core/python/framework/ops.py", line 1748, in __init__
    self._traceback = tf_stack.extract_stack()


In [None]:
plt.semilogy(losses)

In [14]:
class timecrement(snt.Module):
    def __init__(self,ntg,name=None):
        super(timecrement,self).__init__(ntg,name=name)
        self.ntg = ntg
    def __call__(self,T):
        #tg_T = (T[1] + 1)%self.ntg
        #day_T = T[0]
#         T_ = tf.gather(T,[0])
        tg_T = tf.Variable(T[0,0])
        day_T = tf.Variable(T[0,1])
        tg_T = tf.add(tg_T,1.)
        tg_T = tf.floormod(tg_T,self.ntg)
        def f1(): return tf.floormod(tf.add(day_T,1.),7)
        def f2(): return day_T
        day_T = tf.cond(tf.math.equal(tg_T,0.),f1,f2)
#         T = tf.Variable([day_T,tg_T])
        return T

In [27]:
tf.reset_default_graph()

T = tf.Variable([[6.,140.]])
tc = timecrement(NTG)
with tf.Session() as sess:
    sess.run(tf.global_variables_initializer())
    for _ in range(20):
        T = tc(T)
        print(T.eval())


[[  6. 140.]]
[[  6. 140.]]
[[  6. 140.]]
[[  6. 140.]]
[[  6. 140.]]
[[  6. 140.]]
[[  6. 140.]]
[[  6. 140.]]
[[  6. 140.]]
[[  6. 140.]]
[[  6. 140.]]
[[  6. 140.]]
[[  6. 140.]]
[[  6. 140.]]
[[  6. 140.]]
[[  6. 140.]]
[[  6. 140.]]
[[  6. 140.]]
[[  6. 140.]]
[[  6. 140.]]


In [None]:
train_dict['inet_vars']

In [None]:
graphnet.variables

In [None]:
outputs = train_dict['outputs']
outputs.nodes[:10]

In [None]:
lbl.nodes[:10]

In [None]:
a = np_graphs_tuple.replace(nodes=np_graphs_tuple.nodes[:10],
                            edges=np_graphs_tuple.edges[:10],
                            globals=np_graphs_tuple.globals[:10])