### VERGE: Vector-mode Regional Geospatial Encoding
# Masked Geospatial Model Training

Here we build and train a "masked geospatial model".
This is a model in which each input is a set of encoded geospatial entities,
consisting of a concatenation of a multi-point proximity encoding and a one-hot label vector.
Modeling consists of masking the labels for a random selection of entities,
passing the data through an encoder-based architecutre to predicte the labels of masked entities.
The idea is that the encodings then capture information about the region.


## Processing Setup

In [1]:
# Google colab setup
import os
from google.colab import drive
drive.mount('/content/drive')
project_home = '/content/drive/MyDrive/Projects/verge'
os.chdir(project_home)

!pip install geo_encodings

Mounted at /content/drive
Collecting geo_encodings
  Downloading geo_encodings-1.0.4-py2.py3-none-any.whl.metadata (4.0 kB)
Downloading geo_encodings-1.0.4-py2.py3-none-any.whl (6.9 kB)
Installing collected packages: geo_encodings
Successfully installed geo_encodings-1.0.4


In [2]:
# Local processing setup
# project_home = '..'

## Notebook Setup

In [3]:
import pandas as pd
import numpy as np
import glob
import pickle
import os
import torch
import torch.nn as nn
import torch.utils
import torch.utils.data
from torch.utils.data import DataLoader
import copy
import json
from geo_encodings import MPPEncoder

import sys
sys.path.append(project_home)
from utils.geo_transformer_mem import VergeDataset, verge_collate_fn, GeospatialTransformer


## Parameters

In [4]:
# The name of the ROI to use.
roi_name = 'newengland'

# The name of the general-purpose data directory.
data_home = '%s/data' % (project_home)

# The name of the ROI-specific data directory.
roi_home = '%s/data/%s' % (project_home, roi_name)

# As always.
np.random.seed(5)

In [5]:
# A unique identifier for this run. This will be a component of any
# output file names.
run_id = '301b'

# If this is not None, then start the training with the weights from this
# version of the model.
input_run_id = '301a'

# Identifier of the "splits" file to use.
split_id = '201'

# Figure out what type of device to train on.
device = 'cuda' if torch.cuda.is_available() else 'cpu'
print('using device', device)

# Fraction of cases to use for training.
train_fraction = 0.8

# Number of epochs to run.
epoch_count = 50

using device cuda


## Preliminaries

In [6]:
# Read the ROI definition.
fname = '%s/roi.json' % roi_home
with open(fname) as source:
    roi = json.load(source)

tile_size = roi['tile_size']
encoding_resolution = roi['encoding_resolution']

# We need the dimension of the encoding.
encoder = MPPEncoder(
    region=[0, 0, tile_size, tile_size],
    resolution=encoding_resolution,
    center=True
)
geo_encoding_dim = len(encoder)
print('%d elements in encodings' % geo_encoding_dim)


400 elements in encodings


In [7]:
# Read the list of labels.
fname = '%s/labels.csv' % data_home
labels = pd.read_csv(fname)
n_classes = len(labels)
print('%d labels in this dataset' % n_classes)

label_id_lookup = {
    z['label']: z['id']
    for z in labels.to_dict('records')
}

label_name_lookup = {
    z['id']: z['label']
    for z in labels.to_dict('records')
}

22 labels in this dataset


In [8]:
# Read the file that gives class probabilities.
fname = '%s/class_info.csv' % roi_home
class_info = pd.read_csv(fname)
print('%d class info records' % len(class_info))

22 class info records


## Load data


In [9]:
# Get a list of file names, and separate them into training and validation sets.
# Note the "splits" file lists individual tiles, but the input files for the model
# -- ie. the encodings -- are per AOI. So we use the splits file to assign AOI names
# to different splits.

fname = '%s/models/splits-%s.csv' % (roi_home, split_id)
splits = pd.read_csv(fname)
splits.head()

aoi_split_lookup = {
    z['aoi_tag']: z['split']
    for z in splits.to_dict('records')
}

train_fnames = []
val_fnames = []
for aoi_tag in aoi_split_lookup:
    aoi_encoding_fname = '%s/encodings/%s.pkl' % (roi_home, aoi_tag)
    s = aoi_split_lookup[aoi_tag]
    if s == 'train':
        train_fnames.append(aoi_encoding_fname)
    else:
        val_fnames.append(aoi_encoding_fname)
print('%d files with training data' % len(train_fnames))
print('%d files with validation data' % len(val_fnames))


207 files with training data
56 files with validation data


In [10]:
# The dataset constructor requires a lookup table for class probabilities.
class_prob_lookup = {
    z['label']: z['prob']
    for z in class_info.to_dict('records')
}

