### VERGE: Vector-mode Regional Geospatial Encoding
# Masked Geospatial Model Implementation

Here we build and train a "masked geospatial model".
This is a model in which each input is a set of encoded geospatial entities,
consisting of a concatenation of a multi-point proximity encoding and a one-hot label vector.
Modeling consists of masking the labels for a random selection of entities,
passing the data through an encoder-based architecutre to predicte the labels of masked entities.
The idea is that the encodings then capture information about the region.


## Colab setup

In [None]:
# from google.colab import drive
# drive.mount('/content/drive')
# project_home = '/content/drive/MyDrive/Projects/verge'

In [None]:
# The top-level directory for this project.
project_home = '..'

# The name of the ROI to use.
roi_name = 'ne-dev'

# The name of the general-purpose data directory.
data_home = '%s/data' % (project_home)

# The name of the ROI-specific data directory.
roi_home = '%s/data/%s' % (project_home, roi_name)


## Setup

In [None]:
import pandas as pd
import numpy as np
import glob
import pickle
import os
import torch
import torch.nn as nn
import torch.utils
import torch.utils.data
from torch.utils.data import DataLoader
import copy
import json
from geo_encodings import MPPEncoder
from geo_transformer import VergeDataset, verge_collate_fn, GeospatialTransformer


## Parameters

In [None]:
# Read the ROI definition.
fname = '%s/roi.json' % roi_home
with open(fname) as source:
    roi = json.load(source)

tile_size = roi['tile_size']
encoding_resolution = roi['encoding_resolution']

# We need the dimension of the encoding.
encoder = MPPEncoder(
    region=[0, 0, tile_size, tile_size],
    resolution=encoding_resolution,
    center=True
)
geo_encoding_dim = len(encoder)
print('%d elements in encodings' % geo_encoding_dim)


In [None]:
# A unique identifier for this run. This will be a component of any
# output file names.
run_id = '003'

# What type of device to train on.
device = 'cuda' if torch.cuda.is_available() else 'cpu'
print('using device', device)

# Fraction of cases to use for training.
train_fraction = 0.8

# Number of epochs to run.
epoch_count = 100

## Preliminaries

In [None]:
# Read the list of labels.
fname = '%s/labels.csv' % data_home
labels = pd.read_csv(fname)
n_classes = len(labels)
print('%d labels in this dataset' % n_classes)

label_id_lookup = {
    z['label']: z['id']
    for z in labels.to_dict('records')
}

label_name_lookup = {
    z['id']: z['label']
    for z in labels.to_dict('records')
}

## Load data

In [None]:
# Get a list of input data files. Each file consists of a list of encodings for
# a number of square tiles in a particluar AOI.
globstring = '%s/encodings/*' % roi_home
fnames = glob.glob(globstring)
print('%d input files with encodings' % len(fnames))

In [None]:
# Split the dataset into training and validation components based on AOI.
# This minimizes the possibility of spatial autocorrelation biasing performance assessments.
splits = []
for fname in fnames:
  if np.random.random() < train_fraction:
    splits.append({'fname': fname, 'type': 'train'})
  else:
    splits.append({'fname': fname, 'type': 'val'})

# Save that.
fname = '%s/split-%s.csv' % (roi_home, run_id)
pd.DataFrame(splits).to_csv(fname, index=False)
print('saved %s' % fname)


In [None]:
# Read all data. This input consists of the endoded geospatial features
# for individual tiles, concatenated with one-hot vectors indicating
# their class.
#
# Each input file contains data for all tiles in an AOI. We split into
# training and validation sets based on AOI to avoid
#

split_lookup = {
    z['fname'] : z['type']
    for z in splits
}

train_tiles = []
val_tiles = []

for fname in fnames:
    # print('reading', fname)
    with open(fname, 'rb') as source:
        data =  pickle.load(source)
        sp = split_lookup[fname]
        if sp == 'train':
            train_tiles += data
        else:
            val_tiles += data

print('%d training tiles' % len(train_tiles))
print('%d validation tiles' % len(val_tiles))

In [None]:
# # Test that.
# dataset = VergeDataset(train_tiles, n_classes, mask_fraction=0.15)
# batch = [dataset[k] for k in [0, 12, 17, 23]]
# batch_features, batch_labels, batch_attention_mask = verge_collate_fn(batch)
# print('test:')
# print('batch_features.shape', batch_features.shape)
# print('batch_labels.shape', batch_labels.shape)
# print('batch_attention_mask.shape', batch_attention_mask.shape)


In [None]:
from torch.utils.data import DataLoader

# Initialize training and validation datasets.

train_dataset = VergeDataset(train_tiles, n_classes, mask_fraction=0.15)
train_dataloader = DataLoader(
    train_dataset,
    batch_size=16, # Tune depending on GPU memory
    shuffle=True,
    collate_fn=verge_collate_fn,
    drop_last=False
)

val_dataset = VergeDataset(val_tiles, n_classes, mask_fraction=0.15)
val_dataloader = DataLoader(
    val_dataset,
    batch_size=16, # Tune depending on GPU memory
    shuffle=True,
    collate_fn=verge_collate_fn,
    drop_last=False
)

## Model definition

In [None]:
model = GeospatialTransformer(
    feature_dim = geo_encoding_dim + n_classes,
    model_dim=128,
    num_heads=4,
    num_layers=5,
    num_classes=n_classes,
    dropout=0.2
)
n_param = sum(p.numel() for p in model.parameters() if p.requires_grad)
print('%d trainable parameters in model' % n_param)

