<a href="https://colab.research.google.com/github/msterpa87/GNN-CL/blob/master/GNN_CL.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [5]:
!pip install dgl -f https://data.dgl.ai/wheels/repo.html

Looking in links: https://data.dgl.ai/wheels/repo.html
Collecting dgl
  Downloading https://data.dgl.ai/wheels/dgl-0.7.2-cp37-cp37m-manylinux1_x86_64.whl (5.7 MB)
[K     |████████████████████████████████| 5.7 MB 2.2 MB/s 
Installing collected packages: dgl
Successfully installed dgl-0.7.2


In [11]:
import dgl
import tensorflow as tf
from tensorflow import keras
from keras.layers import BatchNormalization, Dense
from tensorflow.nn import relu
from dgl.nn import DenseSAGEConv

In [12]:
class GNN(tf.keras.Model):
  def __init__(self, in_channels, hidden_channels, out_channels,
               normalize=False, lin=True):
    super(GNN, self).__init__()
    self.conv1 = DenseSAGEConv(in_channels, hidden_channels, norm=normalize)
    self.bn1 = BatchNormalization()
    self.conv2 = DenseSAGEConv(hidden_channels, hidden_channels, norm=normalize)
    self.bn2 = BatchNormalization()
    self.conv3 = DenseSAGEConv(hidden_channels, out_channels, norm=normalize)
    self.bn3 = BatchNormalization()

    if lin:
      self.lin = Dense(out_channels, input_shape=(2 * hidden_channels + out_channels,))
    else:
      self.lin = None
    
  def bn(self, i, x):
    batch_size, num_nodes, num_channels = x.size()

    x = x.view(-1, num_channels)
    x = getattr(self, 'bn{}'.format(i))(x)
    x = x.view(batch_size, num_nodes, num_channels)
    return x
  
  def call(self, x, adj, mask=None):
    batch_size, num_nodes, in_channels = x.size()

    x0 = x
    x1 = self.bn(1, relu(self.conv1(x0, adj, mask)))
    x2 = self.bn(2, relu(self.conv2(x1, adj, mask)))
    x3 = self.bn(3, relu(self.conv3(x2, adj, mask)))

    x = tf.concat([x1, x2, x3], axis=1)

    if self.lin is not None:
      x = relu(self.lin(x))
    
    return x