Skip to content

Commit

Permalink
Merge pull request #52 from LeviBorodenko/ENH_DiffConv
Browse files Browse the repository at this point in the history
 ENH (issue #49): Added DiffusionConvolution, tests and a script to easily build docs locally.
  • Loading branch information
danielegrattarola committed Apr 29, 2020
2 parents adf23ae + 3180832 commit c9a54bf
Show file tree
Hide file tree
Showing 8 changed files with 279 additions and 20 deletions.
28 changes: 18 additions & 10 deletions docs/autogen.py
Expand Up @@ -44,7 +44,8 @@
layers.GraphAttention,
layers.GraphConvSkip,
layers.APPNP,
layers.GINConv
layers.GINConv,
layers.DiffusionConvolution
]
},
{
Expand Down Expand Up @@ -319,7 +320,8 @@ def count_leading_spaces(s):

def process_list_block(docstring, starting_point, leading_spaces, marker):
ending_point = docstring.find('\n\n', starting_point)
block = docstring[starting_point:(None if ending_point == -1 else ending_point - 1)]
block = docstring[starting_point:(
None if ending_point == -1 else ending_point - 1)]
# Place marker for later reinjection.
docstring = docstring.replace(block, marker)
lines = block.split('\n')
Expand All @@ -329,7 +331,8 @@ def process_list_block(docstring, starting_point, leading_spaces, marker):
# These have to be removed, but first the list roots have to be detected.
top_level_regex = r'^ ([^\s\\\(]+):(.*)'
top_level_replacement = r'- __\1__:\2'
lines = [re.sub(top_level_regex, top_level_replacement, line) for line in lines]
lines = [re.sub(top_level_regex, top_level_replacement, line)
for line in lines]
# All the other lines get simply the 4 leading space (if present) removed
lines = [re.sub(r'^ ', '', line) for line in lines]
# Fix text lines after lists
Expand Down Expand Up @@ -366,13 +369,14 @@ def process_docstring(docstring):
index = tmp[3:].find('```') + 6
snippet = tmp[:index]
# Place marker in docstring for later reinjection.
docstring = docstring.replace(snippet, '$CODE_BLOCK_%d' % len(code_blocks))
docstring = docstring.replace(
snippet, '$CODE_BLOCK_%d' % len(code_blocks))
snippet_lines = snippet.split('\n')
# Remove leading spaces.
num_leading_spaces = snippet_lines[-1].find('`')
snippet_lines = ([snippet_lines[0]] +
[line[num_leading_spaces:]
for line in snippet_lines[1:]])
for line in snippet_lines[1:]])
# Most code snippets have 3 or 4 more leading spaces
# on inner lines, but not all. Remove them.
inner_lines = snippet_lines[1:-1]
Expand Down Expand Up @@ -432,7 +436,8 @@ def process_docstring(docstring):
# Spektral-specific code
docstring = re.sub(r':param', '\n**Arguments** \n:param', docstring, 1)
docstring = re.sub(r':param(.*):', r'\n- `\1`:', docstring)
docstring = re.sub(r':return: ([a-z])', lambda m: ':return: {}'.format(m.group(1).upper()), docstring)
docstring = re.sub(
r':return: ([a-z])', lambda m: ':return: {}'.format(m.group(1).upper()), docstring)
docstring = re.sub(r':return:', '\n**Return** \n', docstring)

return docstring
Expand Down Expand Up @@ -494,7 +499,7 @@ def read_page_data(page_data, type):
continue
module_member = getattr(module, name)
if (inspect.isclass(module_member) and type == 'classes' or
inspect.isfunction(module_member) and type == 'functions'):
inspect.isfunction(module_member) and type == 'functions'):
instance = module_member
if module.__name__ in instance.__module__:
if instance not in module_data:
Expand Down Expand Up @@ -587,9 +592,12 @@ def read_page_data(page_data, type):
if not os.path.exists('sources/custom_theme/img/'):
os.makedirs('sources/custom_theme/img/')

shutil.copy('./stylesheets/extra.css', './sources/stylesheets/extra.css')
shutil.copy('./stylesheets/extra.css',
'./sources/stylesheets/extra.css')
shutil.copy('./js/macros.js', './sources/js/macros.js')
for file in glob.glob(r'./img/*.svg'):
shutil.copy(file, './sources/img/')
shutil.copy('./img/favicon.ico', './sources/custom_theme/img/favicon.ico')
shutil.copy('./templates/google8a76765aa72fa8c1.html', './sources/google8a76765aa72fa8c1.html')
shutil.copy('./img/favicon.ico',
'./sources/custom_theme/img/favicon.ico')
shutil.copy('./templates/google8a76765aa72fa8c1.html',
'./sources/google8a76765aa72fa8c1.html')
15 changes: 15 additions & 0 deletions docs/local_build.sh
@@ -0,0 +1,15 @@
#!/bin/bash

