In [1]:
import os
import numpy as np
import matplotlib.pyplot as plt

import torch
import torch.nn as nn
import torch_geometric.nn as pyg_nn
from torch.utils.data import Dataset, DataLoader

from model_components.vit import *

In [2]:
# defining a seed for reproducible results
np.random.seed(69)

In [3]:
# Check if CUDA is available, then MPS, otherwise use CPU
if torch.cuda.is_available():
    device = torch.device("cuda")
    torch.cuda.empty_cache()
    # cluster path
    multimodal_path = "../scratch/multimodal"
else:
    device = torch.device("cpu")
    # local path
    multimodal_path = "../data.nosync/multimodal"

print(f"Device set to: {device}")

Device set to: cpu


In [4]:
# operators are always specified in this order
operator_order = ("elimination", "aggregation", "typification", "displacement", "enlargement", "simplification")

In [5]:
# Define DIN font for plots if working locally
if not torch.cuda.is_available():
    plt.rcParams["font.family"] = "DIN Alternate"

### Loading the data

In [6]:
class BuildingMultimodalDataset(Dataset):
    def __init__(self, raster_path, vector_path, operators, transform=None):
        '''Stores the directory and filenames of the individual raster (.npz) and vector (.pt) files.'''
        # store the path to the raster and vector files
        self.raster_path = raster_path
        self.vector_path = vector_path

        # get filenames of the individual files, sort the filenames to make them line up
        self.raster_filenames = sorted(os.listdir(self.raster_path))
        self.vector_filenames = sorted(os.listdir(self.vector_path))

        # make sure that the samples line up
        assert len(self.raster_filenames) == len(self.vector_filenames)

        # store indices of the operators within operator_order for slicing in the __getitem__ method
        self.operators = sorted([operator_order.index(operator) for operator in operators if operator in operator_order])

        # store transformation
        self.transform = transform

    def __len__(self):
        '''Enables dataset length calculation.'''
        return len(self.raster_filenames)

    def __getitem__(self, index):
        '''Enables indexing, returns graph and raster representation and generalization operator as label.'''
        # load the raster sample associated with the given index
        raster_filename = self.raster_filenames[index]
        raster_sample_raw = np.load(os.path.join(self.raster_path, raster_filename))

        # extract the rasters
        focal_building_raster = raster_sample_raw["focal_building"]
        context_buildings_raster = raster_sample_raw["context_buildings"]
        roads_raster = raster_sample_raw["roads"]

        # stack the rasters to shape (3, n_pixels, n_pixels) and convert to tensor
        raster_sample = np.stack([focal_building_raster, context_buildings_raster, roads_raster], axis=0)
        raster_sample = torch.from_numpy(raster_sample).float()

        # load the vector sample associated with the given index
        vector_filename = self.vector_filenames[index]
        vector_sample = torch.load(os.path.join(self.vector_path, vector_filename))

        # extract the operators from the graph object
        operators = vector_sample.y[self.operators]

        # reshape the operators associated with the graph
        vector_sample.y = vector_sample.y[self.operators].reshape(1, -1)

        return raster_sample, vector_sample, operators

### Model design

