<a href="https://colab.research.google.com/github/jcsh4326/notebook/blob/master/graph_nets_example.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# 安装tensorflow_probobility

[从源码安装tensorflow_probability_gpu](https://colab.research.google.com/github/jcsh4326/notebook/blob/master/%E4%BB%8E%E6%BA%90%E7%A0%81%E5%AE%89%E8%A3%85tensorflow_probability_gpu.ipynb)

In [0]:
#@title build tensorflow_probability from source  { form-width: "30%" }

!wget https://github.com/bazelbuild/bazel/releases/download/0.19.1/bazel-0.19.1-installer-linux-x86_64.sh
!chmod +x bazel-0.19.1-installer-linux-x86_64.sh
!./bazel-0.19.1-installer-linux-x86_64.sh --user

!git clone https://github.com/tensorflow/probability.git
import os
os.chdir('probability')

!$HOME/bin/bazel build --copt=-O3 --copt=-march=native :pip_pkg
  
from tempfile import TemporaryFile
PKGDIR = TemporaryFile('w+t').name

!./bazel-bin/pip_pkg $PKGDIR
!pip install --user --upgrade $PKGDIR/*.whl

# 安装graph_nets_gpu

In [0]:
 #@title install graph_nets for gpu with pip
 !pip install graph_nets

# Hello Word
一个图网络是由什么构成的呢？ 

首先我们需要一些节点，这些节点是具体样本抽象后的表示，节点上带有样本的特征，这个样本有可能是一个具体的东西，比如社群关系中的一个任务，论文引用/被引样本集中的一篇论文，也有可能是一个东西的局部，比如一个人物动作中的某个关节节点，一张图片的某一个或某一块像素，当然也可能是一些更抽象的概念。需要注意的是，在graph_nets的使用中，特征是必需的，哪怕这个特征在具体的问题中抽象为“empty”或者“none”，也应该明确地定义出来，而不是让框架接受到“undefined”。

其次我们需要一些边，这些边表明了节点与节点之间的关系。这些节点被分为receivers和senders，顾名思义，这表明了“关系”的发出方和接收方，由此我们可以明确，这个图是一个“有向图”。毫无疑问的，边也具有自身的特征，同样要满足框架的要求。

最后，既然边和节点需要定义特征，图需不需要呢？答案是同样需要。

以networkx为例，graph_nets在测试用例中定义了一个特征都是“none”的用例，


```
def test_networkxs_to_graphs_tuple_with_none_fields(self):
    graph_nx = nx.OrderedMultiDiGraph()
    data_dict = utils_np.networkx_to_data_dict(
        graph_nx,
        node_shape_hint=None,
        edge_shape_hint=None)
    self.assertEqual(None, data_dict["edges"])
    self.assertEqual(None, data_dict["globals"])
    self.assertEqual(None, data_dict["nodes"])
    graph_nx.add_node(0, features=None)
    data_dict = utils_np.networkx_to_data_dict(
        graph_nx,
        node_shape_hint=1,
        edge_shape_hint=None)
    self.assertEqual(None, data_dict["nodes"])
    graph_nx.add_edge(0, 0, features=None)
    data_dict = utils_np.networkx_to_data_dict(
        graph_nx,
        node_shape_hint=[1],
        edge_shape_hint=[1])
    self.assertEqual(None, data_dict["edges"])
    graph_nx.graph["features"] = None
    utils_np.networkx_to_data_dict(graph_nx)
    self.assertEqual(None, data_dict["globals"])
```

如果没有定义“features"会发生什么呢？比如没有定义node的features，


```
KeyError                                  Traceback (most recent call last)
/usr/local/lib/python3.6/dist-packages/graph_nets/utils_np.py in networkx_to_data_dict(graph_nx, node_shape_hint, edge_shape_hint, data_type_hint)
    168           x[1][GRAPH_NX_FEATURES_KEY]
--> 169           for x in graph_nx.nodes(data=True)
    170           if x[1][GRAPH_NX_FEATURES_KEY] is not None

/usr/local/lib/python3.6/dist-packages/graph_nets/utils_np.py in <listcomp>(.0)
    169           for x in graph_nx.nodes(data=True)
--> 170           if x[1][GRAPH_NX_FEATURES_KEY] is not None
    171       ]

KeyError: 'features'

During handling of the above exception, another exception occurred:

KeyError                                  Traceback (most recent call last)
<ipython-input-7-3496952e54e1> in <module>()
----> 1 input_graphs = utils_np.networkxs_to_graphs_tuple([nxs])
      2 # Create the graph network.
      3 graph_net_module = gn.modules.GraphNetwork(
      4     edge_model_fn=lambda: snt.nets.MLP([32, 32]),
      5     node_model_fn=lambda: snt.nets.MLP([32, 32]),

/usr/local/lib/python3.6/dist-packages/graph_nets/utils_np.py in networkxs_to_graphs_tuple(graph_nxs, node_shape_hint, edge_shape_hint, data_type_hint)
    335     for graph_nx in graph_nxs:
    336       data_dict = networkx_to_data_dict(graph_nx, node_shape_hint,
--> 337                                         edge_shape_hint, data_type_hint)
    338       data_dicts.append(data_dict)
    339   except TypeError:

/usr/local/lib/python3.6/dist-packages/graph_nets/utils_np.py in networkx_to_data_dict(graph_nx, node_shape_hint, edge_shape_hint, data_type_hint)
    176         nodes = np.array(nodes_data)
    177     except KeyError:
--> 178       raise KeyError("Missing 'node' field from the graph nodes. "
    179                      "This could be due to the node having been silently added "
    180                      "as a consequence of an edge addition when creating the "

KeyError: "Missing 'node' field from the graph nodes. This could be due to the node having been silently added as a consequence of an edge addition when creating the networkx instance"
```




In [0]:
#@title import

import networkx as nx               # colab 自带
import matplotlib.pyplot as plt     # colab 自带
import numpy as np                  # colab 自带
import tensorflow as tf
import sonnet as snt                # 安装graph-nets 附带安装dm-sonnet依赖项
import graph_nets as gn
from graph_nets import utils_np
from graph_nets import utils_tf

In [0]:
#@title functions
def get_ordered_multi_digraph():
  G = nx.OrderedMultiDiGraph()
  # 定义图的节点
  for node_index in range(10):    
    # add_node 接受attr_dict=None, **attr. 
    # graph_nets的networkx_to_data_dict方法需要graph具有一个feature属性，
    # feature属性表示的是节点的特征
    G.add_node(node_index, features=np.array([node_index],dtype=np.float))
  
  # 定义 receivers和senders 
  senders = [1.,1.,2.,2.,3.,4.,5.,3.]
  receivers = [2.,3.,4.,5.,6.,8.,8.,7.]
  # 定义边
  for edge_index, (receiver, sender) in enumerate(zip(receivers, senders)):
    # Removing the "index" key makes this test fail 100%.
    edge_data = {"features": np.array([edge_index],dtype=np.float), "index": edge_index}
    G.add_edge(sender, receiver, **edge_data)
  G.graph["features"] = np.array([0],dtype=np.float)
  #H = nx.path_graph(10)  # 生成10个节点，由0-9表示，由10-1条边线性连接
  #G.add_nodes_from(H) # 把10个节点加入到图里
  #G.add_edges_from([(1,2),(1,3),(2,4),(2,5),(3,6),(4,8),(5,8),(3,7)], weight=3) # 把指定的节点连接起来
  print(list(G.nodes(data=True)))
  return G  

def show_graph(G):
  nx.draw(G, with_labels=True)
  plt.show()


In [0]:
#@title create graphs
nxs = get_ordered_multi_digraph()
show_graph(nxs)

In [0]:
#@title run module

data_dic = utils_np.networkx_to_data_dict(nxs)
print(data_dic)

#input_graphs = utils_np.networkxs_to_graphs_tuple([nxs])
input_graphs = utils_tf.data_dicts_to_graphs_tuple([data_dic])
# Create the graph network.
graph_net_module = gn.modules.GraphNetwork(
    # output size [32, 32]
    edge_model_fn=lambda: snt.nets.MLP([32, 32]),
    node_model_fn=lambda: snt.nets.MLP([32, 32]),
    global_model_fn=lambda: snt.nets.MLP([32, 32]))

# Pass the input graphs to the graph network, and return the output graphs.
output_graphs = graph_net_module(input_graphs)
print(output_graphs)


In [7]:
#@title show result
print(output_graphs.n_node)
print(output_graphs.n_edge)

output_dicts = utils_np.graphs_tuple_to_data_dicts(output_graphs)

print(output_dicts)
#res = utils_np.graphs_tuple_to_networkxs(output_graphs)
#show_graph(res)

Tensor("data_dicts_to_graphs_tuple/stack_1:0", shape=(1,), dtype=int32)
Tensor("data_dicts_to_graphs_tuple/stack_2:0", shape=(1,), dtype=int32)


ValueError: ignored