In [5]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch_geometric.data import Data
from torch_geometric.nn import GCNConv, GraphConv, SAGEConv, global_mean_pool

In [65]:
class BaselineGNN(nn.Module):
    def __init__(self, in_channels, hidden_channels=64, num_layers=2):
        super().__init__()
        self.convs = nn.ModuleList()
        self.bns = nn.ModuleList()

        self.convs.append(SAGEConv(in_channels, hidden_channels))
        self.bns.append(nn.BatchNorm1d(hidden_channels))
        for i in range(num_layers-1):
            self.convs.append(SAGEConv(hidden_channels, hidden_channels))
            self.bns.append(nn.BatchNorm1d(hidden_channels))

        self.pool = global_mean_pool
        
        self.head = nn.Sequential(
            nn.Linear(hidden_channels, hidden_channels//2),
            nn.ReLU(),
            nn.Linear(hidden_channels//2, 1)
        )

    def forward(self, data):
        x, edge_index = data.x, data.edge_index
        print(x.dtype, edge_index.dtype)
        for conv, bn in zip(self.convs, self.bns):
            x = conv(x, edge_index)
            x = bn(x)
            x = F.relu(x)
        batch = data.batch if hasattr(data, "batch") else torch.zeros(x.size(0), dtype=torch.long, device=x.device)
        h = self.pool(x, batch)
        out = self.head(h).squeeze(-1)
        return out


In [66]:
model = BaselineGNN(32)

In [71]:
data = Data(x = torch.randint(0, 32, (4, 32)).float(), 
            edge_index = torch.randint(0, 4, (2, 30)))

In [72]:
data.edge_index

tensor([[3, 1, 1, 2, 1, 3, 2, 3, 3, 2, 3, 0, 2, 1, 3, 3, 2, 1, 1, 1, 3, 3, 2, 2,
         0, 1, 1, 3, 3, 0],
        [0, 1, 1, 0, 3, 1, 1, 1, 0, 0, 1, 0, 1, 2, 2, 0, 0, 3, 2, 1, 2, 3, 3, 3,
         0, 2, 0, 0, 1, 3]])

In [73]:
data.x

tensor([[19., 28.,  6.,  3.,  3.,  2., 11., 25.,  5., 26., 18., 28., 28., 19.,
         26.,  6., 13., 26., 12., 27., 11., 21.,  0.,  4., 20., 30., 22., 24.,
          4.,  8., 17., 29.],
        [ 0., 24.,  0.,  9., 11.,  7., 24., 28., 26., 22., 18., 22.,  8., 14.,
         22.,  9., 21., 17., 23., 21., 23.,  0., 24., 30., 15., 21., 23., 11.,
          6., 16., 31., 24.],
        [16.,  6., 16.,  5.,  0., 23., 28.,  0., 13., 22., 31., 25.,  7., 16.,
         16., 12.,  7., 10.,  2., 28.,  8., 21., 11., 20., 28.,  7.,  7., 28.,
         26., 20.,  5., 26.],
        [16., 10.,  2., 17.,  6.,  5.,  8., 24., 10., 20., 16.,  5., 29., 26.,
          1., 19.,  7., 21., 29., 27., 30., 22.,  7.,  0., 16.,  6., 26., 11.,
          1.,  4., 25.,  9.]])

In [74]:
model(data)

torch.float32 torch.int64


tensor([0.1311], grad_fn=<SqueezeBackward1>)