### VERGE: Vector-mode Regional Geospatial Encoding
# Masked Geospatial Model Implementation

Here we build and train a "masked geospatial model".
This is a model in which each input is a set of encoded geospatial entities,
consisting of a concatenation of a multi-point proximity encoding and a one-hot label vector.
Modeling consists of masking the labels for a random selection of entities,
passing the data through an encoder-based architecutre to predicte the labels of masked entities.
The idea is that the encodings then capture information about the region.


## Colab setup

In [None]:
from google.colab import drive
drive.mount('/content/drive')
project_home = '/content/drive/MyDrive/Projects/verge'

## Setup

In [None]:
import pandas as pd
import numpy as np
import glob
import pickle
import os
import torch
import torch.nn as nn
import torch.utils
import torch.utils.data
from torch.utils.data import DataLoader
import copy

## Parameters

In [None]:
# A unique identifier for this run. This will be a component of any
# output file names.
run_id = '003'

# What type of device to train on.
device = 'cuda' if torch.cuda.is_available() else 'cpu'
print('using device', device)

# This is the dimension of the (square) AOIs. Set this to match what was used
# when the tiles were created.
aoi_size = 1000

# This is the resolution of the MPP encoding.
resolution = 50

# This is how many elements there are in a geometric encoding. This is actually
# implied by the MPP encoding resolution and the AOI size.
geo_encoding_dim = 400

# Fraction of cases to use for training.
train_fraction = 0.8

## Preliminaries

In [None]:
# Read the list of labels.
fname = '%s/data/labels.csv' % project_home
labels = pd.read_csv(fname)
n_classes = len(labels)
print('%d labels in this dataset' % n_classes)

label_id_lookup = {
    z['label']: z['id']
    for z in labels.to_dict('records')
}

label_name_lookup = {
    z['id']: z['label']
    for z in labels.to_dict('records')
}

## Load data

In [None]:
# Get a list of input data files. Each file consists of a list of encodings for
# a number of square tiles in a particluar AOI.
globstring = '%s/data/encodings/*' % project_home
fnames = glob.glob(globstring)
print('%d input files' % len(fnames))

In [None]:
# Split the dataset into training and validation components based on AOI.
# This minimizes the possibility
# of spatial autocorrelation biasing performance assessments.
splits = []
for fname in fnames:
  if np.random.random() < train_fraction:
    splits.append({'fname': fname, 'type': 'train'})
  else:
    splits.append({'fname': fname, 'type': 'val'})

# Save that.
fname = '%s/splits/split-%s.csv' % (project_home, run_id)
pd.DataFrame(splits).to_csv(fname, index=False)
print('saved %s' % fname)


In [None]:
# Read all data. This input consists of the endocde geospatial features
# for individual tiles, concatenated with one-hot vectors indicating
# their class.
#
# Each input file contains data for all tiles in an AOI. We split into
# training and validation sets based on AOI to avoid
#

split_lookup = {
    z['fname'] : z['type']
    for z in splits
}

train_tiles = []
val_tiles = []

for fname in fnames:
    print('reading', fname)
    with open(fname, 'rb') as source:
        data =  pickle.load(source)
        sp = split_lookup[fname]
        if sp == 'train':
            train_tiles += data
        else:
            val_tiles += data

print('%d training instances' % len(train_tiles))
print('%d validation instances' % len(val_tiles))

In [None]:
# This class wraps a list of input tile data as a pytorch dataset.
# The "getitem" method here parses apart the true labels and the encodings,
# and applies random masking to the encoding.

