In this week, you are required to implement a toy GATConv and SAGEConv based on document. Also, you need to implement both in PyG and DGL. In this work, you will get a further understanding of tensor-centric in PyG and graph-centric in DGL.

# PYG 复现GATConv和SAGEConv

In [35]:
import numpy as np
from torch_geometric.utils import add_self_loops, degree
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch_geometric.nn import MessagePassing
from torch_geometric.utils import add_self_loops, softmax

In [36]:
class PyG_GATConv(MessagePassing):
    def __init__(self, in_channels, out_channels):
        super(PyG_GATConv, self).__init__(aggr='mean')
        self.lin = nn.Linear(in_channels, out_channels)
        self.att = nn.Parameter(torch.Tensor(1, 2 * out_channels))
        self.reset_parameters()

    def reset_parameters(self):
        nn.init.xavier_uniform_(self.lin.weight)
        nn.init.xavier_uniform_(self.att)

    def forward(self, x, edge_index):
        edge_index, _ = add_self_loops(edge_index, num_nodes=x.size(0))

        x = self.lin(x)

        return self.propagate(edge_index, x=x)

    def message(self, x_i, x_j, edge_index_i, num_nodes):
        alpha = torch.cat([x_i, x_j], dim=-1)
        alpha = (alpha * self.att).sum(dim=-1)
        alpha = F.leaky_relu(alpha, negative_slope=0.2)
        alpha = softmax(alpha, edge_index_i, num_nodes)

        return x_j * alpha.view(-1, 1)

    def update(self, aggr_out):
        return aggr_out

In [37]:
class PyG_SAGEConv(MessagePassing):
    def __init__(self, in_channels, out_channels):
        super(PyG_SAGEConv, self).__init__(aggr='mean')
        self.lin = nn.Linear(in_channels, out_channels)
        self.reset_parameters()

    def reset_parameters(self):
        nn.init.xavier_uniform_(self.lin.weight)

    def forward(self, x, edge_index):
        edge_index, _ = add_self_loops(edge_index, num_nodes=x.size(0))

        x = self.lin(x)

        return self.propagate(edge_index, x=x)

    def message(self, x_j):
        return x_j

    def update(self, aggr_out):
        return aggr_out

In [38]:
edge_index = torch.tensor([[0,1,1,2,2,4],[2,0,2,3,4,3]])
x = torch.ones((5, 8))
conv = PyG_GATConv(8, 4)
output = conv(x, edge_index)
print(output)
conv = PyG_SAGEConv(8, 4)
output = conv(x, edge_index)
print(output)

tensor([[-0.4938,  0.6922, -0.7242,  0.0556],
        [-0.9877,  1.3844, -1.4483,  0.1112],
        [-0.3292,  0.4615, -0.4828,  0.0371],
        [-0.3292,  0.4615, -0.4828,  0.0371],
        [-0.4938,  0.6922, -0.7242,  0.0556]], grad_fn=<DivBackward0>)
tensor([[ 0.0930, -3.0377,  0.0126, -1.6848],
        [ 0.0930, -3.0377,  0.0126, -1.6848],
        [ 0.0930, -3.0377,  0.0126, -1.6848],
        [ 0.0930, -3.0377,  0.0126, -1.6848],
        [ 0.0930, -3.0377,  0.0126, -1.6848]], grad_fn=<DivBackward0>)


# DGL 复现GATConv和SAGEConv

In [39]:
import torch
import numpy as np
import torch.nn as nn
import dgl
import dgl.function as fn
from torch_geometric.nn.conv import MessagePassing


from torch import nn
from torch.nn import functional as F

from dgl.utils import DGLError
from dgl.utils import check_eq_shape, expand_as_pair
from dgl.nn import SAGEConv

In [40]:
import torch
from torch import nn
from torch.nn import functional as F

class DGL_SAGEConv(nn.Module):
    def __init__(self, in_feats, out_feats, aggregator_type='mean'):
        super(DGL_SAGEConv, self).__init__()
        self.aggregator_type = aggregator_type
        self.fc = nn.Linear(in_feats * 2, out_feats)
        self.relu = nn.ReLU()
        self.norm = nn.LayerNorm(out_feats)

    def forward(self, graph, h):
        #聚合
        with graph.local_scope():
            if self.aggregator_type == 'mean':
                graph.ndata['h'] = h
                graph.update_all(fn.copy_u('h', 'm'), fn.mean('m', 'h_neigh'))
            elif self.aggregator_type == 'gcn':
                degs = graph.in_degrees().float().clamp(min=1)
                norm = torch.pow(degs, -0.5)
                norm = norm.to(h.device).unsqueeze(1)
                graph.ndata['h'] = h * norm
                graph.update_all(fn.copy_u('h', 'm'), fn.sum('m', 'h_neigh'))
                graph.ndata['h_neigh'] = graph.ndata['h_neigh'] * norm
            elif self.aggregator_type == 'pool' or self.aggregator_type == 'lstm':
                raise NotImplementedError
            else:
                raise KeyError('Aggregator type {} not recognized.'.format(self.aggregator_type))

            h_neigh = graph.ndata['h_neigh']
            h_concat = torch.cat([h, h_neigh], dim=1)
            h_prime = self.fc(h_concat)
            
            #ReLU  norm
            h_prime = self.relu(h_prime)
            h_prime = self.norm(h_prime)
            
            return h_prime