# Initialize training and validation datasets.
train_dataset = VergeDataset(train_fnames, n_classes, mask_fraction=0.15, class_prob=class_prob_lookup)
train_dataloader = DataLoader(
    train_dataset,
    batch_size=128, # Tune depending on GPU memory
    shuffle=True,
    collate_fn=verge_collate_fn,
    drop_last=False
)

val_dataset = VergeDataset(val_fnames, n_classes, mask_fraction=0.15, class_prob=class_prob_lookup)
val_dataloader = DataLoader(
    val_dataset,
    batch_size=128, # Tune depending on GPU memory
    shuffle=True,
    collate_fn=verge_collate_fn,
    drop_last=False
)

loaded 63 instances from /content/drive/MyDrive/Projects/verge/data/newengland/encodings/0735w-413n.pkl
loaded 63 instances from /content/drive/MyDrive/Projects/verge/data/newengland/encodings/0735w-414n.pkl
loaded 39 instances from /content/drive/MyDrive/Projects/verge/data/newengland/encodings/0735w-446n.pkl
loaded 63 instances from /content/drive/MyDrive/Projects/verge/data/newengland/encodings/0734w-414n.pkl
loaded 62 instances from /content/drive/MyDrive/Projects/verge/data/newengland/encodings/0733w-414n.pkl
loaded 41 instances from /content/drive/MyDrive/Projects/verge/data/newengland/encodings/0733w-423n.pkl
loaded 18 instances from /content/drive/MyDrive/Projects/verge/data/newengland/encodings/0733w-444n.pkl
loaded 44 instances from /content/drive/MyDrive/Projects/verge/data/newengland/encodings/0733w-445n.pkl
loaded 63 instances from /content/drive/MyDrive/Projects/verge/data/newengland/encodings/0732w-413n.pkl
loaded 60 instances from /content/drive/MyDrive/Projects/verge/d

## Model definition

In [11]:
model = GeospatialTransformer(
    feature_dim = geo_encoding_dim + n_classes,
    model_dim=128,
    num_heads=4,
    num_layers=4,
    num_classes=n_classes,
    dropout=0.2
)
n_param = sum(p.numel() for p in model.parameters() if p.requires_grad)
print('%d trainable parameters in model' % n_param)

850070 trainable parameters in model


In [13]:
# If we specified an input model, load it.
if input_run_id is not None:
    model_fname = '%s/models/transformer-%s' % (roi_home, input_run_id)
    model = torch.load(model_fname, weights_only=False)
    print('loaded %s' % model_fname)
    n_param = sum(p.numel() for p in model.parameters() if p.requires_grad)
    print('%d trainable parameters in model' % n_param)
else:
    print('no input model specified')

loaded /content/drive/MyDrive/Projects/verge/data/newengland/models/transformer-301a
850070 trainable parameters in model


### Training loop

In [None]:
model.to(device)
optimizer = torch.optim.Adam(model.parameters(), lr=1e-4)
criterion = nn.CrossEntropyLoss(ignore_index=-100)

losses = []
min_val_loss = None

