In [1]:
import torch_geometric
import matplotlib.pyplot as plt
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch_geometric.nn import GCNConv, TopKPooling, global_mean_pool, GraphUNet
from torch_geometric.data import Batch
from torch_geometric.utils import to_dense_adj


from utils.data import GraphDataModule
from utils.training import train_model, evaluate_model

In [2]:
data_module = GraphDataModule("./data", num_workers=1)
train_loader = data_module.train_dataloader()
val_loader = data_module.val_dataloader()

Loading data from disk


Converting vectors to graphs: 100%|██████████| 133/133 [00:00<00:00, 974.89it/s]
Converting vectors to graphs: 100%|██████████| 34/34 [00:00<00:00, 747.66it/s]


In [7]:
class SuperResMLP(nn.Module):
    def __init__(self, num_nodes_input, num_nodes_output, num_hidden_nodes, n_layers):
        super().__init__()
        self.layers = nn.ModuleList(
            [
                nn.Flatten(start_dim=1),
                nn.Linear(in_features=num_nodes_input**2, out_features=num_hidden_nodes**2),
                nn.BatchNorm1d(num_features=num_hidden_nodes**2),
                nn.Dropout(p=0.1),
                nn.ReLU(),
            ]
        )
        for _ in range(n_layers - 1):
            self.layers.append(
                nn.Linear(in_features=num_hidden_nodes**2, out_features=num_hidden_nodes**2)
            )
            self.layers.append(nn.BatchNorm1d(num_features=num_hidden_nodes**2))
            self.layers.append(nn.Dropout(p=0.1))
            self.layers.append(nn.ReLU())

        self.layers.append(nn.Linear(in_features=num_hidden_nodes**2, out_features=num_nodes_output**2))
        self.layers.append(nn.Unflatten(dim=1, unflattened_size=(num_nodes_output, num_nodes_output)))

    @property
    def device(self):
        return next(self.parameters()).device

    def forward(self, samples: Batch):
        x = to_dense_adj(batch.edge_index, batch=batch.batch)
        for layer in self.layers:
            x = layer(x)
        return x


In [4]:
batch = next(iter(train_loader))
input_dim = batch[0].x.shape[0]
output_dim = batch[0].y.shape[0]

In [6]:
model = SuperResMLP(input_dim, output_dim, num_hidden_nodes=(input_dim+output_dim)//2, n_layers=3)
criterion = nn.MSELoss()

train_model(
    model=model, 
    train_dataloader=train_loader, 
    val_dataloader=val_loader,
    criterion=criterion,
    num_epochs=10,
)

KeyboardInterrupt: 