<a href="https://colab.research.google.com/github/krzs13/dgl_tensorflow/blob/main/densegraphconv.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [1]:
import tensorflow as tf
from tensorflow.keras import layers
import numpy as np

In [6]:
class DenseGraphConv(layers.Layer):
    def __init__(self,
                 in_feats,
                 out_feats,
                 norm='both',
                 bias=True,
                 activation=None):
        super(DenseGraphConv, self).__init__()
        self._in_feats = in_feats
        self._out_feats = out_feats
        self._norm = norm
        xinit = tf.keras.initializers.glorot_uniform()
        self.weight = tf.Variable(initial_value=xinit(
            shape=(in_feats, out_feats), dtype='float32'), trainable=True)
        
        if bias:
            zeroinit = tf.keras.initializers.zeros()
            self.bias = tf.Variable(initial_value=zeroinit(
                shape=(out_feats), dtype='float32'), trainable=True)
        else:
            self.bias = None

        self._activation = activation

    def call(self, adj, feat):
        with tf.device(feat.device):
            adj = tf.cast(adj, feat.dtype)        
        
        feat_src = feat

        if self._norm == 'both':
            src_degrees = tf.clip_by_value(tf.math.reduce_sum(adj, axis=0), 
                                           clip_value_min=1, 
                                           clip_value_max=np.inf)
            norm_src = tf.pow(src_degrees, -0.5)
            shp_src = norm_src.shape + (1,) * (feat_src.ndim - 1)

            with tf.device(feat.device):
                norm_src = tf.reshape(norm_src, shp_src)

            feat_src = feat_src * norm_src

        if self._in_feats > self._out_feats:
            # mult W first to reduce the feature size for aggregation.
            feat_src = tf.matmul(feat_src, self.weight)
            rst = tf.tensordot(adj, feat_src)
        else:
            # aggregate first then mult W
            rst = tf.tensordot(adj, feat_src)
            feat_src = tf.matmul(feat_src, self.weight)

        if self._norm != 'none':
            dst_degrees = tf.clip_by_value(tf.math.reduce_sum(adj, axis=1),
                                           clip_value_min=1,
                                           clip_value_max=np.inf)

            if self._norm == 'both':
                norm_dst = tf.pow(dst_degrees, -0.5)
            else:  # right
                norm_dst = 1.0 / dst_degrees

            shp_dst = norm_dst.shape + (1,) * (feat.ndim - 1)
            
            with tf.device(feat.device):
                norm_dst = tf.reshape(norm_dst, shp_dst)

            rst = rst * norm_dst

            if self.bias is not None:
                rst = rst + self.bias

            if self._activation is not None:
                rst = self._activation(rst)

            return rst
