## **Creating Message Passing Networks**
- Generalizing the convolution operator to irregular domains is typically expressed as a neighborhood aggregation or message passing scheme. 
- torch_geometric.nn.MessagePassing : base class, which helps in **creating such kinds of message passing graph neural networks** by automatically taking care of **message propagation**.
- The user only has to define **the functions ϕ , i.e. message()**, and **γ , .i.e. update()**, as well as the **aggregation scheme** to use, **.i.e. aggr='add', aggr='mean' or aggr='max'**.

### **Implementing the GCN Layer**
- Add self-loops to the adjacency matrix.
- Linearly transform node feature matrix.
- Compute normalization coefficients.
- Normalize node features in ϕ.
- Sum up neighboring node features ("add" aggregation).
- Return new node embeddings in γ.

In [1]:
import torch
from torch_geometric.nn import MessagePassing
from torch_geometric.utils import add_self_loops, degree

In [4]:
class GCNConv(MessagePassing):
    def __init__(self, in_channels, out_channels):
        super(GCNConv, self).__init__(aggr='add')  # "Add" aggregation.
        self.lin = torch.nn.Linear(in_channels, out_channels)

    def forward(self, x, edge_index):
        # x has shape [N, in_channels]
        # edge_index has shape [2, E]

        # Step 1: Add self-loops to the adjacency matrix.
        edge_index, _ = add_self_loops(edge_index, num_nodes=x.size(0))

        # Step 2: Linearly transform node feature matrix.
        x = self.lin(x)

        # Step 3: Compute normalization
        row, col = edge_index
        deg = degree(row, x.size(0), dtype=x.dtype)
        deg_inv_sqrt = deg.pow(-0.5)
        norm = deg_inv_sqrt[row] * deg_inv_sqrt[col]

        # Step 4-6: Start propagating messages.
        return self.propagate(edge_index, 
                              size=(x.size(0), x.size(0)), 
                              x=x,
                              norm=norm)

    def message(self, x_j, norm):
        # x_j has shape [E, out_channels]
        # Step 4: Normalize node features.
        return norm.view(-1, 1) * x_j

    def update(self, aggr_out):
        # aggr_out has shape [N, out_channels]
        # Step 6: Return new node embeddings.
        return aggr_out
    
conv = GCNConv(16, 32)
print(conv)

GCNConv(
  (lin): Linear(in_features=16, out_features=32, bias=True)
)


### **Implementing the Edge Convolution**

In [5]:
import torch
from torch.nn import Sequential as Seq, Linear, ReLU
from torch_geometric.nn import MessagePassing

In [6]:
class EdgeConv(MessagePassing):
    def __init__(self, in_channels, out_channels):
        super(EdgeConv, self).__init__(aggr='max') #  "Max" aggregation.
        self.mlp = Seq(Linear(2 * in_channels, out_channels),
                       ReLU(),
                       Linear(out_channels, out_channels))

    def forward(self, x, edge_index):
        # x has shape [N, in_channels]
        # edge_index has shape [2, E]

        return self.propagate(edge_index, size=(x.size(0), x.size(0)), x=x)

    def message(self, x_i, x_j):
        # x_i has shape [E, in_channels]
        # x_j has shape [E, in_channels]

        tmp = torch.cat([x_i, x_j - x_i], dim=1)  # tmp has shape [E, 2 * in_channels]
        return self.mlp(tmp)

    def update(self, aggr_out):
        # aggr_out has shape [N, out_channels]

        return aggr_out            

In [7]:
from torch_geometric.nn import knn_graph
# 동적 Edge Convolution이란 아마도 point cloud 같은 점(노드)만 존재하는 상태에 대하여 edge를 정의하기위해
# k-NN같은 알고리즘을 추가적으로 적용하는 방식을 '동적'이라고 표현한 것 같음 
class DynamicEdgeConv(EdgeConv):
    def __init__(self, in_channels, out_channels, k=6):
        super(DynamicEdgeConv, self).__init__(in_channels, out_channels)
        self.k = k

    def forward(self, x, batch=None):
        edge_index = knn_graph(x, self.k, batch, loop=False, flow=self.flow)
        return super(DynamicEdgeConv, self).forward(x, edge_index)

In [9]:
conv = DynamicEdgeConv(3, 128, k=6)
print(conv)

DynamicEdgeConv(
  (mlp): Sequential(
    (0): Linear(in_features=6, out_features=128, bias=True)
    (1): ReLU()
    (2): Linear(in_features=128, out_features=128, bias=True)
  )
)