In [7]:
# architecture of best performing raster model
class MultiTaskViT(nn.Module):
    def __init__(self, *, 
                 image_size=256, 
                 patch_size=32, 
                 num_classes, 
                 dim=512, 
                 depth=6, 
                 heads=16, 
                 mlp_dim=2048, 
                 pool='cls',
                 channels, 
                 dim_head=64, 
                 dropout=0., 
                 emb_dropout=0.):
        super().__init__()
        image_height, image_width = pair(image_size)
        patch_height, patch_width = pair(patch_size)

        assert image_height % patch_height == 0 and image_width % patch_width == 0, 'Image dimensions must be divisible by the patch size.'

        num_patches = (image_height // patch_height) * (image_width // patch_width)
        patch_dim = channels * patch_height * patch_width
        assert pool in {'cls', 'mean'}, 'pool type must be either cls (cls token) or mean (mean pooling)'

        self.to_patch_embedding = nn.Sequential(
            Rearrange('b c (h p1) (w p2) -> b (h w) (p1 p2 c)', p1 = patch_height, p2 = patch_width),
            nn.LayerNorm(patch_dim),
            nn.Linear(patch_dim, dim),
            nn.LayerNorm(dim),
        )

        self.pos_embedding = nn.Parameter(torch.randn(1, num_patches + 1, dim))
        self.cls_token = nn.Parameter(torch.randn(1, 1, dim))
        self.dropout = nn.Dropout(emb_dropout)

        self.transformer = Transformer(dim, depth, heads, dim_head, mlp_dim, dropout)

        self.pool = pool
        self.to_latent = nn.Identity()

        self.classification_heads = nn.ModuleList([nn.Linear(dim, 1) for _ in range(num_classes)])

    def forward(self, img):
        x = self.to_patch_embedding(img)
        b, n, _ = x.shape

        cls_tokens = repeat(self.cls_token, '1 1 d -> b 1 d', b = b)
        x = torch.cat((cls_tokens, x), dim=1)
        x += self.pos_embedding[:, :(n + 1)]
        x = self.dropout(x)

        x = self.transformer(x)

        x = x.mean(dim = 1) if self.pool == 'mean' else x[:, 0]

        x = self.to_latent(x)
        
        # apply each classification head and concatenate the results along the final dimension
        if isinstance(self.classification_heads, nn.ModuleList):
            outputs = torch.cat([head(x).squeeze(-1).unsqueeze(1) for head in self.classification_heads], dim=1)
            return outputs
        else:
            return x

    def get_n_parameters(self):
        n_parameters = sum(p.numel() for p in self.parameters())
        return n_parameters

    def __str__(self):
        return f"Multi-Task Vision transformer with {self.get_n_parameters():,} parameters"

class ClassificationHead(nn.Module):
    def __init__(self, n_input_features, n_classes):
        super().__init__()

        self.fc = nn.Sequential(
            pyg_nn.Linear(n_input_features, n_input_features//2),
            nn.ReLU(inplace=True),
            pyg_nn.Linear(n_input_features//2, n_classes)
        )

    def forward(self, x):
        return self.fc(x)

# architecture of best performing vector model
class MultiTaskHGT(nn.Module):
    def __init__(self, hidden_channels, out_channels, num_heads, num_layers, node_types, metadata, node_to_predict):
        super().__init__()

        self.lin_dict = nn.ModuleDict()
        for node_type in node_types:
            self.lin_dict[node_type] = pyg_nn.Linear(-1, hidden_channels)

        self.convs = nn.ModuleList()
        for _ in range(num_layers):
            conv = pyg_nn.HGTConv(hidden_channels, hidden_channels, metadata, num_heads)
            self.convs.append(conv)

        self.classification_heads = nn.ModuleList([ClassificationHead(n_input_features=hidden_channels, n_classes=1) for _ in range(out_channels)])

    def forward(self, x_dict, edge_index_dict):
        x_dict = {
            node_type: self.lin_dict[node_type](x).relu_()
            for node_type, x in x_dict.items()
        }

        for conv in self.convs:
            x_dict = conv(x_dict, edge_index_dict)

        if isinstance(self.classification_heads, nn.ModuleList):
            outputs = torch.cat([head(x_dict[node_to_predict]).squeeze(-1).unsqueeze(1) for head in self.classification_heads], dim=1)
            return outputs
        else:
            return x_dict[node_to_predict]

    def get_n_parameters(self):
        n_parameters = sum(p.numel() for p in self.parameters())
        return n_parameters

    def __str__(self):
        return f"Multi-Task Heterogenous Graph Transformer with {self.get_n_parameters():,} parameters"

In [8]:
class MultimodalModel(nn.Module):
    def __init__(self, raster_model, vector_model, dummy_raster_sample, dummy_vector_sample, n_classes):
        super(MultimodalModel, self).__init__()
        self.raster_model = raster_model
        self.vector_model = vector_model
        
        # both models are already trained and only require gradient for fusion layers
        for param in self.raster_model.parameters():
            param.requires_grad = False
        for param in self.vector_model.parameters():
            param.requires_grad = False

        # remove classification heads
        self.raster_model.classification_heads = nn.Identity()
        self.vector_model.classification_heads = nn.Identity()

        # pass dummy raster and dummy vector samples through the networks to determine the number of output features
        # when the classification heads are missing
        out_raster = self.raster_model(dummy_raster_sample.unsqueeze(0))
        out_vector = self.vector_model(dummy_vector_sample.x_dict, dummy_vector_sample.edge_index_dict)
        n_raster_features = out_raster.shape[1]
        n_vector_features = out_vector.shape[1]
        
        # fusion layer
        self.fusion_layer = nn.Linear(n_raster_features + n_vector_features, n_classes)

    def forward(self, raster, graph):
        raster_output = self.raster_model(raster)
        vector_output = self.vector_model(graph.x_dict, graph.edge_index_dict)

        # concatenate along feature dimension
        combined_features = torch.cat((raster_output, vector_output), dim=1)
        result = self.fusion_layer(combined_features)
        return result

### Elimination model

In [9]:
# define path to training, validation and test data for both raster and vector
path_to_raster_training_data = "../data.nosync/raster/training_data/elimination/training"
path_to_vector_training_data = "../data.nosync/vector/training_data/elimination/training"
path_to_raster_validation_data = "../data.nosync/raster/training_data/elimination/validation"
path_to_vector_validation_data = "../data.nosync/vector/training_data/elimination/validation"
path_to_raster_test_data = "../data.nosync/raster/training_data/elimination/test"
path_to_vector_test_data = "../data.nosync/vector/training_data/elimination/test"

# define input parameters
elimination_operators = ["elimination"]
n_classes = len(elimination_operators)

batch_size = 16

# construct training DataLoader
training_set = BuildingMultimodalDataset(path_to_raster_training_data, path_to_vector_training_data, operators=elimination_operators)
training_loader = DataLoader(dataset=training_set, batch_size=batch_size, shuffle=True)

# construct validation DataLoader (no shuffling)
validation_set = BuildingMultimodalDataset(path_to_raster_validation_data, path_to_vector_validation_data, operators=elimination_operators)
validation_loader = DataLoader(dataset=validation_set, batch_size=batch_size, shuffle=False)

# construct test DataLoader (no shuffling)
test_set = BuildingMultimodalDataset(path_to_raster_test_data, path_to_vector_test_data, operators=elimination_operators)
test_loader = DataLoader(dataset=test_set, batch_size=batch_size, shuffle=False)

print(f"{len(training_set):,} samples in the training set.")
print(f"{len(validation_set):,} samples in the validation set.")
print(f"{len(test_set):,} samples in the test set.")

1,000 samples in the training set.
250 samples in the validation set.
250 samples in the test set.


In [10]:
# extracting dummy raster sample
dummy_raster_sample = training_set[0][0]

# extracting the relevant metadata from the data to set up the graph model
dummy_vector_sample = training_set[0][1]
node_types = dummy_vector_sample.node_types

node_features = {}

for node_type in node_types:
    node_features[node_type] = dummy_vector_sample[node_type]["x"].shape[1]

n_classes = dummy_vector_sample["y"].shape[1]

print(f"Number of node features: {node_features}, {n_classes} operators")

metadata = dummy_vector_sample.metadata()
node_to_predict = "focal_building"

Number of node features: {'focal_building': 10, 'context_building': 10, 'road': 2}, 1 operators


In [11]:
# load the trained raster model
raster_model_path = "../data.nosync/raster/models/elimination"
raster_model_name = "MultiTaskViT_eli_attachRoadsTrue_26812417p_1000s_10ep_bs16.pth"
raster_model = MultiTaskViT(channels=3, num_classes=n_classes)
raster_checkpoint = torch.load(os.path.join(raster_model_path, raster_model_name))
raster_model.load_state_dict(raster_checkpoint["model_state_dict"])
raster_model.eval()
    
# load the trained vector model
vector_model_path = "../data.nosync/vector/models/elimination"
vector_model_name = "MultiTaskHGT_eli_attachRoadsTrue_637219p_10000s_10ep_bs16.pth"
vector_model = MultiTaskHGT(hidden_channels=128, out_channels=n_classes, num_heads=2, num_layers=2, node_types=node_types, 
                            metadata=metadata, node_to_predict=node_to_predict)

# initialize lazy modules
with torch.no_grad():
    out = vector_model(dummy_vector_sample.x_dict, dummy_vector_sample.edge_index_dict)

vector_checkpoint = torch.load(os.path.join(vector_model_path, vector_model_name))
vector_model.load_state_dict(vector_checkpoint["model_state_dict"])
vector_model.eval()

print("Models successfully loaded.")

Models successfully loaded.


In [12]:
# initialize the multimodal model
multimodal_model = MultimodalModel(raster_model, vector_model, dummy_raster_sample, dummy_vector_sample, n_classes)

In [13]:
multimodal_model(dummy_raster_sample.unsqueeze(0), dummy_vector_sample)

tensor([[0.3177]], grad_fn=<AddmmBackward0>)

### Selection model

In [14]:
# define path to training, validation and test data for both raster and vector
path_to_raster_training_data = "../data.nosync/raster/training_data/selection/training"
path_to_vector_training_data = "../data.nosync/vector/training_data/selection/training"
path_to_raster_validation_data = "../data.nosync/raster/training_data/selection/validation"
path_to_vector_validation_data = "../data.nosync/vector/training_data/selection/validation"
path_to_raster_test_data = "../data.nosync/raster/training_data/selection/test"
path_to_vector_test_data = "../data.nosync/vector/training_data/selection/test"

# define input parameters
selection_operators = ["aggregation", "typification", "displacement", "enlargement"]
n_classes = len(selection_operators)

batch_size = 16

# construct training DataLoader
training_set = BuildingMultimodalDataset(path_to_raster_training_data, path_to_vector_training_data, operators=selection_operators)
training_loader = DataLoader(dataset=training_set, batch_size=batch_size, shuffle=True)

# construct validation DataLoader (no shuffling)
validation_set = BuildingMultimodalDataset(path_to_raster_validation_data, path_to_vector_validation_data, operators=selection_operators)
validation_loader = DataLoader(dataset=validation_set, batch_size=batch_size, shuffle=False)

# construct test DataLoader (no shuffling)
test_set = BuildingMultimodalDataset(path_to_raster_test_data, path_to_vector_test_data, operators=selection_operators)
test_loader = DataLoader(dataset=test_set, batch_size=batch_size, shuffle=False)

print(f"{len(training_set):,} samples in the training set.")
print(f"{len(validation_set):,} samples in the validation set.")
print(f"{len(test_set):,} samples in the test set.")

1,000 samples in the training set.
250 samples in the validation set.
250 samples in the test set.


In [15]:
# extracting dummy raster sample
dummy_raster_sample = training_set[0][0]

# extracting the relevant metadata from the data to set up the graph model
dummy_vector_sample = training_set[0][1]
node_types = dummy_vector_sample.node_types

node_features = {}

for node_type in node_types:
    node_features[node_type] = dummy_vector_sample[node_type]["x"].shape[1]

n_classes = dummy_vector_sample["y"].shape[1]

print(f"Number of node features: {node_features}, {n_classes} operators")

metadata = dummy_vector_sample.metadata()
node_to_predict = "focal_building"

Number of node features: {'focal_building': 10, 'context_building': 10, 'road': 2}, 4 operators


In [16]:
# load the trained raster model
raster_model_path = "../data.nosync/raster/models/selection"
raster_model_name = "MultiTaskViT_sel_attachRoadsTrue_26813956p_1000s_10ep_bs16.pth"
raster_model = MultiTaskViT(channels=3, num_classes=4)
raster_checkpoint = torch.load(os.path.join(raster_model_path, raster_model_name))
raster_model.load_state_dict(raster_checkpoint["model_state_dict"])
raster_model.eval()
    
# load the trained vector model
vector_model_path = "../data.nosync/vector/models/selection"
vector_model_name = "MultiTaskHGT_sel_attachRoadsTrue_662182p_10000s_10ep_bs16.pth"
vector_model = MultiTaskHGT(hidden_channels=128, out_channels=n_classes, num_heads=2, num_layers=2, node_types=node_types, 
                            metadata=metadata, node_to_predict=node_to_predict)

# initialize lazy modules
with torch.no_grad():
    out = vector_model(dummy_vector_sample.x_dict, dummy_vector_sample.edge_index_dict)

vector_checkpoint = torch.load(os.path.join(vector_model_path, vector_model_name))
vector_model.load_state_dict(vector_checkpoint["model_state_dict"])
vector_model.eval()

print("Models successfully loaded.")

Models successfully loaded.


In [17]:
# initialize the multimodal model
multimodal_model = MultimodalModel(raster_model, vector_model, dummy_raster_sample, dummy_vector_sample, n_classes)

In [18]:
multimodal_model(dummy_raster_sample.unsqueeze(0), dummy_vector_sample)

tensor([[-0.8012, -0.2070,  0.4219, -0.3853]], grad_fn=<AddmmBackward0>)