In [None]:
import numpy as np

import torch

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

seed = 123

torch.manual_seed(seed)
torch.cuda.manual_seed(seed)

np.random.seed(seed)

n_folds = 10 # n-fold cross-validation
batch_size = 10 # number of meshes in a minibatch
breakpoint = 100 # if validation loss does not decrease after this number of epochs, we break the training loop
post_train = True # post train with SGD + momentum (lr=0.001, momentum=0.9)

misalign = False # whether to misalign the meshes (i.e. apply random isometric transformations to them)
realign = False # whether to realign the meshes

In [None]:
from dataset import HCPDataset
from data_utils import align_meshes, connect_nodes, RandomIsoTransform
from torch_geometric.transforms import Compose
from torch_geometric.transforms import NormalizeFeatures
from torch_geometric.transforms import NormalizeScale

data_path = 'data'
!rm -r data/processed

pre_transform = Compose(
    (
        connect_nodes,
        NormalizeFeatures(),
        NormalizeScale(),
        RandomIsoTransform(global_=not misalign),
    )
)

dataset = HCPDataset(data_path, pre_transform=pre_transform)

if realign:
    dataset = align_meshes(dataset)

In [None]:
# import models
from mlp import MLP
from gnn import GNN
from egnn import EGNN

In [None]:
from torch_geometric.loader import DataLoader
from data_utils import split_fold
from train_test import train_model, test_model
from metrics import plot_learning_curve

test_losses = []
test_IoUs = []
for fold in range(n_folds):
    train_subset, val_subset, test_subset = split_fold(dataset, fold, n_folds)
    train_loader = DataLoader(train_subset, batch_size=batch_size, shuffle=True)
    val_loader = DataLoader(val_subset, batch_size=len(val_subset), shuffle=True)
    test_loader = DataLoader(test_subset, batch_size=len(test_subset), shuffle=True)
    model = MLP(device, 12, 32, 3)
    # model = GNN(device, 12, 32, 3, residual=False)
    # model = EGNN(device, 9, 32, 3, residual=False)
    optimizer = torch.optim.Adam(model.parameters(), lr=0.001)
    print(model)
    print('Number of parameters: ', sum(p.numel() for p in model.parameters() if p.requires_grad))
    val_losses = []
    val_IoUs = []
    # train with Adam
    best_val_index = -1
    for epoch in range(200 * breakpoint):
        train_loss, train_IoU = train_model(train_loader, model, optimizer)
        val_loss, val_IoU = test_model(val_loader, model)
        val_loss = val_loss.cpu().detach().numpy()
        new_min = " "
        if epoch > 0:
            if val_losses[best_val_index] > val_loss:
                new_min = "*"
                best_val_index = epoch
                torch.save(model, "models/best_model_%d.pkl" % fold)
            if epoch - best_val_index > breakpoint:
                break
        val_losses.append(val_loss)
        val_IoUs.append(val_IoU)
        print(new_min,
              "Epoch: %d, train loss: %1.3f, train IoU/class: %1.3f %1.3f %1.3f, val loss: %1.3f, val IoU/class: %1.3f %1.3f %1.3f" \
              % (epoch, train_loss, train_IoU[0], train_IoU[1], train_IoU[2], val_loss, val_IoU[0], val_IoU[1], val_IoU[2]))
    model = torch.load("models/best_model_%d.pkl" % fold)
    if post_train:
        model.optimizer = torch.optim.SGD(model.parameters(), lr=0.001, momentum=0.9)
        best_val_index = -1
        for epoch in range(200):
            train_loss, train_IoU = train_model(train_loader, model, optimizer)
            val_loss, val_IoU = test_model(val_loader, model)
            val_loss = val_loss.cpu().detach().numpy()
            new_min = " "
            if epoch > 0:
                if val_losses[best_val_index] > val_loss:
                    new_min = "*"
                    best_val_index = epoch
                    torch.save(model, "models/best_model_%d.pkl" % fold)
            val_losses.append(val_loss)
            val_IoUs.append(val_IoU)
            print(new_min,
                  "Epoch: %d, train loss: %1.3f, train IoU/class: %1.3f %1.3f %1.3f, val loss: %1.3f, val IoU/class: %1.3f %1.3f %1.3f" \
                  % (epoch, train_loss, train_IoU[0], train_IoU[1], train_IoU[2], val_loss, val_IoU[0], val_IoU[1], val_IoU[2]))
        model = torch.load("models/best_model_%d.pkl" % fold)
    # test model
    test_loss, test_IoU = test_model(test_loader, model)
    test_loss = test_loss.cpu().detach().numpy()
    test_losses.append(test_loss)
    test_IoUs.append(test_IoU)
    print("Fold: %d, test loss: %1.3f, test IoU/class: %1.3f %1.3f %1.3f" % (fold, test_loss, test_IoU[0], test_IoU[1], test_IoU[2]))
    plot_learning_curve(val_losses, val_IoUs)

print(sum(test_losses) / len(test_losses))
print(sum(test_IoUs) / len(test_IoUs))

In [None]:
for test_IoU in test_IoUs:
    print(test_IoU)