# install current branch
cd ../
python setup.py install
cd docs/

# delete old docs
rm -r sources/

# generate new docs
python autogen.py

# serve new docs
python -m mkdocs serve
3 changes: 2 additions & 1 deletion spektral/layers/convolutional/__init__.py
Expand Up @@ -6,4 +6,5 @@
from .gcn import GraphConv
from .gcs import GraphConvSkip
from .gin import GINConv
from .graphsage import GraphSageConv
from .graphsage import GraphSageConv
from .diffconv import DiffusionConvolution
200 changes: 200 additions & 0 deletions spektral/layers/convolutional/diffconv.py
@@ -0,0 +1,200 @@
import tensorflow as tf
import tensorflow.keras.layers as layers
from spektral.layers.convolutional.gcn import GraphConv


class DiffuseFeatures(layers.Layer):
r"""Utility layer calculating a single channel of the
diffusional convolution.
Procedure is based on https://arxiv.org/pdf/1707.01926.pdf
**Input**
- Node features of shape `([batch], N, F)`;
- Normalized adjacency or attention coef. matrix \(\hat \A \) of shape
`([batch], N, N)`; Use DiffusionConvolution.preprocess to normalize.
**Output**
- Node features with the same shape as the input, but with the last
dimension changed to \(1\).
**Arguments**
- `num_diffusion_steps`: How many diffusion steps to consider. \(K\) in paper.
- `kernel_initializer`: initializer for the kernel matrix;
- `kernel_regularizer`: regularization applied to the kernel vectors;
- `kernel_constraint`: constraint applied to the kernel vectors;
"""

def __init__(
self,
num_diffusion_steps: int,
kernel_initializer,
kernel_regularizer,
kernel_constraint,
**kwargs
):
super(DiffuseFeatures, self).__init__()

# number of diffusino steps (K in paper)
self.K = num_diffusion_steps

# get regularizer, initializer and constraint for kernel
self.kernel_initializer = kernel_initializer
self.kernel_regularizer = kernel_regularizer
self.kernel_constraint = kernel_constraint

def build(self, input_shape):

# Initializing the kernel vector (R^K)
# (theta in paper)
self.kernel = self.add_weight(
shape=(self.K,),
initializer=self.kernel_initializer,
regularizer=self.kernel_regularizer,
constraint=self.kernel_constraint,
)

def call(self, inputs):

# Get signal X and adjacency A
X, A = inputs

# Calculate diffusion matrix: sum kernel_k * Attention_t^k
# tf.polyval needs a list of tensors as the coeff. thus we
# unstack kernel
diffusion_matrix = tf.math.polyval(tf.unstack(self.kernel), A)

# Apply it to X to get a matrix C = [C_1, ..., C_F] (N x F)
# of diffused features
diffused_features = tf.matmul(diffusion_matrix, X)

# Now we add all diffused features (columns of the above matrix)
# and apply a non linearity to obtain H:,q (eq. 3 in paper)
H = tf.math.reduce_sum(diffused_features, axis=-1)

# H has shape ([batch], N) but as it is the sum of columns
# we reshape it to ([batch], N, 1)
return tf.expand_dims(H, -1)


class DiffusionConvolution(GraphConv):
r"""Applies Graph Diffusion Convolution as descibed by
[Li et al. (2016)](https://arxiv.org/pdf/1707.01926.pdf)
**Mode**: single, mixed, batch.
Given a number of diffusion steps \(K\) and a row normalized adjacency matrix \(\hat \A \),
this layer calculates the q'th channel as:
$$
\mathbf{H}_{~:,~q} = \sigma(
\sum_{f=1}^{F}
\left(
\sum_{k=0}^{K-1}\theta_k {\hat \A}^k
\right)
\X_{~:,~f}
)
$$
**Input**
- Node features of shape `([batch], N, F)`;
- Normalized adjacency or attention coef. matrix \(\hat \A \) of shape
`([batch], N, N)`; Use `DiffusionConvolution.preprocess` to normalize.
**Output**
- Node features with the same shape as the input, but with the last
dimension changed to `channels`.
**Arguments**
- `channels`: number of output channels;
- `num_diffusion_steps`: How many diffusion steps to consider. \(K\) in paper.
- `activation`: activation function \(\sigma\); (\(\tanh\) by default)
- `kernel_initializer`: initializer for the kernel matrix;
- `kernel_regularizer`: regularization applied to the kernel vectors;
- `kernel_constraint`: constraint applied to the kernel vectors;
"""

