In [67]:

from dgl.utils import expand_as_pair, check_eq_shape


"""
    Source:             https://github.com/dmlc/dgl/blob/master/python/dgl/nn/pytorch/conv/sageconv.py
    
    Modifications:      Removed all the aggregation types but GCN (forced to GCN aggr_type)
                        Applied message computation with edge values
"""
class MeSSAGEConv(nn.Module):

    def __init__(self,
                 in_feats,
                 out_feats,
                 feat_drop=0.,
                 bias=True,
                 norm=None,
                 activation=None):
        super(MeSSAGEConv, self).__init__()

        self._in_src_feats, self._in_dst_feats = expand_as_pair(in_feats)
        self._out_feats = out_feats
        self.norm = norm
        self.feat_drop = nn.Dropout(feat_drop)
        self.activation = activation

        # Using GCN
        self.fc_neigh = nn.Linear(self._in_src_feats, out_feats, bias=bias)
        self.reset_parameters()

    def reset_parameters(self):
        gain = nn.init.calculate_gain('relu')

        # GCN  init with xavier uniform
        nn.init.xavier_uniform_(self.fc_neigh.weight, gain=gain)


    def forward(self, graph, feat, edges):

        with graph.local_scope():
            # (u, v) format 
            if isinstance(feat, tuple):
                feat_src = self.feat_drop(feat[0])
                feat_dst = self.feat_drop(feat[1])
            else:
                feat_src = feat_dst = self.feat_drop(feat)
                if graph.is_block:
                    feat_dst = feat_src[:graph.number_of_dst_nodes()]

            h_self = feat_dst

            # Handle the case of graphs without edges
            if graph.number_of_edges() == 0:
                graph.dstdata['neigh'] = torch.zeros(
                    feat_dst.shape[0], self._in_src_feats).to(feat_dst)

            # AGGR TYPE: GCN -> forward
            check_eq_shape(feat)
            graph.srcdata['h'] = feat_src
            graph.dstdata['h'] = feat_dst     # same as above if homogeneous

            # Apply edges to src nodes
            graph.edata['a'] = edges
            graph.update_all(
                fn.u_mul_e('h', 'a', 'm_a'),
                fn.sum('m_a', 'h')
            )

            # Apply neighbour aggregation -> message 
            graph.update_all(fn.copy_src('h', 'm'), fn.sum('m', 'neigh'))

            # neighbours + curr node -> new node feat
            degs = graph.in_degrees().to(feat_dst)
            h_neigh = (graph.dstdata['neigh'] + graph.dstdata['h']) / (degs.unsqueeze(-1) + 1)
        
            # NN(new node feat) 
            rst = self.fc_neigh(h_neigh)

            # activation if any
            if self.activation is not None:
                rst = self.activation(rst)

            # normalization if any
            if self.norm is not None:
                rst = self.norm(rst)
            return rst

In [68]:
g = dgl.DGLGraph()
g.add_nodes(5)

# 0 1 2 3 4
# | | | | |
# V V V V V 
# 1 2 3 4 2
g.add_edges([0, 1, 2, 3, 4], [1, 2, 3, 4, 0])

# node featues
u_features = torch.rand((5, 5))
v_features  = torch.rand((5, 5))

# edge values
edges = torch.tensor([[3.], [4.], [3.], [4.], [3.]])

# 5 features for each node -> 5 out features per node
conv = MeSSAGEConv(5, 5)
res = conv(g, (u_features, v_features), edges)
res

tensor([[ 0.5917, -1.2986,  4.6642,  0.4827,  1.3296],
        [-0.0451, -0.5946,  4.2177,  1.6446,  0.4137],
        [ 0.2701, -1.9960,  2.8608,  2.8209,  1.3942],
        [ 0.3006, -0.9945,  3.4080,  2.6003,  0.5482],
        [ 0.3727, -1.6936,  4.4544,  0.2479,  1.7237]],
       grad_fn=<AddmmBackward>)

In [70]:
class MessageNet(nn.Module):
    def __init__(self):
        super(Net, self).__init__()
        
        # Using GCNLayers
        # 1433 node features (each node)
        self.layer1 = MeSSAGEConv(5, 16)
        self.layer2 = MeSSAGEConv(16, 5)

    def forward(self, g, features, edges):

        # Continuous output
        x = F.relu(self.layer1(g, features, edges))
        x = self.layer2(g, x, edges)
        return x

net = Net()

In [72]:
net(g, (u_features, v_features), edges)

tensor([[ 3.1789, -3.6413, -8.9632, 21.3088, -5.5338],
        [ 2.9888, -2.1161, -9.3941, 20.4855, -7.0903],
        [ 2.1981, -2.8307, -9.6823, 21.5778, -7.1989],
        [ 2.1544, -3.7740, -7.8020, 20.6272, -4.4232],
        [ 2.9180, -4.8823, -7.2430, 20.8156, -2.6052]],
       grad_fn=<AddmmBackward>)