In [55]:
import networkx as nx
from torch_geometric.datasets import ZINC
from torch_geometric.data import DataLoader
import matplotlib.pyplot as plt
import torch
import torch.nn as nn
import torch.optim as op
from torch_geometric.nn import global_add_pool
from torch_geometric.datasets import QM9, ZINC
from torch_geometric.data import DataLoader
from torch_scatter import scatter_add

class BasicMPNNLayer(nn.Module):
    def __init__(self, num_hidden):
        super().__init__()
        self.message_mlp = nn.Linear(3 * num_hidden, num_hidden)
        self.update_mlp = nn.Linear(2 * num_hidden, num_hidden)

    def forward(self, h, edge_index, edge_attr):
        send, rec = edge_index
        h_send, h_rec = h[send], h[rec]
        messages = self.message_mlp(torch.cat((h_send, h_rec, edge_attr), dim=1))
        messages_agg = scatter_add(messages, rec, dim=0)
        out = self.update_mlp(torch.cat((h, messages_agg), dim=1))

        return out

In [56]:
class BasicMPNN(nn.Module):
    def __init__(self, feat_in, edge_feat_in, num_hidden, num_layers):
        super().__init__()
        self.embed = nn.Linear(feat_in, num_hidden)
        self.edge_embed = nn.Linear(edge_feat_in, num_hidden)
        self.layers = nn.ModuleList([BasicMPNNLayer(num_hidden) for _ in range(num_layers)])
        self.predict = nn.Linear(num_hidden, 1)

    def forward(self, graph):
        #print(graph.x, graph.pos)
        h, edge_index, edge_attr,  batch = graph.x, graph.edge_index, graph.edge_attr, graph.batch

        # h_nodes, h_edges, h_triangles
        h = h.float()
        h = self.embed(h)
        edge_attr = edge_attr.unsqueeze(1).float()
        edge_attr = self.edge_embed(edge_attr)

        for layer in self.layers:
            # h_nodes, h_edges, h_triangles = layer(..., edge_nodes_nodes, edge_nodes_edges, ...)
            h = h + nn.functional.relu(layer(h, edge_index, edge_attr))

        h_agg = global_add_pool(h, batch)
        final_prediction = self.predict(h_agg)

        return final_prediction.squeeze(1)

In [60]:
# transform = PEAddWR(....)
# transform = AddRandomWalkPE(walk_length=4)
data = ZINC('datasets/ZINC_basic') #QM9('datasets/QM9', pre_transform=transform)

train_loader = DataLoader(data[:10], batch_size=32)
val_loader = DataLoader(data[10:12], batch_size=32)
test_loader = DataLoader(data[12:14], batch_size=32)

model = BasicMPNN(1, 1, 32, 4)
optimizer = op.Adam(model.parameters(), lr=1e-4)
criterion = nn.L1Loss(reduce='sum')

# test

In [61]:
for _ in range(10):
    # train
    train_loss = 0
    model.train()
    for batch in train_loader:
        optimizer.zero_grad()
        label = batch.y  # alpha

        out = model(batch)

        loss = criterion(out, label)
        train_loss += loss.item()

        loss.backward()
        optimizer.step()

    print(train_loss / len(train_loader.dataset))
    


2.6160797119140624
2.3987836837768555
2.184476852416992
1.9739900588989259
1.7663282394409179
1.5605850219726562
1.3572912216186523
1.1572065353393555
0.9608661651611328
0.7666574954986572


In [17]:
batch

DataBatch(x=[253, 1], edge_index=[2, 546], edge_attr=[546], y=[10], batch=[253], ptr=[11])

In [26]:
data[50]

Data(x=[29, 1], edge_index=[2, 64], edge_attr=[64], y=[1])

In [25]:
data[50].edge_index

tensor([[ 0,  1,  1,  1,  2,  2,  2,  3,  3,  3,  4,  5,  5,  6,  6,  6,  7,  7,
          8,  8,  9,  9,  9, 10, 11, 11, 12, 12, 13, 13, 13, 14, 14, 15, 15, 16,
         16, 17, 17, 18, 18, 18, 19, 19, 19, 20, 20, 20, 21, 22, 22, 22, 23, 23,
         24, 24, 25, 25, 25, 26, 27, 27, 28, 28],
        [ 1,  0,  2, 19,  1,  3, 13,  2,  4,  5,  3,  3,  6,  5,  7, 12,  6,  8,
          7,  9,  8, 10, 11,  9,  9, 12,  6, 11,  2, 14, 18, 13, 15, 14, 16, 15,
         17, 16, 18, 13, 17, 19,  1, 18, 20, 19, 21, 22, 20, 20, 23, 28, 22, 24,
         23, 25, 24, 26, 27, 25, 25, 28, 22, 27]])

In [None]:

model.eval()
test_loss = 0
for batch in test_loader:
    out = model(batch)
    label = batch.y[:, 1]
    loss = criterion(out, label)

    test_loss += loss.item()

print(test_loss / len(test_loader.dataset))