model.train()
for epoch in range(epoch_count):

    # Training.
    model.train()
    ibatch = 0
    for features, labels, attention_mask, idents in train_dataloader:
        ibatch += 1

        features = features.to(device)
        labels = labels.to(device)
        attention_mask = attention_mask.to(device)

        logits = model(features, attention_mask)
        loss = criterion(
            logits.view(-1, n_classes),
            labels.view(-1)
        )

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

    # Validation loss
    model.eval()
    with torch.no_grad():
        for features, labels, attention_mask, idents in val_dataloader:
            features = features.to(device)
            labels = labels.to(device)
            attention_mask = attention_mask.to(device)
            logits = model(features, attention_mask)
            val_loss = criterion(
                logits.view(-1, n_classes),
                labels.view(-1)
            )

    losses.append({
        'epoch': epoch,
        'train_loss': loss.item(),
        'val_loss': val_loss.item()
    })

    print(f"Epoch {epoch+1}, train loss: {loss.item():.4f}, val_loss: {val_loss.item():.4f}")

    if min_val_loss is None or val_loss.item() < min_val_loss:
        min_val_loss = val_loss.item()
        model_fname = '%s/models/transformer-%s' % (roi_home, run_id)
        torch.save(model, model_fname)
        print('saved %s' % model_fname)


  output = torch._nested_tensor_from_mask(


Epoch 1, train loss: 1.1084, val_loss: 1.1040
saved /content/drive/MyDrive/Projects/verge/data/newengland/models/transformer-301b
Epoch 2, train loss: 1.1324, val_loss: 1.0317
saved /content/drive/MyDrive/Projects/verge/data/newengland/models/transformer-301b
Epoch 3, train loss: 0.9956, val_loss: 1.1552
Epoch 4, train loss: 1.0665, val_loss: 1.1322
Epoch 5, train loss: 1.1917, val_loss: 1.0694
Epoch 6, train loss: 1.1885, val_loss: 1.1471
Epoch 7, train loss: 1.1567, val_loss: 1.0370
Epoch 8, train loss: 1.0119, val_loss: 1.0734
Epoch 9, train loss: 1.0855, val_loss: 1.0099
saved /content/drive/MyDrive/Projects/verge/data/newengland/models/transformer-301b
Epoch 10, train loss: 1.0553, val_loss: 1.0114
Epoch 11, train loss: 1.1319, val_loss: 1.0368
Epoch 12, train loss: 1.1132, val_loss: 1.0519
Epoch 13, train loss: 1.0402, val_loss: 1.0521
Epoch 14, train loss: 1.1585, val_loss: 1.0998
Epoch 15, train loss: 1.1522, val_loss: 1.0522
Epoch 16, train loss: 1.0488, val_loss: 0.9923
saved

## Loss history

In [None]:
import plotly
from plotly.subplots import make_subplots
from plotly.graph_objects import Scatter

epochs = [d['epoch'] for d in losses]
train_losses = [d['train_loss'] for d in losses]
val_losses = [d['val_loss'] for d in losses]

fig = make_subplots(rows=1, cols=1)
trace = Scatter(
    x=epochs, y=train_losses, name='training loss',
    mode='markers+lines', marker_color='blue'
)
fig.append_trace(trace, 1, 1)

trace = Scatter(
    x=epochs, y=val_losses, name='validation loss',
    mode='markers+lines', marker_color='green'
)
fig.append_trace(trace, 1, 1)

lo = fig['layout']
lo['xaxis1']['title'] = 'epoch'
lo['yaxis1']['title'] = 'loss'
lo['width'] = 900
lo['height'] = 400

fig

## Validation Visualization

In [None]:
# Process the validation dataset, getting the class probability predictions
# for every instance.
model.to(device)
cases = []
model.eval()
for features, labels, attention_mask, idents in val_dataloader:

    features = features.to(device)
    attention_mask = attention_mask.to(device)
    labels = labels.to(device)

    logits = model(features, attention_mask)

    batch_size = logits.shape[0]
    for i in range(batch_size):
        case_logits = logits[i]
        case_probs = torch.softmax(case_logits, dim=1)
        case_labels = labels[i]
        entity_count = len(case_labels)
        for k in range(entity_count):
            if case_labels[k].item() >= 0: # Skips the "-100" labels.
                cases.append({
                    'true_label': case_labels[k].item(),
                    'probs': torch.Tensor.cpu(case_probs[k, :]).detach().numpy()
                })
    break

print('compiled prediction probabilities for %d validation instances' % len(cases))

In [None]:
import numpy as np
import matplotlib.pyplot as plt
from collections import defaultdict

class_count = max(d["true_label"] for d in cases) + 1
probs_by_class = defaultdict(list)

for d in cases:
    label = d["true_label"]
    probs = np.array(d["probs"])
    probs_by_class[label].append(probs)

# For each true class, compute the mean probability vector
mean_probs = []
for t in range(class_count):
    if probs_by_class[t]:
        mean = np.stack(probs_by_class[t]).mean(axis=0)
    else:
        mean = np.zeros(class_count)  # if no samples for this class
    mean_probs.append(mean)

# Convert to 2D array: [true_class, predicted_class]
matrix = np.stack(mean_probs)  # shape [C, C]

# Plot heatmap
fig, ax = plt.subplots(figsize=(9, 6))
im = ax.imshow(matrix, cmap='viridis', aspect='auto')

plt.colorbar(im, ax=ax, label='Avg Predicted Probability')
ax.set_title("Mean Predicted Probabilities by True Class")
ax.set_xlabel("Predicted Class")
ax.set_ylabel("True Class")
ax.set_xticks(range(class_count))
ax.set_yticks(range(class_count))
ax.set_yticklabels([label_name_lookup[i] for i in range(class_count)])
plt.tight_layout()
plt.show()

In [None]:
# class_count = max(d["true_label"] for d in cases) + 1
# print(class_count)
# cmat = np.zeros((class_count, class_count))

# for d in cases:
#     true_label = d["true_label"]
#     pred_label = np.argmax(d["probs"])
#     cmat[true_label, pred_label] += 1

# cmat = np.sqrt(cmat)

# # Plot heatmap
# fig, ax = plt.subplots(figsize=(12, 8))
# im = ax.imshow(cmat, cmap='viridis', aspect='auto')

# plt.colorbar(im, ax=ax, label='Number Of Cases')
# ax.set_title("Confusion Matrix")
# ax.set_xlabel("Predicted Class")
# ax.set_ylabel("True Class")
# ax.set_xticks(range(class_count))
# ax.set_yticks(range(class_count))
# ax.set_yticklabels(['%s [%d]' % (label_name_lookup[i], i) for i in range(class_count)])
# plt.tight_layout()
# plt.show()