class VergeDataset(torch.utils.data.Dataset):

    def __init__(self, data_list, n_classes, mask_fraction):
        self.data = data_list
        self.n_classes = n_classes
        self.mask_fraction = mask_fraction
        self.encoding_dim = data_list[0].shape[1] - self.n_classes
        # print('encoding_dim', self.encoding_dim)
        # print('n_classes', self.n_classes)

        # When accessing any item, we will also be sampling from its available classes.
        # But this dataset has a big class imbalance, so we will sample according
        # to inverse probability. Here we compute the probability distribution of classes.
        self.class_prob = {z: 0.0 for z in range(self.n_classes)}
        n = 0.0
        for d in data_list:
            true_labels_onehot = d[:, :self.n_classes]
            true_labels = np.argmax(true_labels_onehot, axis=1)
            for label in true_labels:
                self.class_prob[label] += 1.0
            n += len(true_labels)
        for label in self.class_prob:
            self.class_prob[label] /= n

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):

        features = self.data[idx]
        encodings = features[:, self.n_classes:]
        true_labels_onehot = features[:, :self.n_classes]
        true_labels = np.argmax(true_labels_onehot, axis=1)
        n_entities = features.shape[0]

        # Sample eneite to mask. This weights the sampling by the relative
        # frequency of different classes in the dataset -- i.e. it addresses
        # class imbalance.
        weights = []
        for label in true_labels:
            prob = self.class_prob[label]
            weights.append(1.0 / (prob + 0.001))
        weights = np.array(weights)
        weights = weights / np.sum(weights)
        sample_size = int(np.ceil(self.mask_fraction * n_entities))
        mask_indices = np.random.choice(n_entities, size=sample_size, replace=True, p=weights)

        # The old way: no weighting in selection of masked entities.
        # mask = np.random.rand(n_entities) < self.mask_fraction
        # mask_indices = np.where(mask)[0]
        # print('mask_indices', mask_indices)

        # In the feature array, labels are one-hot vectors that get concatenated
        # with the geometric encodings. To "mask" those labels, we replace the
        # one-hot vector with a zero-hot vector.
        mask_vector = np.zeros(self.n_classes)
        masked_labels_onehot = copy.copy(true_labels_onehot)
        for i in mask_indices:
            masked_labels_onehot[i] = mask_vector
            # print('replaced one-hot vectdor for row %d' % i)

        # Re-concatenate the masked labels with the geometric encodings.
        masked_labels_onehot_tensor = torch.tensor(masked_labels_onehot, dtype=torch.float32)
        encodings_tensor = torch.tensor(encodings, dtype=torch.float32)
        masked_features = torch.cat(
            [masked_labels_onehot_tensor, encodings_tensor], dim=1
        )

        # During model training below, we will be using the "CrossEntropyLoss" function,
        # which has a built-in capability to ignore elements thatwe don't care about,
        # which in this case is any element that is NOT masked. To get it to work,
        # we need to pack an "ignore" token into any label slot that is not masked.
        # Pytorch's standard value for that token is -100. Or more specifically
        # we start with all "ignore" tokens and just replace the ones that we do
        # care about with the appropriate value.
        labels = torch.full(true_labels.shape, -100, dtype=torch.long)
        for i in mask_indices:
            labels[i] = true_labels[i]
            # print('set true label for element %d to %d' % (i, true_labels[i]))

        # Shuffle the features and labels.
        perm = torch.randperm(masked_features.shape[0])
        masked_features = masked_features[perm]
        labels = labels[perm]

        return (masked_features, labels)

# Test that.
dataset = VergeDataset(train_tiles, n_classes, mask_fraction=0.2)
features, labels = dataset[0]
print('features.shape', features.shape)
print('labels.shape', labels.shape)


In [None]:
# Define the function that puts together a batch. The main thing we are handling here
# is padding. We make all arrays have a size equal to the largest one in the batch,
# with excess space filled with padding tokens.
def collate_fn(batch):

    features, labels = zip(*batch)
    max_len = max(x.shape[0] for x in features)
    batch_size = len(features)
    feature_dim = features[0].shape[1]

    padded_features = torch.zeros(batch_size, max_len, feature_dim)
    padded_labels = torch.full((batch_size, max_len), -100, dtype=torch.long)  # -100 is the "ignore" value
    attention_mask = torch.zeros(batch_size, max_len, dtype=torch.bool)

    for i in range(batch_size):
        n = features[i].shape[0]
        padded_features[i, :n] = features[i]
        padded_labels[i, :n] = labels[i]
        attention_mask[i, :n] = 1

    return padded_features, padded_labels, attention_mask


# Test that.
dataset = VergeDataset(train_tiles, n_classes, mask_fraction=0.15)
batch = [dataset[k] for k in [0, 12, 17, 23]]
batch_features, batch_labels, batch_attention_mask = collate_fn(batch)
print('batch_features.shape', batch_features.shape)
print('batch_labels.shape', batch_labels.shape)
print('batch_attention_mask.shape', batch_attention_mask.shape)


