In [1]:
import torch
from torch_geometric.nn import GraphUNet
import torch.nn.functional as F
from torch_geometric.utils import dropout_adj
from models.datasets import create_simulation_graph_set

class Net(torch.nn.Module):
    def __init__(self):
        super(Net, self).__init__()
        pool_ratios = [.75, 0.5]
        self.unet = GraphUNet(1, 100, 2,
                              depth=3, pool_ratios=pool_ratios)

    def forward(self):
        edge_index, _ = dropout_adj(data.edge_index, p=0.2,
                                    force_undirected=True,
                                    num_nodes=data.num_nodes,
                                    training=self.training)
        x = F.dropout(data.x, p=0.92, training=self.training)
        x = self.unet(x, edge_index)
        return F.log_softmax(x, dim=1)

def train():
    model.train()
    optimizer.zero_grad()
    F.nll_loss(model()[data.train_mask], data.y[data.train_mask]).backward()
    optimizer.step()


def test():
    model.eval()
    logits, accs = model(), []
    for _, mask in data('train_mask', 'val_mask', 'test_mask'):
        pred = logits[mask].max(1)[1]
        acc = pred.eq(data.y[mask]).sum().item() / mask.sum().item()
        accs.append(acc)
    return accs

In [2]:
n_kp = 100
threshold = 35
train_n = 10
test_n = 2
epochs = 75
train_data = create_simulation_graph_set(n_kp, threshold, train_n)
test_data = create_simulation_graph_set(n_kp, threshold, test_n)

device = 'cuda' if torch.cuda.is_available() else 'cpu'
model = Net().to(device)
optimizer = torch.optim.Adam(model.parameters(), lr=0.01, weight_decay=0.001)

best_val_acc = test_acc = 0

In [3]:
for epoch in range(1, epochs):
        for data in train_data:
            train()
        train_acc, val_acc, tmp_test_acc = test()
        if val_acc > best_val_acc:
            best_val_acc = val_acc
            test_acc = tmp_test_acc
        log = 'Epoch: {:03d}, Train: {:.4f}, Val: {:.4f}, Test: {:.4f}'
        print(log.format(epoch, train_acc, best_val_acc, test_acc))



Epoch: 001, Train: 0.9000, Val: 0.8462, Test: 0.8409
Epoch: 002, Train: 0.9400, Val: 0.9038, Test: 0.9091
Epoch: 003, Train: 0.6600, Val: 0.9038, Test: 0.9091
Epoch: 004, Train: 0.9400, Val: 0.9038, Test: 0.9091
Epoch: 005, Train: 0.8800, Val: 0.9038, Test: 0.9091
Epoch: 006, Train: 0.9000, Val: 0.9038, Test: 0.9091
Epoch: 007, Train: 0.8800, Val: 0.9038, Test: 0.9091
Epoch: 008, Train: 0.7800, Val: 0.9038, Test: 0.9091
Epoch: 009, Train: 0.8400, Val: 0.9038, Test: 0.9091
Epoch: 010, Train: 0.8000, Val: 0.9038, Test: 0.9091
Epoch: 011, Train: 0.8000, Val: 0.9038, Test: 0.9091
Epoch: 012, Train: 0.8000, Val: 0.9038, Test: 0.9091
Epoch: 013, Train: 0.8000, Val: 0.9038, Test: 0.9091
Epoch: 014, Train: 0.5200, Val: 0.9038, Test: 0.9091
Epoch: 015, Train: 0.8000, Val: 0.9038, Test: 0.9091
Epoch: 016, Train: 0.6800, Val: 0.9038, Test: 0.9091
Epoch: 017, Train: 0.7600, Val: 0.9038, Test: 0.9091
Epoch: 018, Train: 0.6600, Val: 0.9038, Test: 0.9091
Epoch: 019, Train: 0.5200, Val: 0.9038, Test: 