In [2]:
import os.path as osp
from math import ceil

import torch
import torch.nn.functional as F
from torch.nn import Linear

from torch_geometric.datasets import TUDataset
from torch_geometric.loader import DataLoader
from torch_geometric.nn import DenseGraphConv, GCNConv, dense_mincut_pool
from torch_geometric.utils import to_dense_adj, to_dense_batch

import sys

sys.path.append("../")
import utils

import torch

import torch
import torch.nn.functional as F
import torch_geometric as tg

from torch_geometric.loader import DataLoader
from sklearn.metrics import f1_score, accuracy_score

early_stop_thresh = 25
best_macro_f1 = -1


train_dataset = utils.GraphDataset(
    "../data/", "MixedShapesSmallTrain_TRAIN", True, n_quantiles=100
)
train_loader = DataLoader(train_dataset, batch_size=32, shuffle=True)

test_dataset = utils.GraphDataset(
    "../data/", "MixedShapesSmallTrain_TEST", True, n_quantiles=100
)
test_loader = DataLoader(test_dataset, batch_size=64)


class Net(torch.nn.Module):
    def __init__(self, in_channels, out_channels, hidden_channels=32):
        super().__init__()

        self.conv1 = GCNConv(in_channels, hidden_channels)
        num_nodes = ceil(0.5 * 100)
        self.pool1 = Linear(hidden_channels, num_nodes)

        self.conv2 = DenseGraphConv(hidden_channels, hidden_channels)
        num_nodes = ceil(0.5 * num_nodes)
        self.pool2 = Linear(hidden_channels, num_nodes)

        self.conv3 = DenseGraphConv(hidden_channels, hidden_channels)

        self.lin1 = Linear(hidden_channels, hidden_channels)
        self.lin2 = Linear(hidden_channels, out_channels)

    def forward(self, x, edge_index, batch):
        x = self.conv1(x, edge_index).relu()

        x, mask = to_dense_batch(x, batch)
        adj = to_dense_adj(edge_index, batch)
        s = self.pool1(x)
        x, adj, mc1, o1 = dense_mincut_pool(x, adj, s, mask)

        x = self.conv2(x, adj).relu()
        s = self.pool2(x)
        x, adj, mc2, o2 = dense_mincut_pool(x, adj, s)

        x = self.conv3(x, adj)

        x = x.mean(dim=1)
        x = self.lin1(x).relu()
        x = self.lin2(x)
        return F.log_softmax(x, dim=-1), mc1 + mc2, o1 + o2


device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = Net(1, 5).to(device)
optimizer = torch.optim.Adam(model.parameters(), lr=0.001)
scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(
    optimizer, patience=5, mode="min", cooldown=2, factor=0.5, verbose=True
)


def train(epoch):
    model.train()
    loss_all = 0

    for data in train_loader:
        data = data.to(device)
        optimizer.zero_grad()
        out, mc_loss, o_loss = model(data.x, data.edge_index, data.batch)
        loss = F.nll_loss(out, data.y.view(-1)) + mc_loss + o_loss
        loss.backward()
        loss_all += data.y.size(0) * float(loss)
        optimizer.step()
    return loss_all / len(train_dataset)


@torch.no_grad()
def test(loader):
    model.eval()
    correct = 0
    loss_all = 0

    for data in loader:
        data = data.to(device)
        pred, mc_loss, o_loss = model(data.x, data.edge_index, data.batch)
        loss = F.nll_loss(pred, data.y.view(-1)) + mc_loss + o_loss
        loss_all += data.y.size(0) * float(loss)
        correct += int(pred.max(dim=1)[1].eq(data.y.view(-1)).sum())

    return loss_all / len(loader.dataset), correct / len(loader.dataset)


best_val_acc = test_acc = 0
best_val_loss = float("inf")
patience = start_patience = 50
for epoch in range(1, 2000):
    train_loss = train(epoch)
    _, train_acc = test(train_loader)
    val_loss, val_acc = test(test_loader)
    if val_loss < best_val_loss:
        test_loss, test_acc = test(test_loader)
        best_val_acc = val_acc
        patience = start_patience
    else:
        patience -= 1
        if patience == 0:
            break
    print(
        f"Epoch: {epoch:03d}, Train_Loss: {train_loss:02.4f},Test_Loss: {test_loss:02.4f},Train_acc: {train_acc:01.4f},Test_acc: {test_acc:01.4f}"
    )



Epoch: 001, Train_Loss: 2.1947,Test_Loss: 2.1892,Train_acc: 0.2000,Test_acc: 0.1889
Epoch: 002, Train_Loss: 2.1853,Test_Loss: 2.1889,Train_acc: 0.2000,Test_acc: 0.1889
Epoch: 003, Train_Loss: 2.1854,Test_Loss: 2.1915,Train_acc: 0.2000,Test_acc: 0.1724
Epoch: 004, Train_Loss: 2.1823,Test_Loss: 2.1931,Train_acc: 0.2000,Test_acc: 0.1724
Epoch: 005, Train_Loss: 2.1827,Test_Loss: 2.1940,Train_acc: 0.2000,Test_acc: 0.1724
Epoch: 006, Train_Loss: 2.1815,Test_Loss: 2.1912,Train_acc: 0.2000,Test_acc: 0.1724
Epoch: 007, Train_Loss: 2.1800,Test_Loss: 2.1887,Train_acc: 0.2000,Test_acc: 0.1724
Epoch: 008, Train_Loss: 2.1795,Test_Loss: 2.1861,Train_acc: 0.2000,Test_acc: 0.1724
Epoch: 009, Train_Loss: 2.1795,Test_Loss: 2.1813,Train_acc: 0.2000,Test_acc: 0.1724
Epoch: 010, Train_Loss: 2.1785,Test_Loss: 2.1790,Train_acc: 0.2000,Test_acc: 0.1724
Epoch: 011, Train_Loss: 2.1776,Test_Loss: 2.1790,Train_acc: 0.2000,Test_acc: 0.1724
Epoch: 012, Train_Loss: 2.1758,Test_Loss: 2.1795,Train_acc: 0.2000,Test_ac