## Model definition

In [None]:
class GeospatialTransformer(nn.Module):


    def __init__(self, feature_dim, model_dim, num_heads, num_layers, num_classes, dropout):
        super().__init__()

        self.input_proj = nn.Linear(feature_dim, model_dim)

        encoder_layer = nn.TransformerEncoderLayer(
            d_model=model_dim,
            nhead=num_heads,
            dim_feedforward=4 * model_dim,
            dropout=dropout,
            batch_first=True
        )
        self.encoder = nn.TransformerEncoder(encoder_layer, num_layers=num_layers)
        self.output_head = nn.Linear(model_dim, num_classes)


    def forward(self, x, attention_mask):
        """
        x: Tensor of shape [batch_size, n_entities, encoding_dim]
        attention_mask: Tensor of shape [batch_size, n_entities], with 1 for valid, 0 for padding
        """
        # print('input', x.shape)

        x = self.input_proj(x)
        # print('projected', x.shape)

        # Transformer expects padding mask: True for PAD tokens
        pad_mask = (attention_mask == 0)
        x = self.encoder(x, src_key_padding_mask=pad_mask)
        # print('transformed', x.shape)

        # x = torch.flatten(x, start_dim=1)
        # print('flattened', x.shape)

        logits = self.output_head(x)
        # print('logits', logits.shape)

        return logits


    def embed(self, x, attention_mask):
        """
        Returns an embedding for the input features
        """
        x = self.input_proj(x)
        pad_mask = (attention_mask == 0)
        x = self.encoder(x, src_key_padding_mask=pad_mask)
        return x


In [None]:
model = GeospatialTransformer(
    feature_dim = geo_encoding_dim + n_classes,
    model_dim=128,
    num_heads=4,
    num_layers=5,
    num_classes=n_classes,
    dropout=0.2
)
n_param = sum(p.numel() for p in model.parameters() if p.requires_grad)
print('%d trainable parameters in model' % n_param)

### Testing

In [None]:
# dataset = VergeDataset(train_tiles, n_classes, mask_fraction=0.15)
# dataloader = DataLoader(
#     dataset,
#     batch_size=2,            # Tune depending on GPU memory
#     shuffle=True,
#     collate_fn=collate_fn,   # Key for padding variable-length instances
#     drop_last=False
# )

# features, labels, attention_mask = dataloader.__iter__().__next__()
# print(features.shape, labels.shape, attention_mask.shape)

In [None]:
# model(features, attention_mask)

### Real traing loop

In [None]:
from torch.utils.data import DataLoader

# Initialize training and validation datasets.

train_dataset = VergeDataset(train_tiles, n_classes, mask_fraction=0.15)
train_dataloader = DataLoader(
    dataset,
    batch_size=16, # Tune depending on GPU memory
    shuffle=True,
    collate_fn=collate_fn,
    drop_last=False
)

val_dataset = VergeDataset(val_tiles, n_classes, mask_fraction=0.15)
val_dataloader = DataLoader(
    dataset,
    batch_size=16, # Tune depending on GPU memory
    shuffle=True,
    collate_fn=collate_fn,
    drop_last=False
)

In [None]:
model.to(device)
optimizer = torch.optim.Adam(model.parameters(), lr=1e-4)
criterion = nn.CrossEntropyLoss(ignore_index=-100)

losses = []

model.train()
for epoch in range(80):

    # Training.
    model.train()
    for features, labels, attention_mask in train_dataloader:
        features = features.to(device)
        labels = labels.to(device)
        attention_mask = attention_mask.to(device)

        logits = model(features, attention_mask)
        loss = criterion(
            logits.view(-1, n_classes),
            labels.view(-1)
        )

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

    # Validation loss
    model.eval()
    with torch.no_grad():
        for features, labels, attention_mask in val_dataloader:
            features = features.to(device)
            labels = labels.to(device)
            attention_mask = attention_mask.to(device)
            logits = model(features, attention_mask)
            val_loss = criterion(
                logits.view(-1, n_classes),
                labels.view(-1)
            )

    losses.append({
        'epoch': epoch,
        'train_loss': loss.item(),
        'val_loss': val_loss.item()
    })

    print(f"Epoch {epoch+1}, train loss: {loss.item():.4f}, val_loss: {val_loss.item():.4f}")


