diff --git a/docs/autogen.py b/docs/autogen.py index 4e6bafae..e30f9755 100644 --- a/docs/autogen.py +++ b/docs/autogen.py @@ -44,7 +44,8 @@ layers.GraphAttention, layers.GraphConvSkip, layers.APPNP, - layers.GINConv + layers.GINConv, + layers.DiffusionConvolution ] }, { @@ -319,7 +320,8 @@ def count_leading_spaces(s): def process_list_block(docstring, starting_point, leading_spaces, marker): ending_point = docstring.find('\n\n', starting_point) - block = docstring[starting_point:(None if ending_point == -1 else ending_point - 1)] + block = docstring[starting_point:( + None if ending_point == -1 else ending_point - 1)] # Place marker for later reinjection. docstring = docstring.replace(block, marker) lines = block.split('\n') @@ -329,7 +331,8 @@ def process_list_block(docstring, starting_point, leading_spaces, marker): # These have to be removed, but first the list roots have to be detected. top_level_regex = r'^ ([^\s\\\(]+):(.*)' top_level_replacement = r'- __\1__:\2' - lines = [re.sub(top_level_regex, top_level_replacement, line) for line in lines] + lines = [re.sub(top_level_regex, top_level_replacement, line) + for line in lines] # All the other lines get simply the 4 leading space (if present) removed lines = [re.sub(r'^ ', '', line) for line in lines] # Fix text lines after lists @@ -366,13 +369,14 @@ def process_docstring(docstring): index = tmp[3:].find('```') + 6 snippet = tmp[:index] # Place marker in docstring for later reinjection. - docstring = docstring.replace(snippet, '$CODE_BLOCK_%d' % len(code_blocks)) + docstring = docstring.replace( + snippet, '$CODE_BLOCK_%d' % len(code_blocks)) snippet_lines = snippet.split('\n') # Remove leading spaces. num_leading_spaces = snippet_lines[-1].find('`') snippet_lines = ([snippet_lines[0]] + [line[num_leading_spaces:] - for line in snippet_lines[1:]]) + for line in snippet_lines[1:]]) # Most code snippets have 3 or 4 more leading spaces # on inner lines, but not all. Remove them. inner_lines = snippet_lines[1:-1] @@ -432,7 +436,8 @@ def process_docstring(docstring): # Spektral-specific code docstring = re.sub(r':param', '\n**Arguments** \n:param', docstring, 1) docstring = re.sub(r':param(.*):', r'\n- `\1`:', docstring) - docstring = re.sub(r':return: ([a-z])', lambda m: ':return: {}'.format(m.group(1).upper()), docstring) + docstring = re.sub( + r':return: ([a-z])', lambda m: ':return: {}'.format(m.group(1).upper()), docstring) docstring = re.sub(r':return:', '\n**Return** \n', docstring) return docstring @@ -494,7 +499,7 @@ def read_page_data(page_data, type): continue module_member = getattr(module, name) if (inspect.isclass(module_member) and type == 'classes' or - inspect.isfunction(module_member) and type == 'functions'): + inspect.isfunction(module_member) and type == 'functions'): instance = module_member if module.__name__ in instance.__module__: if instance not in module_data: @@ -587,9 +592,12 @@ def read_page_data(page_data, type): if not os.path.exists('sources/custom_theme/img/'): os.makedirs('sources/custom_theme/img/') - shutil.copy('./stylesheets/extra.css', './sources/stylesheets/extra.css') + shutil.copy('./stylesheets/extra.css', + './sources/stylesheets/extra.css') shutil.copy('./js/macros.js', './sources/js/macros.js') for file in glob.glob(r'./img/*.svg'): shutil.copy(file, './sources/img/') - shutil.copy('./img/favicon.ico', './sources/custom_theme/img/favicon.ico') - shutil.copy('./templates/google8a76765aa72fa8c1.html', './sources/google8a76765aa72fa8c1.html') + shutil.copy('./img/favicon.ico', + './sources/custom_theme/img/favicon.ico') + shutil.copy('./templates/google8a76765aa72fa8c1.html', + './sources/google8a76765aa72fa8c1.html') diff --git a/docs/local_build.sh b/docs/local_build.sh new file mode 100644 index 00000000..a0621d9a --- /dev/null +++ b/docs/local_build.sh @@ -0,0 +1,15 @@ +#!/bin/bash + +# install current branch +cd ../ +python setup.py install +cd docs/ + +# delete old docs +rm -r sources/ + +# generate new docs +python autogen.py + +# serve new docs +python -m mkdocs serve \ No newline at end of file diff --git a/spektral/layers/convolutional/__init__.py b/spektral/layers/convolutional/__init__.py index 81c7166f..0dc58c40 100644 --- a/spektral/layers/convolutional/__init__.py +++ b/spektral/layers/convolutional/__init__.py @@ -6,4 +6,5 @@ from .gcn import GraphConv from .gcs import GraphConvSkip from .gin import GINConv -from .graphsage import GraphSageConv \ No newline at end of file +from .graphsage import GraphSageConv +from .diffconv import DiffusionConvolution diff --git a/spektral/layers/convolutional/diffconv.py b/spektral/layers/convolutional/diffconv.py new file mode 100644 index 00000000..21e464cb --- /dev/null +++ b/spektral/layers/convolutional/diffconv.py @@ -0,0 +1,200 @@ +import tensorflow as tf +import tensorflow.keras.layers as layers +from spektral.layers.convolutional.gcn import GraphConv + + +class DiffuseFeatures(layers.Layer): + r"""Utility layer calculating a single channel of the + diffusional convolution. + + Procedure is based on https://arxiv.org/pdf/1707.01926.pdf + + **Input** + + - Node features of shape `([batch], N, F)`; + - Normalized adjacency or attention coef. matrix \(\hat \A \) of shape + `([batch], N, N)`; Use DiffusionConvolution.preprocess to normalize. + + **Output** + + - Node features with the same shape as the input, but with the last + dimension changed to \(1\). + + **Arguments** + + - `num_diffusion_steps`: How many diffusion steps to consider. \(K\) in paper. + - `kernel_initializer`: initializer for the kernel matrix; + - `kernel_regularizer`: regularization applied to the kernel vectors; + - `kernel_constraint`: constraint applied to the kernel vectors; + """ + + def __init__( + self, + num_diffusion_steps: int, + kernel_initializer, + kernel_regularizer, + kernel_constraint, + **kwargs + ): + super(DiffuseFeatures, self).__init__() + + # number of diffusino steps (K in paper) + self.K = num_diffusion_steps + + # get regularizer, initializer and constraint for kernel + self.kernel_initializer = kernel_initializer + self.kernel_regularizer = kernel_regularizer + self.kernel_constraint = kernel_constraint + + def build(self, input_shape): + + # Initializing the kernel vector (R^K) + # (theta in paper) + self.kernel = self.add_weight( + shape=(self.K,), + initializer=self.kernel_initializer, + regularizer=self.kernel_regularizer, + constraint=self.kernel_constraint, + ) + + def call(self, inputs): + + # Get signal X and adjacency A + X, A = inputs + + # Calculate diffusion matrix: sum kernel_k * Attention_t^k + # tf.polyval needs a list of tensors as the coeff. thus we + # unstack kernel + diffusion_matrix = tf.math.polyval(tf.unstack(self.kernel), A) + + # Apply it to X to get a matrix C = [C_1, ..., C_F] (N x F) + # of diffused features + diffused_features = tf.matmul(diffusion_matrix, X) + + # Now we add all diffused features (columns of the above matrix) + # and apply a non linearity to obtain H:,q (eq. 3 in paper) + H = tf.math.reduce_sum(diffused_features, axis=-1) + + # H has shape ([batch], N) but as it is the sum of columns + # we reshape it to ([batch], N, 1) + return tf.expand_dims(H, -1) + + +class DiffusionConvolution(GraphConv): + r"""Applies Graph Diffusion Convolution as descibed by + [Li et al. (2016)](https://arxiv.org/pdf/1707.01926.pdf) + + **Mode**: single, mixed, batch. + + Given a number of diffusion steps \(K\) and a row normalized adjacency matrix \(\hat \A \), + this layer calculates the q'th channel as: + $$ + \mathbf{H}_{~:,~q} = \sigma( + \sum_{f=1}^{F} + \left( + \sum_{k=0}^{K-1}\theta_k {\hat \A}^k + \right) + \X_{~:,~f} + ) + $$ + + **Input** + + - Node features of shape `([batch], N, F)`; + - Normalized adjacency or attention coef. matrix \(\hat \A \) of shape + `([batch], N, N)`; Use `DiffusionConvolution.preprocess` to normalize. + + **Output** + + - Node features with the same shape as the input, but with the last + dimension changed to `channels`. + + **Arguments** + + - `channels`: number of output channels; + - `num_diffusion_steps`: How many diffusion steps to consider. \(K\) in paper. + - `activation`: activation function \(\sigma\); (\(\tanh\) by default) + - `kernel_initializer`: initializer for the kernel matrix; + - `kernel_regularizer`: regularization applied to the kernel vectors; + - `kernel_constraint`: constraint applied to the kernel vectors; + """ + + def __init__( + self, + channels: int, + num_diffusion_steps: int = 6, + kernel_initializer='glorot_uniform', + kernel_regularizer=None, + kernel_constraint=None, + activation='tanh', + ** kwargs + ): + super().__init__(channels, + activation=activation, + kernel_initializer=kernel_initializer, + kernel_regularizer=kernel_regularizer, + kernel_constraint=kernel_constraint, + **kwargs) + + # number of features to generate (Q in paper) + assert channels > 0 + self.Q = channels + + # number of diffusion steps for each output feature + self.K = num_diffusion_steps + 1 + + def build(self, input_shape): + + # We expect to receive (X, A) + # A - Adjacency ([batch], N, N) + # X - graph signal ([batch], N, F) + X_shape, A_shape = input_shape + + # initialise Q diffusion convolution filters + self.filters = [] + + for _ in range(self.Q): + layer = DiffuseFeatures( + num_diffusion_steps=self.K, + kernel_initializer=self.kernel_initializer, + kernel_regularizer=self.kernel_regularizer, + kernel_constraint=self.kernel_constraint, + ) + self.filters.append(layer) + + def apply_filters(self, X, A): + """Applies diffusion convolution self.Q times to get a + ([batch], N, Q) diffused graph signal + + """ + + # This will be a list of Q diffused features. + # Each diffused feature is a (batch, N, 1) tensor. + # Later we will concat all the features to get one + # (batch, N, Q) diffused graph signal + diffused_features = [] + + # Iterating over all Q diffusion filters + for diffusion in self.filters: + diffused_feature = diffusion((X, A)) + diffused_features.append(diffused_feature) + + # Concat them into ([batch], N, Q) diffused graph signal + H = tf.concat(diffused_features, -1) + + return H + + def call(self, inputs): + + # Get graph signal X and adjacency tensor A + X, A = inputs + + # 'single', 'batch' and 'mixed' mode are supported by + # default, since we access the dimensions from the end + # and everything else is broadcasted accordingly + # if its missing. + + H = self.apply_filters(X, A) + H = self.activation(H) + + return H diff --git a/spektral/layers/pooling/globalpool.py b/spektral/layers/pooling/globalpool.py index 235fae6c..5926f7d4 100644 --- a/spektral/layers/pooling/globalpool.py +++ b/spektral/layers/pooling/globalpool.py @@ -345,7 +345,7 @@ def get_config(self): class SortPool(Layer): r""" SortPool layer pooling the top \(k\) most relevant nodes as described by - (Zhang et al.)[https://www.cse.wustl.edu/~muhan/papers/AAAI_2018_DGCNN.pdf] + [Zhang et al.](https://www.cse.wustl.edu/~muhan/papers/AAAI_2018_DGCNN.pdf) This layers takes a graph signal \(\X\) and sorts the rows by the elements of its last column. It then keeps the top \(k\) rows. @@ -447,3 +447,9 @@ def get_config(self): } base_config = super().get_config() return dict(list(base_config.items()) + list(config.items())) + + def compute_output_shape(self, input_shape): + if self.data_mode == 'single': + return (self.k, input_shape[-1]) + elif self.data_mode == 'batch': + return (input_shape[0], self.k, input_shape[-1]) diff --git a/tests/test_layers/test_convolutional.py b/tests/test_layers/test_convolutional.py index f97e6336..8702f8a5 100644 --- a/tests/test_layers/test_convolutional.py +++ b/tests/test_layers/test_convolutional.py @@ -3,7 +3,7 @@ from tensorflow.keras import Model, Input from spektral.layers import GraphConv, ChebConv, EdgeConditionedConv, GraphAttention, GraphConvSkip, ARMAConv, APPNP, \ - GraphSageConv, GINConv + GraphSageConv, GINConv, DiffusionConvolution tf.keras.backend.set_floatx('float64') SINGLE, BATCH, MIXED = 1, 2, 3 # Single, batch, mixed @@ -54,6 +54,11 @@ LAYER_K_: GINConv, MODES_K_: [SINGLE], KWARGS_K_: {'channels': 8, 'activation': 'relu', 'mlp_hidden': [16], 'sparse': True} + }, + { + LAYER_K_: DiffusionConvolution, + MODES_K_: [SINGLE, BATCH, MIXED], + KWARGS_K_: {'channels': 8, 'activation': 'tanh', 'num_diffusion_steps': 5} } ] diff --git a/tests/test_layers/test_global_pooling.py b/tests/test_layers/test_global_pooling.py index 96617da6..17632fc8 100644 --- a/tests/test_layers/test_global_pooling.py +++ b/tests/test_layers/test_global_pooling.py @@ -2,7 +2,7 @@ import tensorflow as tf from tensorflow.keras import Input, Model -from spektral.layers import GlobalSumPool, GlobalAttnSumPool, GlobalAttentionPool, GlobalAvgPool, GlobalMaxPool +from spektral.layers import GlobalSumPool, GlobalAttnSumPool, GlobalAttentionPool, GlobalAvgPool, GlobalMaxPool, SortPool tf.keras.backend.set_floatx('float64') batch_size = 32 @@ -57,10 +57,34 @@ def _test_graph_mode(layer, **kwargs): output = model([X, S]) assert output.shape == (batch_size, kwargs.get('channels', F)) # When creating actual graph, the bacth dimension is None - assert output.shape[1:] == layer_instance.compute_output_shape([X.shape, S.shape])[1:] + assert output.shape[1:] == layer_instance.compute_output_shape([X.shape, S.shape])[ + 1:] _check_output_and_model_output_shapes(output.shape, model.output_shape) +def _test_sortpool_single(layer, k): + X = np.random.normal(size=(N, F)) + X_in = Input(shape=(F,)) + layer_instance = layer(k=k) + output = layer_instance(X_in) + model = Model(X_in, output) + + output = model(X) + assert output.shape == (k, F) + assert output.shape == layer_instance.compute_output_shape(X.shape) + + +def _test_sortpool_batched(layer, k): + X = np.random.normal(size=(batch_size, N, F)) + X_in = Input(shape=(N, F)) + layer_instance = layer(k=k) + output = layer_instance(X_in) + model = Model(X_in, output) + output = model(X) + assert output.shape == (batch_size, k, F) + assert output.shape == layer_instance.compute_output_shape(X.shape) + + def test_global_sum_pool(): _test_single_mode(GlobalSumPool) _test_batch_mode(GlobalSumPool) @@ -91,3 +115,8 @@ def test_global_attention_pool(): _test_single_mode(GlobalAttentionPool, channels=F_) _test_batch_mode(GlobalAttentionPool, channels=F_) _test_graph_mode(GlobalAttentionPool, channels=F_) + + +def test_global_sort_pool(): + _test_sortpool_single(SortPool, k=6) + _test_sortpool_batched(SortPool, k=6) diff --git a/tests/test_layers/test_pooling.py b/tests/test_layers/test_pooling.py index f437ff25..b3b987aa 100644 --- a/tests/test_layers/test_pooling.py +++ b/tests/test_layers/test_pooling.py @@ -29,11 +29,6 @@ MODES_K_: [SINGLE, BATCH], KWARGS_K_: {'k': 5, 'return_mask': True, 'sparse': True} }, - { - LAYER_K_: SortPool, - MODES_K_: [SINGLE, BATCH], - KWARGS_K_: {'k': 5} - }, ]