In [30]:
%load_ext autoreload
%autoreload 2

import numpy
import matplotlib as mpl
import matplotlib.pyplot as plt

import torch
from torch import nn
from torch.nn import functional as F

from torch_geometric.datasets import CoraFull, TUDataset, Planetoid
from torch_geometric.loader import DataLoader
from torch_geometric import utils, nn as gnn, transforms as T
from torch_geometric.nn import GCNConv, MessagePassing
plt.style.use("seaborn-v0_8")

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


In [31]:
model = MessagePassing("mean", flow="source_to_target")

In [32]:
model

MessagePassing()

In [33]:
dataset = Planetoid("/mnt/dl/datasets/gnn/cora", name="cora")

In [34]:
dataset

cora()

In [35]:
len(dataset)

1

In [36]:
data = dataset[0]

In [37]:
data

Data(x=[2708, 1433], edge_index=[2, 10556], y=[2708], train_mask=[2708], val_mask=[2708], test_mask=[2708])

In [38]:
data.edge_index

tensor([[   0,    0,    0,  ..., 2707, 2707, 2707],
        [ 633, 1862, 2582,  ...,  598, 1473, 2706]])

In [39]:
data.x.ravel().unique()

tensor([0., 1.])

In [40]:
data.y.unique()

tensor([0, 1, 2, 3, 4, 5, 6])

In [41]:
data.edge_index

tensor([[   0,    0,    0,  ..., 2707, 2707, 2707],
        [ 633, 1862, 2582,  ...,  598, 1473, 2706]])

In [42]:
loader = DataLoader(dataset, batch_size=32)

In [43]:
next(iter(loader))

DataBatch(x=[2708, 1433], edge_index=[2, 10556], y=[2708], train_mask=[2708], val_mask=[2708], test_mask=[2708], batch=[2708], ptr=[2])

In [44]:
next(iter(loader)).batch

tensor([0, 0, 0,  ..., 0, 0, 0])

In [45]:
data.x

tensor([[0., 0., 0.,  ..., 0., 0., 0.],
        [0., 0., 0.,  ..., 0., 0., 0.],
        [0., 0., 0.,  ..., 0., 0., 0.],
        ...,
        [0., 0., 0.,  ..., 0., 0., 0.],
        [0., 0., 0.,  ..., 0., 0., 0.],
        [0., 0., 0.,  ..., 0., 0., 0.]])

In [46]:
data.edge_index

tensor([[   0,    0,    0,  ..., 2707, 2707, 2707],
        [ 633, 1862, 2582,  ...,  598, 1473, 2706]])

In [47]:
x = data.x.cuda()
edge_index = data.edge_index.cuda()
x.shape, edge_index.shape

(torch.Size([2708, 1433]), torch.Size([2, 10556]))

In [48]:
utils.add_self_loops(edge_index)

(tensor([[   0,    0,    0,  ..., 2705, 2706, 2707],
         [ 633, 1862, 2582,  ..., 2705, 2706, 2707]], device='cuda:0'),
 None)

In [49]:
(edge_index[0] == 0).sum()

tensor(3, device='cuda:0')

In [123]:
class GCNConv(MessagePassing):
    
    def __init__(self, in_dim, out_dim):
        # super().__init__(aggr="add")
        # super().__init__(aggr="add", flow="target_to_source") # ji format
        super().__init__(aggr="add") 
        self.lin = nn.Linear(in_dim, out_dim, bias=False)
        self.bias = nn.Parameter(torch.zeros((out_dim, )))
    
    def forward(self, x, edge_index):
        edge_index, _ = utils.add_self_loops(edge_index)
        x = self.lin(x)
        print("x size lin: ", x.size())
        
        # print(f"x[{edge_index[0][0]}] in message: ", x[edge_index[0][0]])
        print(f"x[{edge_index[0][0]}] in message: ", x[0])
        # print(f"x[{edge_index[1][0]}] in message: ", x[edge_index[1][0]])
        print(f"x[{edge_index[1][0]}] in message: ", x[633])
        
        row, col = edge_index
        deg = utils.degree(row, x.size(0)) # j index
        deg_inv_sqrt = deg.sqrt().reciprocal()
        assert deg_inv_sqrt.size(0) == x.size(0)
        deg_inv_sqrt[torch.isinf(deg_inv_sqrt)] = 0.
        norm = deg_inv_sqrt[row] * deg_inv_sqrt[col]
        
        print("Propagating edge size: ...", edge_index.size())
        out = self.propagate(edge_index, x=x, norm=norm)
        out += self.bias
        
        return out
    
    def message(self, x_j, x_i, norm):
        print("Messaging...")
        print("Norm size: ", norm.size())
        print(f"xi size: {x_i.size()}, xj size: {x_j.size()}")
        
        print("x_i[0] in message: ", x_i[0])
        print("x_j[0] in message: ", x_j[0])
        return norm.view(-1, 1) * x_j
        # return torch.zeros_like(x_j)
    
            
    def update(self, inputs):
        print("Updating....", inputs.size())
        return inputs

