In [2]:
from math import ceil

import torch
import torch.nn.functional as F
from torch.nn import Linear

from torch_geometric.datasets import TUDataset
from torch_geometric.loader import DataLoader
from torch_geometric.nn import DenseGraphConv, GCNConv, dense_mincut_pool
from torch_geometric.utils import to_dense_adj, to_dense_batch

import sys

sys.path.append("../")
import utils

import torch

import torch
import torch.nn.functional as F
import torch_geometric as tg

from torch_geometric.loader import DataLoader
from sklearn.metrics import f1_score, accuracy_score
import numpy as np

train_dataset = utils.GraphDataset("../data/", "GunPoint", quantile=False)
train_loader = DataLoader(train_dataset, batch_size=50, shuffle=True)

test_dataset = utils.GraphDataset("../data/", "GunPoint", False, quantile=False)
test_loader = DataLoader(test_dataset, batch_size=len(test_dataset))

(50, 150)
(150, 150)


In [3]:
class Net(torch.nn.Module):
    def __init__(self, in_channels, out_channels, hidden_channels=32):
        super().__init__()

        self.conv1 = GCNConv(in_channels, hidden_channels)
        num_nodes = 100
        self.pool1 = Linear(hidden_channels, num_nodes)

        self.conv2 = DenseGraphConv(hidden_channels, hidden_channels)
        num_nodes = ceil(2 / 3 * num_nodes)
        self.pool2 = Linear(hidden_channels, num_nodes)

        self.conv3 = DenseGraphConv(hidden_channels, hidden_channels)

        self.lin1 = Linear(hidden_channels, hidden_channels)
        self.lin2 = Linear(hidden_channels, out_channels)

    def forward(self, data):
        x, edge_index, batch, edge_weight = (
            data.x,
            data.edge_index,
            data.batch,
            data.edge_weight,
        )
        x = self.conv1(x, edge_index, edge_weight).relu()

        x, mask = to_dense_batch(x, batch)
        adj = to_dense_adj(edge_index, batch)
        s = self.pool1(x)
        x, adj, mc1, o1 = dense_mincut_pool(x, adj, s, mask)

        x = self.conv2(x, adj).relu()
        s = self.pool2(x)
        x, adj, mc2, o2 = dense_mincut_pool(x, adj, s)

        x = self.conv3(x, adj).relu()

        x = x.mean(dim=1)
        x = self.lin1(x).relu()
        x = self.lin2(x)
        return F.log_softmax(x, dim=-1), mc1 + mc2, o1 + o2

In [4]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = Net(13, 2).to(device)
optimizer = torch.optim.Adam(model.parameters(), lr=1e-4)

In [5]:
def train():
    model.train()

    total_loss = 0
    for data in train_loader:
        optimizer.zero_grad()
        y_out, mc_loss, o_loss = model(data)
        loss = F.nll_loss(y_out, data.y) + mc_loss + o_loss
        loss.backward()
        optimizer.step()
        total_loss += float(loss) * data.num_graphs
    return total_loss / len(train_loader.dataset)


@torch.no_grad()
def test(loader):
    model.eval()
    y_pred = []
    y_true = []
    loss = 0
    for data in loader:
        y_out, mc_loss, o_loss = model(data)
        y_pred.append(y_out.argmax(dim=-1))
        y_true.append(data.y)
        loss += float(F.nll_loss(y_out, data.y) * data.num_graphs + mc_loss + o_loss)
    y_pred = np.concatenate(y_pred)
    y_true = np.concatenate(y_true)
    return (
        f1_score(y_true=y_true, y_pred=y_pred, average="macro"),
        accuracy_score(y_true=y_true, y_pred=y_pred),
        loss / len(loader.dataset),
    )

In [6]:
best_macro_f1 = -1
train_losses = []
val_losses = []

In [7]:
import time

start = time.time()
for epoch in range(340):
    train()
    train_macro_f1, train_acc, train_loss = test(train_loader)
    test_macro_f1, test_acc, test_loss = test(test_loader)
    print(
        f"Epoch: {epoch:03d}, Train_Loss: {train_loss:02.4f},Test_Loss: {test_loss:02.4f},Train_f1: {train_macro_f1:01.4f},Test_f1: {test_macro_f1:01.4f},Train_acc: {train_acc:01.4f},Test_acc: {test_acc:01.4f}"
    )
    if test_macro_f1 > best_macro_f1:
        best_accuracy = test_macro_f1
        best_epoch = epoch
        torch.save(model.state_dict(), "../data/best_model.pth")
    train_losses.append(train_loss)
    val_losses.append(test_loss)
print(time.time() - start)

Epoch: 000, Train_Loss: 14.8525,Test_Loss: 13.3482,Train_f1: 0.3243,Test_f1: 0.3363,Train_acc: 0.4800,Test_acc: 0.5067
Epoch: 001, Train_Loss: 14.2735,Test_Loss: 12.8243,Train_f1: 0.3243,Test_f1: 0.3363,Train_acc: 0.4800,Test_acc: 0.5067
Epoch: 002, Train_Loss: 13.2474,Test_Loss: 12.3190,Train_f1: 0.3243,Test_f1: 0.3363,Train_acc: 0.4800,Test_acc: 0.5067
Epoch: 003, Train_Loss: 12.5152,Test_Loss: 11.9071,Train_f1: 0.3243,Test_f1: 0.3363,Train_acc: 0.4800,Test_acc: 0.5067
Epoch: 004, Train_Loss: 12.0989,Test_Loss: 11.5041,Train_f1: 0.3243,Test_f1: 0.3363,Train_acc: 0.4800,Test_acc: 0.5067
Epoch: 005, Train_Loss: 11.6849,Test_Loss: 11.1080,Train_f1: 0.3243,Test_f1: 0.3363,Train_acc: 0.4800,Test_acc: 0.5067
Epoch: 006, Train_Loss: 11.2775,Test_Loss: 10.7165,Train_f1: 0.3243,Test_f1: 0.3363,Train_acc: 0.4800,Test_acc: 0.5067
Epoch: 007, Train_Loss: 10.8756,Test_Loss: 10.3311,Train_f1: 0.3243,Test_f1: 0.3363,Train_acc: 0.4800,Test_acc: 0.5067
Epoch: 008, Train_Loss: 10.4759,Test_Loss: 9.949

KeyboardInterrupt: 

In [None]:
import matplotlib.pyplot as plt

plt.plot(train_losses)
plt.plot(val_losses)
plt.show()