In [1]:
import os
import math
import numpy as np
import matplotlib.pyplot as plt
import time

import torch
import torch.nn as nn
import torch.nn.functional as F

import torch_geometric.nn as pyg_nn
import torch_geometric.transforms as transforms
from torch_geometric.data import Dataset
from torch_geometric.loader import DataLoader

In [2]:
# defining a seed for reproducible results
np.random.seed(69)

In [3]:
# Check if CUDA is available, then MPS, otherwise use CPU
if torch.cuda.is_available():
    device = torch.device("cuda")
    torch.cuda.empty_cache()
    # cluster path
    vector_path = "../scratch/vector"
#elif torch.backends.mps.is_available():
#    device = torch.device("mps")
    # local path
#    vector_path = "../data.nosync/vector"
else:
    device = torch.device("cpu")
    # local path
    vector_path = "../data.nosync/vector"

print(f"Device set to: {device}")

Device set to: cpu


In [4]:
# operators are always specified in this order
operator_order = ("elimination", "aggregation", "typification", "displacement", "enlargement", "simplification")

In [5]:
# Define DIN font for plots if working locally
if not torch.cuda.is_available():
    plt.rcParams["font.family"] = "DIN Alternate"

### Loading the data

In [6]:
# Setting up a Dataset object for DataLoader
class BuildingVectorDataset(Dataset):
    def __init__(self, path, transform=None):
        '''Stores the directory and filenames of the individual .pt files.'''
        super().__init__(path, transform)
        # store directory of individual files
        self.path = path
        # get filenames of individual .pt files
        self.filenames = [file for file in os.listdir(path) if file.endswith(".pt")]

        # store transformation
        self.transform = transform

    def len(self):
        '''Enables dataset length calculation.'''
        return len(self.filenames)

    def get(self, index):
        '''Enables indexing, returns HeteroData object which contains nodes, edges and labels.'''
        # get filename associated with given index
        filename = self.filenames[index]

        # load the file with the filename
        graph = torch.load(os.path.join(self.path, filename))

        # TODO: slice graph.y with respect to the specified generalization operators in the __init__ method

        # apply given transformation if specified
        #if self.transform:
        #    graph = self.transform(graph)

        return graph

### Model design

In [7]:
# Define the model
class HeteroGNN(nn.Module):
    def __init__(self, node_features, node_to_predict, metadata, n_classes):
        super().__init__()
        self.conv1 = pyg_nn.HGTConv(in_channels=node_features,
                             out_channels=64, metadata=(metadata))
        self.conv2 = pyg_nn.HGTConv(in_channels=64, out_channels=n_classes, metadata=(metadata))

    def forward(self, x_dict, edge_index_dict):
        x_dict = self.conv1(x_dict, edge_index_dict)
        x_dict = {k: torch.relu(x) for k, x in x_dict.items()}
        x_dict = self.conv2(x_dict, edge_index_dict)
        return x_dict[node_to_predict]  # Only return the building predictions

In [8]:
# define path to training and validation data
path_to_training_data = os.path.join(vector_path, "training_data")

# composing various random transforms that should be applied to the data
#transform = transforms.Compose([transforms.ToUndirected()])

# construct training DataLoader
training_set = BuildingVectorDataset(path_to_training_data, transform=None)
training_loader = DataLoader(training_set, batch_size=1, shuffle=True)

# construct validation DataLoader (no transformations, no shuffling)
validation_set = BuildingVectorDataset(path_to_training_data, transform=None)
validation_loader = DataLoader(training_set, batch_size=1, shuffle=False)

print(f"{len(training_set):,} samples in the training set.")
print(f"{len(training_set):,} samples in the training set.")

10 samples in the training set.
10 samples in the training set.


In [13]:
# extracting the relevant metadata from the data to set up the model
node_to_predict = "building"
n_building_features = training_set.get(0)["building"]["x"].shape[1]
n_road_features = training_set.get(0)["road"]["x"].shape[1]
node_features = {"building": n_building_features, "road": n_road_features}
n_classes = training_set.get(0)["building"]["y"].shape[1]

print(f"{n_building_features} building features, {n_road_features} road features, {n_classes} operators")

# construct the model
model = HeteroGNN(node_features, node_to_predict=node_to_predict, metadata=training_set.get(0).metadata(), n_classes=n_classes)
model.to(device)

