In [1]:
import torch
from torch_geometric.loader import DataLoader
from _04_mnist_digits.graph_dataset import GraphDataset # Needed for loading pickled dataset
import torch_geometric.nn as nn
from torch_geometric.nn import summary

In [2]:
dataset = torch.load("_04_mnist_digits/data/graph_dataset.pt", weights_only=False)

example_x = dataset[0].x
if example_x.dim() == 2:
    num_features = example_x.shape[1]  # Set the number of features
else:
    num_features = 1

num_classes = len(set(dataset.labels.tolist()))  # Set the number of classes

datasets = {}
loaders = {}

for split, dataset in zip(["train", "val", "test"], torch.utils.data.random_split(dataset, (0.7, 0.15, 0.15))):
    datasets[split] = dataset
    loaders[split] = DataLoader(dataset, batch_size=(32 if split == "train" else len(dataset)), shuffle=(split == "train"))

In [3]:
from shared.models.GNNClassifier import GNNClassifier
from shared.training import train_classifier

# model = GNNClassifier(hidden_ch=128, 
#                       num_node_features=num_features,
#                       num_classes=num_classes,
#                       seed=6,
#                       fc_hidden_dim=64,
#                       fc_layers=3,)


from torch_geometric.nn import EdgeCNN


class Classifier(GNNClassifier):
    def __init__(self, in_channels, num_classes, num_layers, hidden_channels, **kwargs):
        super().__init__(num_node_features=in_channels,
                        num_classes=num_classes,
                        hidden_ch=hidden_channels,)
        self.edgecnn = EdgeCNN(in_channels=in_channels,
                               num_layers=num_layers,
                               hidden_channels=hidden_channels,
                               )

    def forward(self, x, edge_index, batch, edge_weight=None):
        x = self.edgecnn(x, edge_index, edge_weight=edge_weight)

        # Use PyTorch Geometric's global_max_pool which handles batching correctly
        if batch is None:
            # If no batch info, assume single graph
            x = torch.max(x, dim=0, keepdim=True).values
        else:
            # Proper batched global pooling
            x = nn.global_max_pool(x, batch)
        
        x = self.mlp(x)
        return x

model = Classifier(in_channels=1,
                num_classes=10,
                num_layers=3,
                hidden_channels=128,)

model.get_model_inputs_from_batch = lambda batch: (batch.x.unsqueeze(-1), batch.edge_index, batch.batch)


loss = torch.nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.parameters(), lr=0.001)

cnt = 0
for batch in loaders["train"]:
    print(batch)
    print([x.shape for x in model.get_model_inputs_from_batch(batch)])
    print(summary(model, *model.get_model_inputs_from_batch(batch)))
    
    cnt += 1
    if cnt >= 2:
        break

print(model)

DataBatch(x=[872], edge_index=[2, 2763], y=[32], edge_weight=[2763], batch=[872], ptr=[33])
[torch.Size([872, 1]), torch.Size([2, 2763]), torch.Size([872])]
+----------------------------+----------------------------+----------------+----------+
| Layer                      | Input Shape                | Output Shape   | #Param   |
|----------------------------+----------------------------+----------------+----------|
| Classifier                 | [872, 1], [2, 2763], [872] | [32, 10]       | 183,178  |
| ├─(convolutions)ModuleList | --                         | --             | 66,176   |
| │    └─(0)GraphConv        | --                         | --             | 384      |
| │    └─(1)GraphConv        | --                         | --             | 32,896   |
| │    └─(2)GraphConv        | --                         | --             | 32,896   |
| ├─(mlp)Sequential          | [32, 128]                  | [32, 10]       | 1,290    |
| │    └─(0)Linear           | [32, 128]           