In [41]:
import dgl
import torch
import torch.nn as nn
import torch.nn.functional as F

class DGL_GATConv(nn.Module):
    def __init__(self, in_channel, out_channel):
        super(DGL_GATConv, self).__init__()
        # 定义可训练的权重
        self.fc = nn.Linear(in_channel, out_channel, bias=False)
        # 注意力权重
        self.attn_fc = nn.Linear(2*out_channel, 1, bias=False)

    def edge_attention(self, edges):
        # 计算边的注意力分数
        z2 = torch.cat([edges.src['z'], edges.dst['z']], dim=1)
        a = self.attn_fc(z2)
        return {'e': F.leaky_relu(a)}

    def message_func(self, edges):
        # 应用注意力权重
        return {'z': edges.src['z'], 'e': edges.data['e']}

    def reduce_func(self, nodes):
        # 聚合邻居节点的特征
        alpha = F.softmax(nodes.mailbox['e'], dim=1)
        h = torch.sum(alpha * nodes.mailbox['z'], dim=1)
        return {'h': h}

    def forward(self, g, h):
        # 应用线性变换
        z = self.fc(h)
        g.ndata['z'] = z

        # 计算注意力
        g.apply_edges(self.edge_attention)

        # 聚合信息
        g.update_all(self.message_func, self.reduce_func)

        # 获取最终的特征表示
        return g.ndata.pop('h')


In [42]:
src = torch.tensor([0, 1, 1, 2, 2, 4])
dst = torch.tensor([2, 0, 2, 3, 4, 3])
h = torch.ones((5, 8))
g = dgl.graph((src, dst))

conv = DGL_GATConv(8, 4)
output = conv(g, h)
print(output)

conv = DGL_SAGEConv(8, 4)
output = conv(g, h)
print(output)

tensor([[-0.1723, -0.2970, -0.4610,  0.3394],
        [ 0.0000,  0.0000,  0.0000,  0.0000],
        [-0.1723, -0.2970, -0.4610,  0.3394],
        [-0.1723, -0.2970, -0.4610,  0.3394],
        [-0.1723, -0.2970, -0.4610,  0.3394]], grad_fn=<IndexCopyBackward0>)
tensor([[-1.0420,  0.2558,  1.5199, -0.7337],
        [-0.8681, -0.4498,  1.7018, -0.3840],
        [-1.0420,  0.2558,  1.5199, -0.7337],
        [-1.0420,  0.2558,  1.5199, -0.7337],
        [-1.0420,  0.2558,  1.5199, -0.7337]],
       grad_fn=<NativeLayerNormBackward0>)


# GraphConv问题回答：

(1)公式对应：  
消息部分：mailBox=(eij/cij)*W*hj  
聚合部分：把所有消息进行求和  
更新部分：加上偏差并通过激活函数进行更新  


# GATConv问题回答：

（1）公式对应：   
消息:W*h_j   
聚合：先求出注意力系数 \e_ij\,然后,按照exp(e_ij)进行分配权重，按照权重，对邻居特征，进行特征求和   
更新:激活函数  

# SAGEConv问题回答：

（1）norm对应:  
公式第一步：利用邻居和自己的特征（如果边特征给出也包括在内）线性变换，并分别和i的特征concat在一起 并实现聚合  
公式第二部：将聚合后的特征和W参数进行操作，更新  
公式第三部：按照给定的norm范式，再次进行更新  

（2）concat对应：

在源码中：rst = self.fc_self(h_self) + h_neigh  
是将线性变换后 自身节点的特征和 线性变换后邻居节点的特征 进行相加；而不是连接

所以对于W参数而言，在源码中的方式，shape为(feature_in,feature_out)

但是，如果按照concat连接的方式，则concat后的矩阵为w*feature_in维的向量，W也应该变为(2*feature_in,feature_out)