-
Notifications
You must be signed in to change notification settings - Fork 0
/
graph_matching_networks.py
2099 lines (1712 loc) · 79.9 KB
/
graph_matching_networks.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
794
795
796
797
798
799
800
801
802
803
804
805
806
807
808
809
810
811
812
813
814
815
816
817
818
819
820
821
822
823
824
825
826
827
828
829
830
831
832
833
834
835
836
837
838
839
840
841
842
843
844
845
846
847
848
849
850
851
852
853
854
855
856
857
858
859
860
861
862
863
864
865
866
867
868
869
870
871
872
873
874
875
876
877
878
879
880
881
882
883
884
885
886
887
888
889
890
891
892
893
894
895
896
897
898
899
900
901
902
903
904
905
906
907
908
909
910
911
912
913
914
915
916
917
918
919
920
921
922
923
924
925
926
927
928
929
930
931
932
933
934
935
936
937
938
939
940
941
942
943
944
945
946
947
948
949
950
951
952
953
954
955
956
957
958
959
960
961
962
963
964
965
966
967
968
969
970
971
972
973
974
975
976
977
978
979
980
981
982
983
984
985
986
987
988
989
990
991
992
993
994
995
996
997
998
999
1000
# -*- coding: utf-8 -*-
"""graph_matching_networks.ipynb
Automatically generated by Colaboratory.
Original file is located at
https://colab.research.google.com/github/deepmind/deepmind_research/blob/master/graph_matching_networks/graph_matching_networks.ipynb
##### Copyright 2019 Google LLC.
Licensed under the Apache License, Version 2.0 (the "License");
"""
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""# Graph Matching Networks for Learning the Similarity of Graph Structured Objects
This is the example code for our ICML 2019 paper. Please refer to the paper for more details:
> Yujia Li, Chenjie Gu, Thomas Dullien, Oriol Vinyals, Pushmeet Kohli. *Graph Matching Networks for Learning the Similarity of Graph Structured Objects*. ICML 2019. [\[arXiv\]](https://arxiv.org/abs/1904.12787)
## Graph similarity learning
Our goal is to learn a similarity function between graphs. Given two graphs $G_1, G_2$, a graph similarity model can be written as a function $f(G_1, G_2)$ that computes a scalar similarity value.
In this project we build models to learn such a similarity function based on examples of similar / dissimilar pairs or triplets. Because of learning, our model can adapt to different notions of similarity and to different types of graph structure, as long as training data is available.
In the following we will sometimes use the term "distance" and say the model learns a "distance function" $d(G_1, G_2)$ between graphs when convenient. But this is just the opposite of a similarity function, and you may simply say $f(G_1, G_2) = - d(G_1, G_2)$.
## Some dependencies and imports
If you want to run the notebook locally, make sure you have all the dependencies first. You can use the following command
```
pip3 install --user -r requirements.txt
```
Note the code should work for both python 3 and 2, but python 3 is recommended.
"""
# Let's disable all the warnings first
import warnings
warnings.simplefilter("ignore")
"""These are all the dependencies that will be used in this notebook."""
import abc
import collections
import contextlib
import copy
import random
import time
import matplotlib.pyplot as plt
import networkx as nx
import numpy as np
import six
import sonnet as snt
import tensorflow as tf
"""## The models
### The graph embedding model
The simpler variant of our model is based on embedding each graph **independently** into a vector and then use an existing distance (or similarity) metric in the vector space to compute the distance between graphs. More concretely, we define
$$d(G_1, G_2) = d_H(embed(G_1), embed(G_2)),$$
where $embed$ is a model that maps any graph $G$ into an $H$-dimensional vector, and $d_H$ is a distance metric in that vector space. Typical examples are Euclidean distance in $\mathbb{R}^H$, i.e. $d_H(x, y) = \sqrt{\sum_{i=1}^H (x_i - y_i)^2}$, or Hamming distance in $H$-dimensional space of binary vectors, i.e. $d_H(x, y)=\sum_{i=1}^H \mathbb{I}[x_i \ne y_i]$.
Each graph input contains a set of nodes $V$ and edges $E$. Each node $i\in V$ may have a feature vector $x_i$ associated with it, and each edge $(i, j)\in E$ may also have a feature vector $x_{ij}$ encoding e.g. edge type or attributes. The embedding model will therefore jointly reason about the graph structure as well as the graph features to come up with an embedding that reflects the notion of similarity described by the training examples.
The embedding model is composed of 3 parts:
1. An encoder that maps $x_i$ and $x_{ij}$ into a nice hidden representation space. Here we use separate MLPs (fully connected neural nets) for node and edge representations:
$$\begin{array}{rcl}
h_i^{(0)} &=& \mathrm{MLP_{node}}(x_i) \\
e_{ij} &=& \mathrm{MLP_{edge}}(x_{ij})
\end{array}
$$
2. A graph neural network (GNN) that communicates information across the graph and computes node representations that encode local neighborhood structure and semantics. More concretely, the GNN computes node representations through an iterative message passing process. In the $t$-th round of message passing, we compute a message vector on each edge, and then each node aggregates all the incoming messages and updates its own representation:
$$\begin{array}{rcl}
m_{i\rightarrow j} &=& f_\mathrm{message}(h_i^{(t)}, h_j^{(t)}, e_{ij}) \\
h_i^{(t+1)} &=& f_\mathrm{node}(h_i^{(t)}, \sum_{j:(j,i)\in E} m_{j\rightarrow i})
\end{array}
$$
Here both $f_\mathrm{message}$ and $f_\mathrm{node}$ are neural modules. We use MLPs for $f_\mathrm{message}$, while $f_\mathrm{node}$ can also be MLPs or even recurrent neural network cores like LSTMs or GRUs. The GNNs have the nice property of being equivariant to node permutations, and nodes on isomorphic graphs (with the same node and edge features) will have the same representations regardless of the ordering.
3. After we obtained the final node representations after $T$ rounds of message passing, we aggregate across them to get graph representations $h_G=f_G(\{h_i^{(T)}\}_{i\in V})$. This could be implemented by a simple sum that reduces the node representations into a single vector and then transform it:
$$h_G = \mathrm{MLP_G}\left(\sum_{i\in V} h_i^{(T)}\right).$$
We used the following gated aggregation module proposed in [Li et al., 2015](https://arxiv.org/abs/1511.05493) which we found to work consistently better:
$$h_G = \mathrm{MLP_G}\left(\sum_{i\in V} \sigma(\mathrm{MLP_{gate}}(h_i^{(T)})) \odot \mathrm{MLP}(h_i^{(T)})\right).$$
The key to this function is to make sure it is invariant to node orderings, both the above forms satisfy this condition. The gated variant gives the model the capacity to explicitly modulate each node's contribution to the graph representation.
#### The graph encoder
"""
class GraphEncoder(snt.AbstractModule):
"""Encoder module that projects node and edge features to some embeddings."""
def __init__(self,
node_hidden_sizes=None,
edge_hidden_sizes=None,
name='graph-encoder'):
"""Constructor.
Args:
node_hidden_sizes: if provided should be a list of ints, hidden sizes of
node encoder network, the last element is the size of the node outputs.
If not provided, node features will pass through as is.
edge_hidden_sizes: if provided should be a list of ints, hidden sizes of
edge encoder network, the last element is the size of the edge outptus.
If not provided, edge features will pass through as is.
name: name of this module.
"""
super(GraphEncoder, self).__init__(name=name)
# this also handles the case of an empty list
self._node_hidden_sizes = node_hidden_sizes if node_hidden_sizes else None
self._edge_hidden_sizes = edge_hidden_sizes
def _build(self, node_features, edge_features=None):
"""Encode node and edge features.
Args:
node_features: [n_nodes, node_feat_dim] float tensor.
edge_features: if provided, should be [n_edges, edge_feat_dim] float
tensor.
Returns:
node_outputs: [n_nodes, node_embedding_dim] float tensor, node embeddings.
edge_outputs: if edge_features is not None and edge_hidden_sizes is not
None, this is [n_edges, edge_embedding_dim] float tensor, edge
embeddings; otherwise just the input edge_features.
"""
if self._node_hidden_sizes is None:
node_outputs = node_features
else:
node_outputs = snt.nets.MLP(
self._node_hidden_sizes, name='node-feature-mlp')(node_features)
if edge_features is None or self._edge_hidden_sizes is None:
edge_outputs = edge_features
else:
edge_outputs = snt.nets.MLP(
self._edge_hidden_sizes, name='edge-feature-mlp')(edge_features)
return node_outputs, edge_outputs
"""#### The message passing layers"""
def graph_prop_once(node_states,
from_idx,
to_idx,
message_net,
aggregation_module=tf.unsorted_segment_sum,
edge_features=None):
"""One round of propagation (message passing) in a graph.
Args:
node_states: [n_nodes, node_state_dim] float tensor, node state vectors, one
row for each node.
from_idx: [n_edges] int tensor, index of the from nodes.
to_idx: [n_edges] int tensor, index of the to nodes.
message_net: a network that maps concatenated edge inputs to message
vectors.
aggregation_module: a module that aggregates messages on edges to aggregated
messages for each node. Should be a callable and can be called like the
following,
`aggregated_messages = aggregation_module(messages, to_idx, n_nodes)`,
where messages is [n_edges, edge_message_dim] tensor, to_idx is the index
of the to nodes, i.e. where each message should go to, and n_nodes is an
int which is the number of nodes to aggregate into.
edge_features: if provided, should be a [n_edges, edge_feature_dim] float
tensor, extra features for each edge.
Returns:
aggregated_messages: an [n_nodes, edge_message_dim] float tensor, the
aggregated messages, one row for each node.
"""
from_states = tf.gather(node_states, from_idx)
to_states = tf.gather(node_states, to_idx)
edge_inputs = [from_states, to_states]
if edge_features is not None:
edge_inputs.append(edge_features)
edge_inputs = tf.concat(edge_inputs, axis=-1)
messages = message_net(edge_inputs)
return aggregation_module(messages, to_idx, tf.shape(node_states)[0])
class GraphPropLayer(snt.AbstractModule):
"""Implementation of a graph propagation (message passing) layer."""
def __init__(self,
node_state_dim,
edge_hidden_sizes,
node_hidden_sizes,
edge_net_init_scale=0.1,
node_update_type='residual',
use_reverse_direction=True,
reverse_dir_param_different=True,
layer_norm=False,
name='graph-net'):
"""Constructor.
Args:
node_state_dim: int, dimensionality of node states.
edge_hidden_sizes: list of ints, hidden sizes for the edge message
net, the last element in the list is the size of the message vectors.
node_hidden_sizes: list of ints, hidden sizes for the node update
net.
edge_net_init_scale: initialization scale for the edge networks. This
is typically set to a small value such that the gradient does not blow
up.
node_update_type: type of node updates, one of {mlp, gru, residual}.
use_reverse_direction: set to True to also propagate messages in the
reverse direction.
reverse_dir_param_different: set to True to have the messages computed
using a different set of parameters than for the forward direction.
layer_norm: set to True to use layer normalization in a few places.
name: name of this module.
"""
super(GraphPropLayer, self).__init__(name=name)
self._node_state_dim = node_state_dim
self._edge_hidden_sizes = edge_hidden_sizes[:]
# output size is node_state_dim
self._node_hidden_sizes = node_hidden_sizes[:] + [node_state_dim]
self._edge_net_init_scale = edge_net_init_scale
self._node_update_type = node_update_type
self._use_reverse_direction = use_reverse_direction
self._reverse_dir_param_different = reverse_dir_param_different
self._layer_norm = layer_norm
def _compute_aggregated_messages(
self, node_states, from_idx, to_idx, edge_features=None):
"""Compute aggregated messages for each node.
Args:
node_states: [n_nodes, input_node_state_dim] float tensor, node states.
from_idx: [n_edges] int tensor, from node indices for each edge.
to_idx: [n_edges] int tensor, to node indices for each edge.
edge_features: if not None, should be [n_edges, edge_embedding_dim]
tensor, edge features.
Returns:
aggregated_messages: [n_nodes, aggregated_message_dim] float tensor, the
aggregated messages for each node.
"""
self._message_net = snt.nets.MLP(
self._edge_hidden_sizes,
initializers={
'w': tf.variance_scaling_initializer(
scale=self._edge_net_init_scale),
'b': tf.zeros_initializer()},
name='message-mlp')
aggregated_messages = graph_prop_once(
node_states,
from_idx,
to_idx,
self._message_net,
aggregation_module=tf.unsorted_segment_sum,
edge_features=edge_features)
# optionally compute message vectors in the reverse direction
if self._use_reverse_direction:
if self._reverse_dir_param_different:
self._reverse_message_net = snt.nets.MLP(
self._edge_hidden_sizes,
initializers={
'w': tf.variance_scaling_initializer(
scale=self._edge_net_init_scale),
'b': tf.zeros_initializer()},
name='reverse-message-mlp')
else:
self._reverse_message_net = self._message_net
reverse_aggregated_messages = graph_prop_once(
node_states,
to_idx,
from_idx,
self._reverse_message_net,
aggregation_module=tf.unsorted_segment_sum,
edge_features=edge_features)
aggregated_messages += reverse_aggregated_messages
if self._layer_norm:
aggregated_messages = snt.LayerNorm()(aggregated_messages)
return aggregated_messages
def _compute_node_update(self,
node_states,
node_state_inputs,
node_features=None):
"""Compute node updates.
Args:
node_states: [n_nodes, node_state_dim] float tensor, the input node
states.
node_state_inputs: a list of tensors used to compute node updates. Each
element tensor should have shape [n_nodes, feat_dim], where feat_dim can
be different. These tensors will be concatenated along the feature
dimension.
node_features: extra node features if provided, should be of size
[n_nodes, extra_node_feat_dim] float tensor, can be used to implement
different types of skip connections.
Returns:
new_node_states: [n_nodes, node_state_dim] float tensor, the new node
state tensor.
Raises:
ValueError: if node update type is not supported.
"""
if self._node_update_type in ('mlp', 'residual'):
node_state_inputs.append(node_states)
if node_features is not None:
node_state_inputs.append(node_features)
if len(node_state_inputs) == 1:
node_state_inputs = node_state_inputs[0]
else:
node_state_inputs = tf.concat(node_state_inputs, axis=-1)
if self._node_update_type == 'gru':
_, new_node_states = snt.GRU(self._node_state_dim)(
node_state_inputs, node_states)
return new_node_states
else:
mlp_output = snt.nets.MLP(
self._node_hidden_sizes, name='node-mlp')(node_state_inputs)
if self._layer_norm:
mlp_output = snt.LayerNorm()(mlp_output)
if self._node_update_type == 'mlp':
return mlp_output
elif self._node_update_type == 'residual':
return node_states + mlp_output
else:
raise ValueError('Unknown node update type %s' % self._node_update_type)
def _build(self,
node_states,
from_idx,
to_idx,
edge_features=None,
node_features=None):
"""Run one propagation step.
Args:
node_states: [n_nodes, input_node_state_dim] float tensor, node states.
from_idx: [n_edges] int tensor, from node indices for each edge.
to_idx: [n_edges] int tensor, to node indices for each edge.
edge_features: if not None, should be [n_edges, edge_embedding_dim]
tensor, edge features.
node_features: extra node features if provided, should be of size
[n_nodes, extra_node_feat_dim] float tensor, can be used to implement
different types of skip connections.
Returns:
node_states: [n_nodes, node_state_dim] float tensor, new node states.
"""
aggregated_messages = self._compute_aggregated_messages(
node_states, from_idx, to_idx, edge_features=edge_features)
return self._compute_node_update(node_states,
[aggregated_messages],
node_features=node_features)
"""#### Graph aggregator"""
AGGREGATION_TYPE = {
'sum': tf.unsorted_segment_sum,
'mean': tf.unsorted_segment_mean,
'sqrt_n': tf.unsorted_segment_sqrt_n,
'max': tf.unsorted_segment_max,
}
class GraphAggregator(snt.AbstractModule):
"""This module computes graph representations by aggregating from parts."""
def __init__(self,
node_hidden_sizes,
graph_transform_sizes=None,
gated=True,
aggregation_type='sum',
name='graph-aggregator'):
"""Constructor.
Args:
node_hidden_sizes: the hidden layer sizes of the node transformation nets.
The last element is the size of the aggregated graph representation.
graph_transform_sizes: sizes of the transformation layers on top of the
graph representations. The last element of this list is the final
dimensionality of the output graph representations.
gated: set to True to do gated aggregation, False not to.
aggregation_type: one of {sum, max, mean, sqrt_n}.
name: name of this module.
"""
super(GraphAggregator, self).__init__(name=name)
self._node_hidden_sizes = node_hidden_sizes
self._graph_transform_sizes = graph_transform_sizes
self._graph_state_dim = node_hidden_sizes[-1]
self._gated = gated
self._aggregation_type = aggregation_type
self._aggregation_op = AGGREGATION_TYPE[aggregation_type]
def _build(self, node_states, graph_idx, n_graphs):
"""Compute aggregated graph representations.
Args:
node_states: [n_nodes, node_state_dim] float tensor, node states of a
batch of graphs concatenated together along the first dimension.
graph_idx: [n_nodes] int tensor, graph ID for each node.
n_graphs: integer, number of graphs in this batch.
Returns:
graph_states: [n_graphs, graph_state_dim] float tensor, graph
representations, one row for each graph.
"""
node_hidden_sizes = self._node_hidden_sizes
if self._gated:
node_hidden_sizes[-1] = self._graph_state_dim * 2
node_states_g = snt.nets.MLP(
node_hidden_sizes, name='node-state-g-mlp')(node_states)
if self._gated:
gates = tf.nn.sigmoid(node_states_g[:, :self._graph_state_dim])
node_states_g = node_states_g[:, self._graph_state_dim:] * gates
graph_states = self._aggregation_op(node_states_g, graph_idx, n_graphs)
# unsorted_segment_max does not handle empty graphs in the way we want
# it assigns the lowest possible float to empty segments, we want to reset
# them to zero.
if self._aggregation_type == 'max':
# reset everything that's smaller than -1e5 to 0.
graph_states *= tf.cast(graph_states > -1e5, tf.float32)
# transform the reduced graph states further
# pylint: disable=g-explicit-length-test
if (self._graph_transform_sizes is not None and
len(self._graph_transform_sizes) > 0):
graph_states = snt.nets.MLP(
self._graph_transform_sizes, name='graph-transform-mlp')(graph_states)
return graph_states
"""#### Putting them together"""
class GraphEmbeddingNet(snt.AbstractModule):
"""A graph to embedding mapping network."""
def __init__(self,
encoder,
aggregator,
node_state_dim,
edge_hidden_sizes,
node_hidden_sizes,
n_prop_layers,
share_prop_params=False,
edge_net_init_scale=0.1,
node_update_type='residual',
use_reverse_direction=True,
reverse_dir_param_different=True,
layer_norm=False,
name='graph-embedding-net'):
"""Constructor.
Args:
encoder: GraphEncoder, encoder that maps features to embeddings.
aggregator: GraphAggregator, aggregator that produces graph
representations.
node_state_dim: dimensionality of node states.
edge_hidden_sizes: sizes of the hidden layers of the edge message nets.
node_hidden_sizes: sizes of the hidden layers of the node update nets.
n_prop_layers: number of graph propagation layers.
share_prop_params: set to True to share propagation parameters across all
graph propagation layers, False not to.
edge_net_init_scale: scale of initialization for the edge message nets.
node_update_type: type of node updates, one of {mlp, gru, residual}.
use_reverse_direction: set to True to also propagate messages in the
reverse direction.
reverse_dir_param_different: set to True to have the messages computed
using a different set of parameters than for the forward direction.
layer_norm: set to True to use layer normalization in a few places.
name: name of this module.
"""
super(GraphEmbeddingNet, self).__init__(name=name)
self._encoder = encoder
self._aggregator = aggregator
self._node_state_dim = node_state_dim
self._edge_hidden_sizes = edge_hidden_sizes
self._node_hidden_sizes = node_hidden_sizes
self._n_prop_layers = n_prop_layers
self._share_prop_params = share_prop_params
self._edge_net_init_scale = edge_net_init_scale
self._node_update_type = node_update_type
self._use_reverse_direction = use_reverse_direction
self._reverse_dir_param_different = reverse_dir_param_different
self._layer_norm = layer_norm
self._prop_layers = []
self._layer_class = GraphPropLayer
def _build_layer(self, layer_id):
"""Build one layer in the network."""
return self._layer_class(
self._node_state_dim,
self._edge_hidden_sizes,
self._node_hidden_sizes,
edge_net_init_scale=self._edge_net_init_scale,
node_update_type=self._node_update_type,
use_reverse_direction=self._use_reverse_direction,
reverse_dir_param_different=self._reverse_dir_param_different,
layer_norm=self._layer_norm,
name='graph-prop-%d' % layer_id)
def _apply_layer(self,
layer,
node_states,
from_idx,
to_idx,
graph_idx,
n_graphs,
edge_features):
"""Apply one layer on the given inputs."""
del graph_idx, n_graphs
return layer(node_states, from_idx, to_idx, edge_features=edge_features)
def _build(self,
node_features,
edge_features,
from_idx,
to_idx,
graph_idx,
n_graphs):
"""Compute graph representations.
Args:
node_features: [n_nodes, node_feat_dim] float tensor.
edge_features: [n_edges, edge_feat_dim] float tensor.
from_idx: [n_edges] int tensor, index of the from node for each edge.
to_idx: [n_edges] int tensor, index of the to node for each edge.
graph_idx: [n_nodes] int tensor, graph id for each node.
n_graphs: int, number of graphs in the batch.
Returns:
graph_representations: [n_graphs, graph_representation_dim] float tensor,
graph representations.
"""
if len(self._prop_layers) < self._n_prop_layers:
# build the layers
for i in range(self._n_prop_layers):
if i == 0 or not self._share_prop_params:
layer = self._build_layer(i)
else:
layer = self._prop_layers[0]
self._prop_layers.append(layer)
node_features, edge_features = self._encoder(node_features, edge_features)
node_states = node_features
layer_outputs = [node_states]
for layer in self._prop_layers:
# node_features could be wired in here as well, leaving it out for now as
# it is already in the inputs
node_states = self._apply_layer(
layer,
node_states,
from_idx,
to_idx,
graph_idx,
n_graphs,
edge_features)
layer_outputs.append(node_states)
# these tensors may be used e.g. for visualization
self._layer_outputs = layer_outputs
return self._aggregator(node_states, graph_idx, n_graphs)
def reset_n_prop_layers(self, n_prop_layers):
"""Set n_prop_layers to the provided new value.
This allows us to train with certain number of propagation layers and
evaluate with a different number of propagation layers.
This only works if n_prop_layers is smaller than the number used for
training, or when share_prop_params is set to True, in which case this can
be arbitrarily large.
Args:
n_prop_layers: the new number of propagation layers to set.
"""
self._n_prop_layers = n_prop_layers
@property
def n_prop_layers(self):
return self._n_prop_layers
def get_layer_outputs(self):
"""Get the outputs at each layer."""
if hasattr(self, '_layer_outputs'):
return self._layer_outputs
else:
raise ValueError('No layer outputs available.')
"""### The graph matching networks
The graph matching networks (GMNs) compute the similarity score for a pair of graphs jointly on the pair. In our current formulation, it still computes a representation for each graph, but the representations for a pair of graphs are computed jointly on the pair, through a cross-graph attention-based matching mechanism.
More concretely, the graph matching model can be formulated as
$$d(G_1, G_2) = d_H(embed\_and\_match(G_1, G_2))$$
where $embed\_and\_match(G_1, G_2)$ returns a pair of graph representations.
Similar to the embedding model, our GMNs computes graph representations through 3 steps. The difference to the embedding model is in the message passing step, where each node not only gets messages from within the same graph, but also gets cross-graph messages by attending to all the nodes in the other graph. This can be formulated as follows.
We first have within-graph messages as before:
$$
m_{i\rightarrow j} = f_\mathrm{message}(h_i^{(t)}, h_j^{(t)}, e_{ij}).
$$
In addition, we also allow each node in one graph to attend to all the other nodes in the other graph. The cross graph attention weight (node $i$ in one graph attending to node $j$ in the other graph, and vice versa) is computed as
$$\begin{array}{rcl}
a_{i\rightarrow j} &=& \frac{\exp(s(h_i^{(t)}, h_j^{(t)}))}{\sum_j \exp(s(h_i^{(t)}, h_j^{(t)}))} \\
a_{j\rightarrow i} &=& \frac{\exp(s(h_i^{(t)}, h_j^{(t)}))}{\sum_i \exp(s(h_i^{(t)}, h_j^{(t)}))},
\end{array}
$$
where $s(., .)$ is again a vector space similarity function, like Euclidean, dot-product or cosine. Also note the different indices being summed over in the normalizers.
The cross-graph message is then computed as
$$\begin{array}{rcl}
\mu_i &=& \sum_j a_{i\rightarrow j} (h_i^{(t)} - h_j^{(t)}) = h_i^{(t)} - \sum_j a_{i\rightarrow j} h_j^{(t)}, \\
\mu_j &=& \sum_i a_{j\rightarrow i} (h_j^{(t)} - h_i^{(t)}) = h_j^{(t)} - \sum_i a_{j\rightarrow i} h_i^{(t)}.
\end{array}
$$
Here we are computing an attention-weighted sum of all the node representations from the other graph, and then take the difference. This is essentially **matching** one node in one graph to nodes most similar to it in the other graph, and then compute the difference.
The node updates are then computed as
$$
h_i^{(t+1)} = f_\mathrm{node}\left(h_i^{(t)}, \sum_{j:(j,i)\in E} m_{j\rightarrow i}, \mu_i\right).
$$
The graph encoder and the graph aggregators are the same as in the embedding model.
#### A few similarity functions
These are the functions $s(., .)$ that will be used in the cross-graph attention.
"""
def pairwise_euclidean_similarity(x, y):
"""Compute the pairwise Euclidean similarity between x and y.
This function computes the following similarity value between each pair of x_i
and y_j: s(x_i, y_j) = -|x_i - y_j|^2.
Args:
x: NxD float tensor.
y: MxD float tensor.
Returns:
s: NxM float tensor, the pairwise euclidean similarity.
"""
s = 2 * tf.matmul(x, y, transpose_b=True)
diag_x = tf.reduce_sum(x * x, axis=-1, keepdims=True)
diag_y = tf.reshape(tf.reduce_sum(y * y, axis=-1), (1, -1))
return s - diag_x - diag_y
def pairwise_dot_product_similarity(x, y):
"""Compute the dot product similarity between x and y.
This function computes the following similarity value between each pair of x_i
and y_j: s(x_i, y_j) = x_i^T y_j.
Args:
x: NxD float tensor.
y: MxD float tensor.
Returns:
s: NxM float tensor, the pairwise dot product similarity.
"""
return tf.matmul(x, y, transpose_b=True)
def pairwise_cosine_similarity(x, y):
"""Compute the cosine similarity between x and y.
This function computes the following similarity value between each pair of x_i
and y_j: s(x_i, y_j) = x_i^T y_j / (|x_i||y_j|).
Args:
x: NxD float tensor.
y: MxD float tensor.
Returns:
s: NxM float tensor, the pairwise cosine similarity.
"""
x = tf.nn.l2_normalize(x, axis=-1)
y = tf.nn.l2_normalize(y, axis=-1)
return tf.matmul(x, y, transpose_b=True)
PAIRWISE_SIMILARITY_FUNCTION = {
'euclidean': pairwise_euclidean_similarity,
'dotproduct': pairwise_dot_product_similarity,
'cosine': pairwise_cosine_similarity,
}
def get_pairwise_similarity(name):
"""Get pairwise similarity metric by name.
Args:
name: string, name of the similarity metric, one of {dot-product, cosine,
euclidean}.
Returns:
similarity: a (x, y) -> sim function.
Raises:
ValueError: if name is not supported.
"""
if name not in PAIRWISE_SIMILARITY_FUNCTION:
raise ValueError('Similarity metric name "%s" not supported.' % name)
else:
return PAIRWISE_SIMILARITY_FUNCTION[name]
"""#### The cross-graph attention
We implement this cross-graph attention in batches of pairs.
"""
def compute_cross_attention(x, y, sim):
"""Compute cross attention.
x_i attend to y_j:
a_{i->j} = exp(sim(x_i, y_j)) / sum_j exp(sim(x_i, y_j))
y_j attend to x_i:
a_{j->i} = exp(sim(x_i, y_j)) / sum_i exp(sim(x_i, y_j))
attention_x = sum_j a_{i->j} y_j
attention_y = sum_i a_{j->i} x_i
Args:
x: NxD float tensor.
y: MxD float tensor.
sim: a (x, y) -> similarity function.
Returns:
attention_x: NxD float tensor.
attention_y: NxD float tensor.
"""
a = sim(x, y)
a_x = tf.nn.softmax(a, axis=1) # i->j
a_y = tf.nn.softmax(a, axis=0) # j->i
attention_x = tf.matmul(a_x, y)
attention_y = tf.matmul(a_y, x, transpose_a=True)
return attention_x, attention_y
def batch_block_pair_attention(data,
block_idx,
n_blocks,
similarity='dotproduct'):
"""Compute batched attention between pairs of blocks.
This function partitions the batch data into blocks according to block_idx.
For each pair of blocks, x = data[block_idx == 2i], and
y = data[block_idx == 2i+1], we compute
x_i attend to y_j:
a_{i->j} = exp(sim(x_i, y_j)) / sum_j exp(sim(x_i, y_j))
y_j attend to x_i:
a_{j->i} = exp(sim(x_i, y_j)) / sum_i exp(sim(x_i, y_j))
and
attention_x = sum_j a_{i->j} y_j
attention_y = sum_i a_{j->i} x_i.
Args:
data: NxD float tensor.
block_idx: N-dim int tensor.
n_blocks: integer.
similarity: a string, the similarity metric.
Returns:
attention_output: NxD float tensor, each x_i replaced by attention_x_i.
Raises:
ValueError: if n_blocks is not an integer or not a multiple of 2.
"""
if not isinstance(n_blocks, int):
raise ValueError('n_blocks (%s) has to be an integer.' % str(n_blocks))
if n_blocks % 2 != 0:
raise ValueError('n_blocks (%d) must be a multiple of 2.' % n_blocks)
sim = get_pairwise_similarity(similarity)
results = []
# This is probably better than doing boolean_mask for each i
partitions = tf.dynamic_partition(data, block_idx, n_blocks)
# It is rather complicated to allow n_blocks be a tf tensor and do this in a
# dynamic loop, and probably unnecessary to do so. Therefore we are
# restricting n_blocks to be a integer constant here and using the plain for
# loop.
for i in range(0, n_blocks, 2):
x = partitions[i]
y = partitions[i + 1]
attention_x, attention_y = compute_cross_attention(x, y, sim)
results.append(attention_x)
results.append(attention_y)
results = tf.concat(results, axis=0)
# the shape of the first dimension is lost after concat, reset it back
results.set_shape(data.shape)
return results
"""#### Graph matching layer and graph matching networks
This only involves a small set of changes from the graph embedding model.
"""
class GraphPropMatchingLayer(GraphPropLayer):
"""A graph propagation layer that also does cross graph matching.
It assumes the incoming graph data is batched and paired, i.e. graph 0 and 1
forms the first pair and graph 2 and 3 are the second pair etc., and computes
cross-graph attention-based matching for each pair.
"""
def _build(self,
node_states,
from_idx,
to_idx,
graph_idx,
n_graphs,
similarity='dotproduct',
edge_features=None,
node_features=None):
"""Run one propagation step with cross-graph matching.
Args:
node_states: [n_nodes, node_state_dim] float tensor, node states.
from_idx: [n_edges] int tensor, from node indices for each edge.
to_idx: [n_edges] int tensor, to node indices for each edge.
graph_idx: [n_onodes] int tensor, graph id for each node.
n_graphs: integer, number of graphs in the batch.
similarity: type of similarity to use for the cross graph attention.
edge_features: if not None, should be [n_edges, edge_feat_dim] tensor,
extra edge features.
node_features: if not None, should be [n_nodes, node_feat_dim] tensor,
extra node features.
Returns:
node_states: [n_nodes, node_state_dim] float tensor, new node states.
Raises:
ValueError: if some options are not provided correctly.
"""
aggregated_messages = self._compute_aggregated_messages(
node_states, from_idx, to_idx, edge_features=edge_features)
# new stuff here
cross_graph_attention = batch_block_pair_attention(
node_states, graph_idx, n_graphs, similarity=similarity)
attention_input = node_states - cross_graph_attention
return self._compute_node_update(node_states,
[aggregated_messages, attention_input],
node_features=node_features)
class GraphMatchingNet(GraphEmbeddingNet):
"""Graph matching net.
This class uses graph matching layers instead of the simple graph prop layers.
It assumes the incoming graph data is batched and paired, i.e. graph 0 and 1
forms the first pair and graph 2 and 3 are the second pair etc., and computes
cross-graph attention-based matching for each pair.
"""
def __init__(self,
encoder,
aggregator,
node_state_dim,
edge_hidden_sizes,
node_hidden_sizes,
n_prop_layers,
share_prop_params=False,
edge_net_init_scale=0.1,
node_update_type='residual',
use_reverse_direction=True,
reverse_dir_param_different=True,
layer_norm=False,
similarity='dotproduct',
name='graph-matching-net'):
super(GraphMatchingNet, self).__init__(
encoder,
aggregator,
node_state_dim,
edge_hidden_sizes,
node_hidden_sizes,
n_prop_layers,
share_prop_params=share_prop_params,
edge_net_init_scale=edge_net_init_scale,
node_update_type=node_update_type,
use_reverse_direction=use_reverse_direction,
reverse_dir_param_different=reverse_dir_param_different,
layer_norm=layer_norm,
name=name)
self._similarity = similarity
self._layer_class = GraphPropMatchingLayer
def _apply_layer(self,
layer,
node_states,
from_idx,
to_idx,
graph_idx,
n_graphs,
edge_features):
"""Apply one layer on the given inputs."""
return layer(node_states, from_idx, to_idx, graph_idx, n_graphs,
similarity=self._similarity, edge_features=edge_features)
"""## Training
### Labeled data examples
We train on either pairs of graphs or triplets of graphs. For pairs of graphs, we assume each pair $(G_1, G_2)$ comes with a label $t\in\{-1, 1\}$. $t=1$ if $G_1$ and $G_2$ are similar, and $t=-1$ otherwise.
For triplets of graphs, we assume within each triplet $(G_1, G_2, G_3)$, $G_1$ is similar to $G_2$ but not similar to $G_3$.
The goal of training is to learn the parameters of the function $f(G_1, G_2)$ such that similar graphs have high similarity (or small distance) and dissimilar graphs have low similarity (or high distance).
### Training on pairs
Given a dataset of pairs $(G_1, G_2)$ and labels $t\in\{-1, 1\}$, we can use the following margin-based loss if using Euclidean distance:
$$
L_\mathrm{pair} = \mathbb{E}_{(G_1, G_2, t)}[\max\{0, \gamma - t(1 - d(G_1, G_2))\}]
$$
This loss encourages similar graphs to have distance smaller than $1-\gamma$, and dissimilar graphs to have distance greater than $1 + \gamma$, where $\gamma$ is a margin parameter.
Alternatively, for many applications it is beneficial to have the representation of graphs be binary which allows efficient indexing and hashing. In this case, Hamming distance (similarity) is more appropriate. On the other hand, the Hamming distance is not differentiable, so we use a smooth approximation
$$
s(G_1, G_2) = \frac{1}{H}\sum_{i=1}^H \tanh(h_{G_1, i}) \cdot \tanh(h_{G_2, i}),
$$
where $s$ is now a similarity (rather than distance) function, $h_{G, i}$ is the i-th dimension of the smooth representation vector for G. We get binary codes by thresholding $h_{G,i}$ at 0, i.e. $\hat{h}_{G,i}=1$ if $h_{G,i}\ge 0$ and $-1$ otherwise.
The loss we use with these binary representations is defined as
$$
L_\mathrm{pair} = \mathbb{E}_{(G_1, G_2, t)}[(t - s(G_1, G_2))^2] / 4.
$$
The factor of $1/4$ is used to normalize the loss to between 0 and 1.
These are just two possible losses, many other types of losses could also be used.
"""
def euclidean_distance(x, y):
"""This is the squared Euclidean distance."""
return tf.reduce_sum((x - y)**2, axis=-1)
def approximate_hamming_similarity(x, y):
"""Approximate Hamming similarity."""