In [124]:
torch.manual_seed(0)
model = GCNConv(data.x.shape[-1], 8).cuda()

In [125]:
model(x, edge_index)

x size lin:  torch.Size([2708, 8])
x[0] in message:  tensor([ 0.0713, -0.0250,  0.0616,  0.0009,  0.1036, -0.0217,  0.0565, -0.0251],
       device='cuda:0', grad_fn=<SelectBackward0>)
x[633] in message:  tensor([ 0.0615, -0.0272, -0.0020, -0.0586,  0.0352,  0.0776, -0.0073, -0.0728],
       device='cuda:0', grad_fn=<SelectBackward0>)
Propagating edge size: ... torch.Size([2, 13264])
Messaging...
Norm size:  torch.Size([13264])
xi size: torch.Size([13264, 8]), xj size: torch.Size([13264, 8])
x_i[0] in message:  tensor([ 0.0615, -0.0272, -0.0020, -0.0586,  0.0352,  0.0776, -0.0073, -0.0728],
       device='cuda:0', grad_fn=<SelectBackward0>)
x_j[0] in message:  tensor([ 0.0713, -0.0250,  0.0616,  0.0009,  0.1036, -0.0217,  0.0565, -0.0251],
       device='cuda:0', grad_fn=<SelectBackward0>)
Updating.... torch.Size([2708, 8])


tensor([[-0.0031, -0.0112,  0.0281,  ...,  0.0170,  0.0280, -0.0118],
        [ 0.0273,  0.0128,  0.0390,  ...,  0.0147, -0.0033, -0.0430],
        [ 0.0014, -0.0163,  0.0444,  ..., -0.0294,  0.0096, -0.0534],
        ...,
        [ 0.0198,  0.0105,  0.0870,  ..., -0.1209,  0.0119, -0.0168],
        [-0.0265, -0.0301,  0.0267,  ..., -0.0185, -0.0077,  0.0171],
        [-0.0125, -0.0269,  0.0160,  ..., -0.0235,  0.0118,  0.0246]],
       device='cuda:0', grad_fn=<AddBackward0>)

In [122]:
dataset[0]

Data(x=[2708, 1433], edge_index=[2, 10556], y=[2708], train_mask=[2708], val_mask=[2708], test_mask=[2708])

In [113]:
data2 = T.NormalizeFeatures()(dataset[0])

In [114]:
dataset[0].x.ravel().unique()

tensor([0., 1.])

In [67]:
data2.x.ravel().unique()

tensor([0.0000, 0.0333, 0.0357, 0.0370, 0.0385, 0.0400, 0.0417, 0.0435, 0.0455,
        0.0476, 0.0500, 0.0526, 0.0556, 0.0588, 0.0625, 0.0667, 0.0714, 0.0769,
        0.0833, 0.0909, 0.1000, 0.1111, 0.1250, 0.1429, 0.1667, 0.2000, 0.2500,
        0.3333, 0.5000, 1.0000])

In [73]:
data2.x.std(1)

tensor([0.0088, 0.0055, 0.0060,  ..., 0.0062, 0.0070, 0.0073])

In [74]:
data2.x.sum(1)

tensor([1.0000, 1.0000, 1.0000,  ..., 1.0000, 1.0000, 1.0000])