def __init__(
self,
channels: int,
num_diffusion_steps: int = 6,
kernel_initializer='glorot_uniform',
kernel_regularizer=None,
kernel_constraint=None,
activation='tanh',
** kwargs
):
super().__init__(channels,
activation=activation,
kernel_initializer=kernel_initializer,
kernel_regularizer=kernel_regularizer,
kernel_constraint=kernel_constraint,
**kwargs)

# number of features to generate (Q in paper)
assert channels > 0
self.Q = channels

# number of diffusion steps for each output feature
self.K = num_diffusion_steps + 1

def build(self, input_shape):

# We expect to receive (X, A)
# A - Adjacency ([batch], N, N)
# X - graph signal ([batch], N, F)
X_shape, A_shape = input_shape

# initialise Q diffusion convolution filters
self.filters = []

for _ in range(self.Q):
layer = DiffuseFeatures(
num_diffusion_steps=self.K,
kernel_initializer=self.kernel_initializer,
kernel_regularizer=self.kernel_regularizer,
kernel_constraint=self.kernel_constraint,
)
self.filters.append(layer)

def apply_filters(self, X, A):
"""Applies diffusion convolution self.Q times to get a
([batch], N, Q) diffused graph signal
"""

# This will be a list of Q diffused features.
# Each diffused feature is a (batch, N, 1) tensor.
# Later we will concat all the features to get one
# (batch, N, Q) diffused graph signal
diffused_features = []

# Iterating over all Q diffusion filters
for diffusion in self.filters:
diffused_feature = diffusion((X, A))
diffused_features.append(diffused_feature)

# Concat them into ([batch], N, Q) diffused graph signal
H = tf.concat(diffused_features, -1)

return H

def call(self, inputs):

# Get graph signal X and adjacency tensor A
X, A = inputs

# 'single', 'batch' and 'mixed' mode are supported by
# default, since we access the dimensions from the end
# and everything else is broadcasted accordingly
# if its missing.

H = self.apply_filters(X, A)
H = self.activation(H)

return H
8 changes: 7 additions & 1 deletion spektral/layers/pooling/globalpool.py
Expand Up @@ -345,7 +345,7 @@ def get_config(self):
class SortPool(Layer):
r"""
SortPool layer pooling the top \(k\) most relevant nodes as described by
(Zhang et al.)[https://www.cse.wustl.edu/~muhan/papers/AAAI_2018_DGCNN.pdf]
[Zhang et al.](https://www.cse.wustl.edu/~muhan/papers/AAAI_2018_DGCNN.pdf)
This layers takes a graph signal \(\X\) and sorts the rows by the elements
of its last column. It then keeps the top \(k\) rows.
Expand Down Expand Up @@ -447,3 +447,9 @@ def get_config(self):
}
base_config = super().get_config()
return dict(list(base_config.items()) + list(config.items()))

def compute_output_shape(self, input_shape):
if self.data_mode == 'single':
return (self.k, input_shape[-1])
elif self.data_mode == 'batch':
return (input_shape[0], self.k, input_shape[-1])
7 changes: 6 additions & 1 deletion tests/test_layers/test_convolutional.py
Expand Up @@ -3,7 +3,7 @@
from tensorflow.keras import Model, Input

from spektral.layers import GraphConv, ChebConv, EdgeConditionedConv, GraphAttention, GraphConvSkip, ARMAConv, APPNP, \
GraphSageConv, GINConv
GraphSageConv, GINConv, DiffusionConvolution

tf.keras.backend.set_floatx('float64')
SINGLE, BATCH, MIXED = 1, 2, 3 # Single, batch, mixed
Expand Down Expand Up @@ -54,6 +54,11 @@
LAYER_K_: GINConv,
MODES_K_: [SINGLE],
KWARGS_K_: {'channels': 8, 'activation': 'relu', 'mlp_hidden': [16], 'sparse': True}
},
{
LAYER_K_: DiffusionConvolution,
MODES_K_: [SINGLE, BATCH, MIXED],
KWARGS_K_: {'channels': 8, 'activation': 'tanh', 'num_diffusion_steps': 5}
}
]

Expand Down

0 comments on commit c9a54bf

Please sign in to comment.