In [154]:
import tensorflow as tf
from tensorflow.keras.layers import Dense, BatchNormalization
from tensorflow.keras.activations import relu
from spektral.layers import GraphSageConv
from contextlib import suppress
import copy as cp

In [91]:
class GNN(tf.keras.Model):
    def __init__(self, hidden_channels, out_channels, add_linear=True):
        super(GNN, self).__init__()
        self.conv1 = GraphSageConv(hidden_channels)
        self.bn1 = BatchNormalization()
        self.conv2 = GraphSageConv(hidden_channels)
        self.bn2 = BatchNormalization()
        self.conv3 = GraphSageConv(out_channels)
        self.bn3 = BatchNormalization()
        
        if add_linear:
            self.lin = Dense(out_channels)
        else:
            self.lin = None
    
    def call(self, inputs):
        x, adj = inputs
        
        x1 = self.bn1(relu(self.conv1(inputs)))
        x2 = self.bn2(relu(self.conv2([x1, adj])))
        x3 = self.bn3(relu(self.conv3([x2, adj])))

        x = tf.concat([x1, x2, x3], axis=-1)

        if self.lin is not None:
            x = relu(self.lin(x))

        return x

In [101]:
len(tf.zeros((10,0)).shape)

2

In [173]:
def dense_diff_pool(x, adj, s):
    # https://pytorch-geometric.readthedocs.io/en/latest/_modules/torch_geometric/nn/dense/diff_pool.html#dense_diff_pool
    # add batch dimension if necessary
    with suppress(TypeError):
        adj = tf.sparse.to_dense(adj)
        s = tf.sparse.to_dense(adj)

    x = tf.expand_dims(x, axis=0) if len(x.shape) == 2 else x
    adj = tf.expand_dims(adj, axis=0) if len(adj.shape) == 2 else adj
    s = tf.expand_dims(s, axis=0) if len(s.shape) == 2 else s

    batch_size, num_nodes, _ = x.shape  # used when maks is implemented

    # s = tf.nn.softmax(s, axis=-1)
    s = tf.nn.softmax(s, axis=-1)    # check if this works as tf.nn.softmax(x, axis=-1)
    st = tf.transpose(s, (0, 2, 1))

    out = tf.matmul(st, x)
    out_adj = tf.matmul(tf.matmul(st, adj), s)

    link_loss = adj - tf.matmul(s, st)
    link_loss = tf.norm(link_loss, ord=2)
    link_loss = link_loss / tf.size(adj, out_type=tf.dtypes.float32)

    ent_loss = tf.reduce_mean(tf.reduce_sum(-s * tf.math.log(s + 1e-15), axis=-1))

    return out, out_adj, link_loss, ent_loss

In [177]:
class Net(tf.keras.Model):
    def __init__(self, in_channels=3, num_classes=6, max_nodes=200):
        super(Net, self).__init__()

        num_nodes = tf.math.ceil(0.25 * max_nodes)
        self.gnn1_pool = GNN(64, num_nodes)
        self.gnn1_embed = GNN(64, 64, add_linear=False)

        num_nodes = tf.math.ceil(0.25 * max_nodes)
        self.gnn2_pool = GNN(63, num_nodes)
        self.gnn1_embed = GNN(64, 64, add_linear=False)

        self.gnn3_embed = GNN(64, 64, add_linear=False)

        self.lin1 = Dense(64)
        self.lin2 = Dense(num_classes)
    
    def call(self, inputs):
        x, adj = inputs
        s = self.gnn1_pool([x, adj])
        x = self.gnn1_embed([x, adj])

        x, adj, l1, e1 = dense_diff_pool(x, adj, s)

        s = self.gnn2_pool([x, adj])
        x = self.gnn2_embed([x, adj])

        x, adj, l2, e2 = dense_diff_pool(x, adj, s)

        x = self.gnn3_embed([x, adj])

        x = tf.reduce_mean(x, axis=1)
        x = relu(self.lin1(x))
        x = self.lin2(x)

        return tf.nn.log_softmax(x, axis=-1), l1 + l2, e1 + e2

In [180]:
num_nodes = 10
num_features = 5

x = tf.Variable(tf.random.normal((num_nodes, num_features)))
adj = tf.sparse.from_dense(tf.round(tf.random.uniform((num_nodes, num_nodes))))

#bn(relu(gnn([x, adj]))).shape
gnn = GNN(64, 8)
gnn([x, adj])

<tf.Tensor: shape=(10, 8), dtype=float32, numpy=
array([[0.07391645, 0.        , 0.22801805, 0.        , 0.        ,
        0.19222161, 0.10885294, 0.01520273],
       [0.        , 0.12564272, 0.17030147, 0.0176785 , 0.        ,
        0.        , 0.11459391, 0.        ],
       [0.        , 0.21832334, 0.19624911, 0.        , 0.        ,
        0.04222625, 0.08397265, 0.        ],
       [0.        , 0.15233031, 0.15863347, 0.0828501 , 0.        ,
        0.1525804 , 0.18321669, 0.06215439],
       [0.        , 0.02776095, 0.26309374, 0.        , 0.        ,
        0.15623054, 0.        , 0.01747207],
       [0.01652861, 0.25542983, 0.17841187, 0.07956162, 0.        ,
        0.        , 0.27369484, 0.        ],
       [0.02258293, 0.        , 0.05399239, 0.11350705, 0.        ,
        0.        , 0.        , 0.        ],
       [0.        , 0.3260104 , 0.09149566, 0.13620564, 0.        ,
        0.09080829, 0.19152896, 0.        ],
       [0.05179439, 0.        , 0.19057748, 0.0

TensorShape([10, 64])