In [None]:
import os
import numpy as np

import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader

In [None]:
# defining a seed for reproducible results
np.random.seed(69)

In [None]:
# Check if CUDA is available, then MPS, otherwise use CPU
if torch.cuda.is_available():
    device = torch.device("cuda")
    torch.cuda.empty_cache()
    # cluster path
    multimodal_path = "../scratch/multimodal"
elif torch.backends.mps.is_available():
    device = torch.device("mps")
    # local path
    multimodal_path = "../data.nosync/multimodal"
else:
    device = torch.device("cpu")
    # local path
    multimodal_path = "../data.nosync/multimodal"

print(f"Device set to: {device}")

In [None]:
# operators are always specified in this order
operator_order = ("elimination", "aggregation", "typification", "displacement", "enlargement", "simplification")

In [None]:
# Define DIN font for plots if working locally
if not torch.cuda.is_available():
    plt.rcParams["font.family"] = "DIN Alternate"

### Loading the data

In [None]:
class BuildingMultimodalDataset(Dataset):
    def __init__(self, path, operators, transform=None):
        '''Stores the directory and filenames of the individual raster (.npz) and vector (.pt) files.'''
        # store directory of individual files
        self.path = path
        # store the path to the raster and vector files
        self.raster_path = os.path.join(path, "raster")
        self.vector_path = os.path.join(path, "vector")

        # get filenames of the individual files
        # potentially sort this to make sure that the samples line up?
        self.raster_filenames = os.listdir(self.raster_path)
        self.vector_filenames = os.listdir(self.vector_path)

        # make sure that the samples line up
        assert len(self.raster_filenames == self.vector_filenames)

        # store indices of the operators within operator_order for slicing in the __getitem__ method
        self.operators = sorted([operator_order.index(operator) for operator in operators if operator in operator_order])

        # store transformation
        self.transform = transform

    def __len__(self):
        '''Enables dataset length calculation.'''
        return len(self.raster_filenames)

    def __getitem__(self, index):
        '''Enables indexing, returns graph and raster representation and generalization operator as label.'''
        # load the raster sample associated with the given index
        raster_filename = self.raster_filenames[index]
        raster_sample_raw = np.load(os.path.join(self.raster_path, raster_filename))

        # extract the rasters
        focal_building_raster = raster_sample_raw["focal_building"]
        context_buildings_raster = raster_sample_raw["context_buildings"]
        roads_raster = raster_sample_raw["roads"]

        # stack the rasters to shape (3, n_pixels, n_pixels) and convert to tensor
        raster_sample = np.stack([focal_building_raster, context_buildings_raster, roads_raster], axis=0)
        raster_sample = torch.from_numpy(raster_sample).float()

        # load the vector sample associated with the given index
        vector_filename = self.vector_filenames[index]
        vector_sample = torch.load(os.path.join(self.vector_path, vector_filename))

        # extract the operators from the graph object
        operators = vector_sample.y[self.operators]#.reshape(1, -1)

        return raster_sample, vector_sample, operators

### Model design

In [None]:
# architecture of best performing raster model
class RasterModel(nn.Module):
    pass

# architecture of best performing vector model
class VectorModel(nn.Module):
    pass

# load the trained raster model
raster_model_path = ""
raster_model = RasterModel()
raster_model.load_state_dict(torch.load(raster_model_path))
raster_model.eval()
    
# load the trained vector model
vector_model_path = ""
vector_model = VectorModel()
vector_model.load_state_dict(torch.load(vector_model_path))
vector_model.eval()

In [None]:
class MultimodalModel(nn.Module):
    def __init__(self, raster_model, vector_model, n_raster_features, n_vector_features, n_classes):
        super(MultimodalModel, self).__init__()
        self.raster_model = raster_model
        self.vector_model = vector_model
        
        # both models are already trained and only require gradient for fusion layers
        for param in self.raster_model.parameters():
            param.requires_grad = False
        for param in self.vector_model.parameters():
            param.requires_grad = False

        # TODO: remove last linear layers (https://discuss.pytorch.org/t/custom-ensemble-approach/52024/4)
        
        # fusion layer
        self.fusion_layer = nn.Linear(n_raster_features + n_vector_features, n_classes)

    def forward(self, graph, raster):
        raster_output = self.raster_model(block)
        vector_output = self.vector_model(graph.x_dict, graph.edge_index_dict)

        # concatenate along feature dimension
        combined_features = torch.cat((raster_output, vector_output), dim=1)
        result = self.fusion_layer(combined_features)
        return result

### Elimination model

In [None]:
# define path to training and validation data
path_to_training_data = os.path.join(multimodal_path, "training_data", "elimination", "training")
path_to_validation_data = os.path.join(multimodal_path, "training_data", "elimination", "validation")

### Selection model

In [None]:
# define path to training and validation data
path_to_training_data = os.path.join(multimodal_path, "training_data", "selection", "training")
path_to_validation_data = os.path.join(multimodal_path, "training_data", "selection", "validation")