In [104]:
# Necessary imports

from pathlib import Path

from torch.utils.data import Dataset  # not the one from PyG!
from torch_geometric.loader import DataLoader
import torch

from torch_geometric.nn import NNConv
from torch.nn.functional import relu, binary_cross_entropy_with_logits, linear
from torch_geometric.utils import add_self_loops, degree
import torch.nn as nn


In [16]:
# MyDataset class for handling file loading

class MyDataset(Dataset):
    def __init__(self, path: Path):
        super().__init__()
        self.graphs = list(path.glob("*.pt"))
    
    def __getitem__(self, idx):
        return torch.load(self.graphs[idx])
    
    def __len__(self) -> int:
        return len(self.graphs)

In [46]:
# Load first batch (batch_1_0)

dataset = MyDataset(Path("./dataset/batch_1_0"))

# Preliminary analysis of dataset behavior

print("Number of graphs in set:", len(dataset))
print("Set of keys in each graph", dataset[0].keys)
print("Number of nodes in first graph:", dataset[0].num_nodes)
print("Number of edges in first graph:", dataset[0].num_edges)
print("Number of node features:", dataset[0].num_node_features)
print("Number of edge features:", dataset[0].num_edge_features)
print("Is this graph undirected? ", dataset[0].is_undirected())

print("Number of edge features:", dataset[0]['y'])

loader = DataLoader(dataset, batch_size=32)

Number of graphs in set: 1000
Set of keys in each graph ['y', 'edge_attr', 'x', 'edge_index']
Number of nodes in first graph: 419
Number of edges in first graph: 4882
Number of node features: 6
Number of edge features: 4
Is this graph undirected?  False
Number of edge features: tensor([1., 1., 1.,  ..., 1., 1., 1.])


In [127]:
class GNN(torch.nn.Module):
    def __init__(self, sample, hidden_layers = [8], intra_hidden = 10):
        super().__init__()

        self.conv = []
        self.linear_weights = torch.randn(2 * sample.num_node_features + sample.num_edge_features)

        if len(hidden_layers) == 0:
            # Case when no hidden layers are passed

            nn = torch.nn.Sequential(
                torch.nn.Linear(sample.num_edge_features, intra_hidden),
                torch.nn.ReLU(),
                torch.nn.Linear(intra_hidden, sample.num_node_features)
            )
            self.conv.append(NNConv(sample.num_node_features, 1, nn))

        else:
            # Case when hidden layers are passed

            nn = self.simple_nn(sample.num_edge_features, intra_hidden, sample.num_node_features * hidden_layers[0])
            self.conv.append(NNConv(sample.num_node_features, hidden_layers[0], nn))

            for i in range(len(hidden_layers) - 1):
                nn = self.simple_nn(sample.num_edge_features, intra_hidden, hidden_layers[i] * hidden_layers[i + 1])
                self.conv.append(NNConv(hidden_layers[i], hidden_layers[i + 1], nn))    
            
            nn = self.simple_nn(sample.num_edge_features, intra_hidden, intra_hidden * hidden_layers[-1])
            self.conv.append(NNConv(hidden_layers[-1], intra_hidden, nn))

        # Last layer for making edge predictions from node + edge data
        self.edge_predictor = self.simple_nn(sample.num_edge_features + intra_hidden, intra_hidden, 1)

        self.conv_module = torch.nn.ModuleList(self.conv)
    
    def simple_nn(self, a_1, a_2, a_3):
        return torch.nn.Sequential(
            torch.nn.Linear(a_1, a_2),
            torch.nn.ReLU(),
            torch.nn.Linear(a_2, a_3)
        )

    def forward(self, data):

        x, edge_attr, edge_index = data.x, data.edge_attr, data.edge_index
        '''
        # Concatenate the edge attributes with the source node attributes
        '''
    
        for i in range(len(self.conv)):
            x = relu(self.conv[i](x, edge_index, edge_attr))

        src_node_attrs = x[edge_index[0]]
        conv_attributes = torch.cat((edge_attr, src_node_attrs), dim=1)
        edge_scores = self.edge_predictor(conv_attributes)

        return edge_scores.squeeze()

In [128]:
device = torch.device('cpu')
model = GNN(dataset[0]).to(device)
data = dataset[0].to(device)
optimizer = torch.optim.Adam(model.parameters(), lr=0.01, weight_decay=5e-4)

model.train()
for epoch in range(200):
    optimizer.zero_grad()
    out = model(data)
    loss = binary_cross_entropy_with_logits(out, data.y)
    loss.backward()
    optimizer.step()

1 of 2
2 of 2
1 of 2
2 of 2
1 of 2
2 of 2
1 of 2
2 of 2
1 of 2
2 of 2
1 of 2
2 of 2
1 of 2
2 of 2
1 of 2
2 of 2
1 of 2
2 of 2
1 of 2
2 of 2
1 of 2
2 of 2
1 of 2
2 of 2
1 of 2
2 of 2
1 of 2
2 of 2
1 of 2
2 of 2
1 of 2
2 of 2
1 of 2
2 of 2
1 of 2
2 of 2
1 of 2
2 of 2
1 of 2
2 of 2
1 of 2
2 of 2
1 of 2
2 of 2
1 of 2
2 of 2
1 of 2
2 of 2
1 of 2
2 of 2
1 of 2
2 of 2
1 of 2
2 of 2
1 of 2
2 of 2
1 of 2
2 of 2
1 of 2
2 of 2
1 of 2
2 of 2
1 of 2
2 of 2
1 of 2
2 of 2
1 of 2
2 of 2
1 of 2
2 of 2
1 of 2
2 of 2
1 of 2
2 of 2
1 of 2
2 of 2
1 of 2
2 of 2
1 of 2
2 of 2
1 of 2
2 of 2
1 of 2
2 of 2
1 of 2
2 of 2
1 of 2
2 of 2
1 of 2
2 of 2
1 of 2
2 of 2
1 of 2
2 of 2
1 of 2
2 of 2
1 of 2
2 of 2
1 of 2
2 of 2
1 of 2
2 of 2
1 of 2
2 of 2
1 of 2
2 of 2
1 of 2
2 of 2
1 of 2
2 of 2
1 of 2
2 of 2
1 of 2
2 of 2
1 of 2
2 of 2
1 of 2
2 of 2
1 of 2
2 of 2
1 of 2
2 of 2
1 of 2
2 of 2
1 of 2
2 of 2
1 of 2
2 of 2
1 of 2
2 of 2
1 of 2
2 of 2
1 of 2
2 of 2
1 of 2
2 of 2
1 of 2
2 of 2
1 of 2
2 of 2
1 of 2
2 of 2
1 of 2