In [10]:
# Necessary imports

from pathlib import Path

from torch.utils.data import Dataset  # not the one from PyG!
from torch_geometric.loader import DataLoader
import torch

from torch_geometric.nn import GCNConv
from torch.nn.functional import relu, binary_cross_entropy_with_logits


In [16]:
# MyDataset class for handling file loading

class MyDataset(Dataset):
    def __init__(self, path: Path):
        super().__init__()
        self.graphs = list(path.glob("*.pt"))
    
    def __getitem__(self, idx):
        return torch.load(self.graphs[idx])
    
    def __len__(self) -> int:
        return len(self.graphs)

In [46]:
# Load first batch (batch_1_0)

dataset = MyDataset(Path("./dataset/batch_1_0"))

# Preliminary analysis of dataset behavior

print("Number of graphs in set:", len(dataset))
print("Set of keys in each graph", dataset[0].keys)
print("Number of nodes in first graph:", dataset[0].num_nodes)
print("Number of edges in first graph:", dataset[0].num_edges)
print("Number of node features:", dataset[0].num_node_features)
print("Number of edge features:", dataset[0].num_edge_features)
print("Is this graph undirected? ", dataset[0].is_undirected())

print("Number of edge features:", dataset[0]['y'])

loader = DataLoader(dataset, batch_size=32)

Number of graphs in set: 1000
Set of keys in each graph ['y', 'edge_attr', 'x', 'edge_index']
Number of nodes in first graph: 419
Number of edges in first graph: 4882
Number of node features: 6
Number of edge features: 4
Is this graph undirected?  False
Number of edge features: tensor([1., 1., 1.,  ..., 1., 1., 1.])


In [51]:
class GNN(torch.nn.Module):
    def __init__(self, sample, hidden_layers = [8]):
        super().__init__()

        self.conv = []

        if len(hidden_layers) == 0:
            # Case when no hidden layers are passed
            self.conv.append(GCNConv(sample.num_node_features + sample.num_edge_features, 1))

        else:
            # Case when hidden layers are passed
            self.conv.append(GCNConv(sample.num_node_features + sample.num_edge_features, hidden_layers[0]))
            for i in range(len(hidden_layers) - 1):
                self.conv.append(GCNConv(hidden_layers[i], hidden_layers[i + 1]))    
            self.conv.append(GCNConv(hidden_layers[-1], 1))

    def forward(self, data):

        x, edge_attr, edge_index = data.x, data.edge_attr, data.edge_index
        '''
        # Concatenate the edge attributes with the source node attributes
        src_node_attrs = x[edge_index[0]]
        conv_attributes = torch.cat((edge_attr, src_node_attrs), dim=1)    
        '''
    
        for i in range(len(self.conv)):
            x = self.conv[i](x, edge_index, edge_attr)
            if i < len(self.conv) - 1:
                x = relu(x)

        return x

In [48]:
device = torch.device('cpu')


In [50]:
edge_index = torch.tensor([[0, 1, 2, 3, 4, 5, 0, 1, 3, 4],
                           [1, 2, 3, 4, 5, 0, 2, 3, 5, 5]])
edge_attr = torch.randn((10, 5))
x = torch.randn((6, 3))

src_node_attrs = x[edge_index[0]]
concatenated_attrs = torch.cat((edge_attr, src_node_attrs), dim=1)
print(concatenated_attrs)

tensor([[ 2.8229,  1.1472, -1.4095,  0.9064,  0.4386,  0.8614, -0.0332,  1.3438],
        [ 0.5211,  0.3819,  0.7579,  0.2391, -1.9337, -1.1964,  1.8479,  1.9858],
        [-0.3307, -0.4691, -0.0952, -0.9481,  0.0325, -0.3639,  1.4590, -2.4435],
        [-0.2916, -0.3574, -0.2228,  1.8745,  2.1522, -0.6042,  1.5561,  2.8058],
        [-0.0472,  0.1039, -1.5517, -0.2899,  0.0590, -1.6996,  0.0304,  1.9987],
        [ 0.9412, -0.1166,  0.4736, -0.8124,  0.6360,  0.0360, -1.2121, -1.7497],
        [ 2.5634,  1.2080,  1.6952, -1.4127,  0.7093,  0.8614, -0.0332,  1.3438],
        [ 0.4937, -0.5360,  0.6529,  0.6682, -1.3964, -1.1964,  1.8479,  1.9858],
        [-0.0542, -1.4778,  0.0411,  1.8971,  1.1198, -0.6042,  1.5561,  2.8058],
        [ 1.5709,  0.6317,  1.0711, -0.1319,  0.6079, -1.6996,  0.0304,  1.9987]])