learning_rate = 0.001

criterion = nn.BCEWithLogitsLoss() # Binary cross-entropy loss, applies a sigmoid internally and takes logits as input
optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate)

2 building features, 2 road features, 6 operators


In [14]:
# number of epochs and batch size
n_epochs = 10
batch_size = 1

total_samples = len(training_set)
n_iterations = math.ceil(total_samples/batch_size)
    
# saving the losses from every epoch
training_losses = []
validation_losses = []

start_time = time.perf_counter()
    
for epoch in range(n_epochs):
    # tracking loss per epoch
    training_running_loss = 0.0
    n_training_batches = 0
    validation_running_loss = 0.0
    n_validation_batches = 0

    # training phase
    model.train()

    for i, graph in enumerate(training_loader):
        n_training_batches += 1
        # target operators
        operators = graph.y_dict[node_to_predict]
        
        # moving the features to device
        graph = graph.to(device)
        operators = operators.to(device)
    
        # empty the gradients
        optimizer.zero_grad()
            
        # forward pass
        pred_operators_logits = model(graph.x_dict, graph.edge_index_dict) # compute predictions, calls forward method under the hood
        loss = criterion(pred_operators_logits, operators) # calculate loss
        training_running_loss += loss.item() # tracking running loss to keep track of the loss for every epoch
    
        # backward pass
        loss.backward() # backpropagation
        optimizer.step() # update the parameters
    
        # print information every few batches
        if not (i + 1) % (n_iterations // 10):
            print(f"epoch {epoch+1}/{n_epochs}, step {i+1}/{n_iterations}")

    # validation phase
    model.eval()

    with torch.no_grad():
        for graph in validation_loader:
            n_validation_batches += 1
            # target operators
            operators = graph.y_dict[node_to_predict]
            
            # moving the features to device
            graph = graph.to(device)
            operators = operators.to(device)

            # prediction on the trained model results in logits
            pred_operators_logits = model(graph.x_dict, graph.edge_index_dict) # compute predictions, calls forward method under the hood
            # calculate and store validation loss
            loss = criterion(pred_operators_logits, operators)
            validation_running_loss += loss.item()
    
    # print information at the end of each epoch
    training_loss_epoch = training_running_loss / n_training_batches
    training_losses.append(training_loss_epoch)
    validation_loss_epoch = validation_running_loss / n_validation_batches
    validation_losses.append(validation_loss_epoch)
    
    print(f"epoch {epoch+1} finished, training loss: {training_loss_epoch:.3f}, validation loss: {validation_loss_epoch:.3f}")

end_time = time.perf_counter()
print(f"Training time: {end_time - start_time:,.3f} seconds")

epoch 1/10, step 1/10
epoch 1/10, step 2/10
epoch 1/10, step 3/10
epoch 1/10, step 4/10
epoch 1/10, step 5/10
epoch 1/10, step 6/10
epoch 1/10, step 7/10
epoch 1/10, step 8/10
epoch 1/10, step 9/10
epoch 1/10, step 10/10
epoch 1 finished, training loss: 501.149, validation loss: 6.862
epoch 2/10, step 1/10
epoch 2/10, step 2/10
epoch 2/10, step 3/10
epoch 2/10, step 4/10
epoch 2/10, step 5/10
epoch 2/10, step 6/10
epoch 2/10, step 7/10
epoch 2/10, step 8/10
epoch 2/10, step 9/10
epoch 2/10, step 10/10
epoch 2 finished, training loss: 3.137, validation loss: 0.714
epoch 3/10, step 1/10
epoch 3/10, step 2/10
epoch 3/10, step 3/10
epoch 3/10, step 4/10
epoch 3/10, step 5/10
epoch 3/10, step 6/10
epoch 3/10, step 7/10
epoch 3/10, step 8/10
epoch 3/10, step 9/10
epoch 3/10, step 10/10
epoch 3 finished, training loss: 0.712, validation loss: 0.711
epoch 4/10, step 1/10
epoch 4/10, step 2/10
epoch 4/10, step 3/10
epoch 4/10, step 4/10
epoch 4/10, step 5/10
epoch 4/10, step 6/10
epoch 4/10, st