In [2]:
# python
import os
import math
import csv

# random
import random
#data analysis libraries
import numpy as np
import pandas as pd

# machine learning
import sklearn

# deep learning
import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader, random_split
import torch.nn.functional as F
import pytorch_lightning as pl

# graph dl
import networkx as nx
# import torch_geometric
from torch_geometric.nn import conv
from torch_geometric import utils

# For plotting learning curve
from torch.utils.tensorboard import SummaryWriter

#visualization libraries
import matplotlib.pyplot as plt
import seaborn as sns
%matplotlib inline

# For Progress Bar
from tqdm import tqdm

#ignore warnings
import warnings

warnings.filterwarnings('ignore')

# auto load change
%load_ext autoreload
%autoreload 2

In [65]:
import torch
from torch_geometric.data import Data
from torch_geometric.nn.conv import MessagePassing

X = torch.tensor([[0.0, 0.0, 0.0], [2.0, 1.0, 1.0], [2.0, 2.0, 2.0]])
edge_index = torch.tensor([[0, 1, 2, 0], [1, 2, 0, 2]])
graph = Data(x=X, edge_index=edge_index)

### GCNConv
`row, col = edge_index`
+ message: x_j[E, out_channels] according to row(outgoing node point to other node)
+ $D$: use `deg[row]`
+ notice `inf` use `deg[deg == float('inf')] = 0`

In [73]:
class Gcnconv(MessagePassing):
    def __init__(self, in_channels, out_channels):
        super().__init__()
        self.lin = nn.Linear(in_channels, out_channels)

    def _get_norm(self, edge_index):
        edge_index, _ = utils.add_self_loops(edge_index)
        row, col = edge_index
        deg = torch.bincount(col)
        deg[deg == float('inf')] = 0
        norm = deg[row] ** (-0.5) * torch.sqrt(deg[col])  ** (-0.5)
        return norm, edge_index

    def forward(self, X, edge_index):
        H = self.lin(X)
        norm, edge_index = self._get_norm(edge_index)
        H = self.propagate(edge_index, x=H, norm=norm)
        return H

    def message(self, x_j, norm):  # how to confirm norm is right
        return norm.reshape(-1, 1) * x_j


In [74]:
net = Gcnconv(3, 5)
net(X, edge_index)

tensor([[-0.4896, -0.5754,  0.0499, -0.9470, -0.2284],
        [-0.5525, -0.3090,  0.1347, -1.0212, -0.0999],
        [-0.9325, -0.7956,  0.2554, -1.5310, -0.2120]],
       grad_fn=<ScatterAddBackward0>)