In [None]:
# Save the model.
model_fname = '%s/models/model-%s' % (project_home, run_id)
torch.save(model, model_fname)
print('saved %s' % model_fname)

In [None]:
import plotly
from plotly.subplots import make_subplots
from plotly.graph_objects import Scatter

epochs = [d['epoch'] for d in losses]
train_losses = [d['train_loss'] for d in losses]
val_losses = [d['val_loss'] for d in losses]

fig = make_subplots(rows=1, cols=1)
trace = Scatter(
    x=epochs, y=train_losses, name='training loss',
    mode='markers+lines', marker_color='blue'
)
fig.append_trace(trace, 1, 1)

trace = Scatter(
    x=epochs, y=val_losses, name='validation loss',
    mode='markers+lines', marker_color='green'
)
fig.append_trace(trace, 1, 1)

fig

## Validation Visualization

In [None]:
# Process the validation dataset, getting the class probability predictions
# for every instance.
model.to(device)
criterion = nn.CrossEntropyLoss(ignore_index=-100)

cases = []

model.train()
for features, labels, attention_mask in val_dataloader:

    features = features.to(device)
    attention_mask = attention_mask.to(device)
    labels = labels.to(device)

    logits = model(features, attention_mask)

    batch_size = logits.shape[0]
    for i in range(batch_size):
        case_logits = logits[i]
        case_probs = torch.softmax(case_logits, dim=1)
        case_labels = labels[i]
        entity_count = len(case_labels)
        for k in range(entity_count):
            if case_labels[k].item() >= 0: # Skips the "-100" labels.
                cases.append({
                    'true_label': case_labels[k].item(),
                    'probs': torch.Tensor.cpu(case_probs[k, :]).detach().numpy()
                })
    # if len(cases) >= 1000:
    #     break

print('compiled prediction probabilities for %d validation instances' % len(cases))

In [None]:
import numpy as np
import matplotlib.pyplot as plt
from collections import defaultdict

class_count = max(d["true_label"] for d in cases) + 1
probs_by_class = defaultdict(list)

for d in cases:
    label = d["true_label"]
    probs = np.array(d["probs"])
    probs_by_class[label].append(probs)

# For each true class, compute the mean probability vector
mean_probs = []
for t in range(class_count):
    if probs_by_class[t]:
        mean = np.stack(probs_by_class[t]).mean(axis=0)
    else:
        mean = np.zeros(class_count)  # if no samples for this class
    mean_probs.append(mean)

# Convert to 2D array: [true_class, predicted_class]
matrix = np.stack(mean_probs)  # shape [C, C]

# Plot heatmap
fig, ax = plt.subplots(figsize=(9, 6))
im = ax.imshow(matrix, cmap='viridis', aspect='auto')

plt.colorbar(im, ax=ax, label='Avg Predicted Probability')
ax.set_title("Mean Predicted Probabilities by True Class")
ax.set_xlabel("Predicted Class")
ax.set_ylabel("True Class")
ax.set_xticks(range(class_count))
ax.set_yticks(range(class_count))
ax.set_yticklabels([label_name_lookup[i] for i in range(class_count)])
plt.tight_layout()
plt.show()



In [None]:
class_count = max(d["true_label"] for d in cases) + 1
print(class_count)
cmat = np.zeros((class_count, class_count))

for d in cases:
    true_label = d["true_label"]
    pred_label = np.argmax(d["probs"])
    cmat[true_label, pred_label] += 1

cmat = np.sqrt(cmat)

# Plot heatmap
fig, ax = plt.subplots(figsize=(12, 8))
im = ax.imshow(cmat, cmap='viridis', aspect='auto')

plt.colorbar(im, ax=ax, label='Number Of Cases')
ax.set_title("Confusion Matrix")
ax.set_xlabel("Predicted Class")
ax.set_ylabel("True Class")
ax.set_xticks(range(class_count))
ax.set_yticks(range(class_count))
ax.set_yticklabels(['%s [%d]' % (label_name_lookup[i], i) for i in range(class_count)])
plt.tight_layout()
plt.show()