In [None]:
# Testing
# dataset = VergeDataset(train_tiles, n_classes, mask_fraction=0.15)
# dataloader = DataLoader(
#     dataset,
#     batch_size=2,            # Tune depending on GPU memory
#     shuffle=True,
#     collate_fn=verge_collate_fn,   # Key for padding variable-length instances
#     drop_last=False
# )

# features, labels, attention_mask = dataloader.__iter__().__next__()
# print(features.shape, labels.shape, attention_mask.shape)

In [None]:
# model(features, attention_mask)

### Training loop

In [None]:
model.to(device)
optimizer = torch.optim.Adam(model.parameters(), lr=1e-4)
criterion = nn.CrossEntropyLoss(ignore_index=-100)

losses = []

model.train()
for epoch in range(epoch_count):

    # Training.
    model.train()
    for features, labels, attention_mask in train_dataloader:
        features = features.to(device)
        labels = labels.to(device)
        attention_mask = attention_mask.to(device)

        logits = model(features, attention_mask)
        loss = criterion(
            logits.view(-1, n_classes),
            labels.view(-1)
        )

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

    # Validation loss
    model.eval()
    with torch.no_grad():
        for features, labels, attention_mask in val_dataloader:
            features = features.to(device)
            labels = labels.to(device)
            attention_mask = attention_mask.to(device)
            logits = model(features, attention_mask)
            val_loss = criterion(
                logits.view(-1, n_classes),
                labels.view(-1)
            )

    losses.append({
        'epoch': epoch,
        'train_loss': loss.item(),
        'val_loss': val_loss.item()
    })

    print(f"Epoch {epoch+1}, train loss: {loss.item():.4f}, val_loss: {val_loss.item():.4f}")


In [None]:
# Save the model.
model_fname = '%s/model-%s' % (roi_home, run_id)
torch.save(model, model_fname)
print('saved %s' % model_fname)

## Loss history

In [None]:
import plotly
from plotly.subplots import make_subplots
from plotly.graph_objects import Scatter

epochs = [d['epoch'] for d in losses]
train_losses = [d['train_loss'] for d in losses]
val_losses = [d['val_loss'] for d in losses]

fig = make_subplots(rows=1, cols=1)
trace = Scatter(
    x=epochs, y=train_losses, name='training loss',
    mode='markers+lines', marker_color='blue'
)
fig.append_trace(trace, 1, 1)

trace = Scatter(
    x=epochs, y=val_losses, name='validation loss',
    mode='markers+lines', marker_color='green'
)
fig.append_trace(trace, 1, 1)

fig

## Validation Visualization

In [None]:
# Process the validation dataset, getting the class probability predictions
# for every instance.
model.to(device)
cases = []
model.eval()
for features, labels, attention_mask in val_dataloader:

    features = features.to(device)
    attention_mask = attention_mask.to(device)
    labels = labels.to(device)

    logits = model(features, attention_mask)

    batch_size = logits.shape[0]
    for i in range(batch_size):
        case_logits = logits[i]
        case_probs = torch.softmax(case_logits, dim=1)
        case_labels = labels[i]
        entity_count = len(case_labels)
        for k in range(entity_count):
            if case_labels[k].item() >= 0: # Skips the "-100" labels.
                cases.append({
                    'true_label': case_labels[k].item(),
                    'probs': torch.Tensor.cpu(case_probs[k, :]).detach().numpy()
                })

print('compiled prediction probabilities for %d validation instances' % len(cases))

In [None]:
import numpy as np
import matplotlib.pyplot as plt
from collections import defaultdict

class_count = max(d["true_label"] for d in cases) + 1
probs_by_class = defaultdict(list)

for d in cases:
    label = d["true_label"]
    probs = np.array(d["probs"])
    probs_by_class[label].append(probs)

# For each true class, compute the mean probability vector
mean_probs = []
for t in range(class_count):
    if probs_by_class[t]:
        mean = np.stack(probs_by_class[t]).mean(axis=0)
    else:
        mean = np.zeros(class_count)  # if no samples for this class
    mean_probs.append(mean)

# Convert to 2D array: [true_class, predicted_class]
matrix = np.stack(mean_probs)  # shape [C, C]

# Plot heatmap
fig, ax = plt.subplots(figsize=(9, 6))
im = ax.imshow(matrix, cmap='viridis', aspect='auto')

plt.colorbar(im, ax=ax, label='Avg Predicted Probability')
ax.set_title("Mean Predicted Probabilities by True Class")
ax.set_xlabel("Predicted Class")
ax.set_ylabel("True Class")
ax.set_xticks(range(class_count))
ax.set_yticks(range(class_count))
ax.set_yticklabels([label_name_lookup[i] for i in range(class_count)])
plt.tight_layout()
plt.show()

In [None]:
# class_count = max(d["true_label"] for d in cases) + 1
# print(class_count)
# cmat = np.zeros((class_count, class_count))

# for d in cases:
#     true_label = d["true_label"]
#     pred_label = np.argmax(d["probs"])
#     cmat[true_label, pred_label] += 1

# cmat = np.sqrt(cmat)

# # Plot heatmap
# fig, ax = plt.subplots(figsize=(12, 8))
# im = ax.imshow(cmat, cmap='viridis', aspect='auto')

# plt.colorbar(im, ax=ax, label='Number Of Cases')
# ax.set_title("Confusion Matrix")
# ax.set_xlabel("Predicted Class")
# ax.set_ylabel("True Class")
# ax.set_xticks(range(class_count))
# ax.set_yticks(range(class_count))
# ax.set_yticklabels(['%s [%d]' % (label_name_lookup[i], i) for i in range(class_count)])
# plt.tight_layout()
# plt.show()
