### VERGE: Vector-mode Regional Geospatial Encoding
# VERGE model implementation


Here we build and train a "masked geospatial model". 
This is a model in which each inpout is a set of encoded geospatial entities,
consisting of a cooncatenation of a multi-point proximity encoding and a one-hot label vector.
Modeling consists of masking the labels for a random selection of entities, 
passing the data through an encoder-based architecutre to predicte the labels of masked entities. 
The idea is that the encodings then capture information about the region.

## Version 1:
THis version uses a basic architecture proposed by ChatGPT. Of course I have modified 
the original suggestion heavily. This helped to establish the processing flow,
but the model architcture is not realluy what I wanted. The main issue is the 
"permutation invariant" part, which was the GPT suggestion based on my request that 
the model be invariant with respect to the ordering of its inputs (i.e. the rows of the
feature matrix). I am not convinced that this implementation really works.
But either way I am abandoning this in favor of a different approach that is
more verifiably permutation-invariant.

In [None]:
import pandas as pd
import numpy as np
import glob
import pickle
import os
import torch
import copy

## Parameters

In [None]:
# This is the dimension of the (square) AOIs. Set thi to match what was used
# when the tiles were created.
aoi_size = 500

# This is the resolution of the MPP encoding.
resolution = 50

## Preliminaries

In [None]:
# Read the list of labels.
fname = 'labels.csv'
labels = pd.read_csv(fname)
n_classes = len(labels)
print('%d labels in this dataset' % n_classes)

label_id_lookup = {
    z['label']: z['id']
    for z in labels.to_dict('records')
}

mask_label = label_id_lookup['token : mask']
print('mask label: %s' % mask_label)

pad_label = label_id_lookup['token : pad']
print('pad label: %s' % pad_label)


In [None]:
# Get a list of input data files. Each file consists of a list of encodings for 
# a square tile.
globstring = 'data/encodings/*'
fnames = glob.glob(globstring)
print('%d input files' % len(fnames))


In [None]:
# Read some data.
tile_data_list = []
for fname in fnames[:5]:
    print('reading', fname)
    with open(fname, 'rb') as source:
        tile_data_list += pickle.load(source)

print('%d instances total' % len(tile_data_list))

In [None]:
# This class wraps a list of input tile data as a pytorch dataset.
# The "getitem" method here parses apart the true labels and the encodings,
# and applies random masking to the endocing.

class VergeDataset(torch.utils.data.Dataset):
    
    def __init__(self, data_list, n_classes, mask_fraction=0.15, mask_label_index=-1):
        self.data = data_list
        self.n_classes = n_classes
        self.mask_fraction = mask_fraction
        self.encoding_dim = data_list[0].shape[1] - self.n_classes
        self.mask_label_index = mask_label_index

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        features = self.data[idx]
        encodings = features[:, self.n_classes:]
        true_labels_onehot = features[:, :self.n_classes]
        true_labels = np.argmax(true_labels_onehot, axis=1)
        n_entities = features.shape[0]

        # Define the "mask" label as a one-hot vector.
        mask_label_onehot = np.zeros(self.n_classes)
        mask_label_onehot[self.mask_label_index] = 1.0

        # Select a few entities for which to assign masked labels.
        mask = np.random.rand(n_entities) < self.mask_fraction
        mask_indices = np.where(mask)

        # Get true and masked labels as integer arrays.
        masked_labels_onehot = copy.copy(true_labels_onehot)
        for i in mask_indices:
            masked_labels_onehot[i] = mask_label_onehot

        # The features to be returned are a concatenation of the masked labels
        # and the encodings
        masked_features = torch.concat((
            torch.tensor(masked_labels_onehot, dtype=torch.float32), 
            torch.tensor(encodings, dtype=torch.float32)
        ), axis=-1)

        # During model training below, we will be using the "CrossEntropyLoss" function,
        # which has a built-in capability to ignore un-masked entitites. To get it to work,
        # we need to pack the "ignore" token into any label slot that is not masked.
        # Pytorch's standard value for that token is -100.
        labels = torch.full(true_labels.shape, -100, dtype=torch.long)
        for i in mask_indices[0]:
            labels[i] = true_labels[i]
        
        return (masked_features, labels)

dataset = VergeDataset(tile_data_list, n_classes, mask_label_index=mask_label)
features, labels = dataset[0]
print('features.shape', features.shape)
print('labels.shape', labels.shape)


In [None]:
# New ChatGPT suggestion, using integer labels. This will need some editing.
def collate_fn(batch):
    features, labels = zip(*batch)
    max_len = max(x.shape[0] for x in features)
    batch_size = len(features)
    encoding_dim = features[0].shape[1]

    padded_features = torch.zeros(batch_size, max_len, encoding_dim)
    padded_labels = torch.full((batch_size, max_len), -100, dtype=torch.long)  # -100 is the "ignore" value
    attention_mask = torch.zeros(batch_size, max_len, dtype=torch.bool)

    for i in range(batch_size):
        n = features[i].shape[0]
        padded_features[i, :n] = features[i]
        padded_labels[i, :n] = labels[i]
        attention_mask[i, :n] = 1

    return padded_features, padded_labels, attention_mask


dataset = VergeDataset(tile_data_list, n_classes)
batch = [dataset[k] for k in [5, 6, 7, 8]]
batch_features, batch_labels, batch_attention_mask = collate_fn(batch)
print('batch_features.shape', batch_features.shape)
print('batch_labels.shape', batch_labels.shape)
print('batch_attention_mask.shape', batch_attention_mask.shape)


## Model definition

In [None]:
import torch
import torch.nn as nn

class GeospatialTransformer(nn.Module):
    def __init__(self, feature_dim, model_dim=256, num_heads=4, num_layers=2, num_classes=10, dropout=0.1):
        super().__init__()
        self.input_proj = nn.Linear(feature_dim, model_dim)

        encoder_layer = nn.TransformerEncoderLayer(
            d_model=model_dim,
            nhead=num_heads,
            dim_feedforward=4 * model_dim,
            dropout=dropout,
            batch_first=True  # Allows [batch, seq, dim] format
        )
        self.encoder = nn.TransformerEncoder(encoder_layer, num_layers=num_layers)

        # Optional permutation-invariant transformation
        self.perm_invariant = nn.Sequential(
            nn.LayerNorm(model_dim),
            nn.Linear(model_dim, model_dim),
            nn.ReLU()
        )

        self.output_head = nn.Linear(model_dim, num_classes)

    def forward(self, x, attention_mask):
        """
        x: Tensor of shape [batch_size, n_entities, encoding_dim]
        attention_mask: Tensor of shape [batch_size, n_entities], with 1 for valid, 0 for padding
        """
        x = self.input_proj(x)  # [batch, n_entities, model_dim]
        # Transformer expects padding mask: True for PAD tokens
        pad_mask = (attention_mask == 0)  # [batch, n_entities]
        x = self.encoder(x, src_key_padding_mask=pad_mask)  # [batch, n_entities, model_dim]
        x = self.perm_invariant(x)
        logits = self.output_head(x)  # [batch, n_entities, num_classes]
        return logits


In [None]:
model = GeospatialTransformer(feature_dim=171, model_dim=64, num_heads=4, num_layers=2, num_classes=50, dropout=0.1)

In [None]:
from torch.utils.data import DataLoader

# Initialize dataset
dataset = VergeDataset(tile_data_list, n_classes)

# Create DataLoader
dataloader = DataLoader(
    dataset,
    batch_size=16,            # Tune depending on GPU memory
    shuffle=True,
    collate_fn=collate_fn,   # Key for padding variable-length instances
    drop_last=False
)

In [None]:
device = 'cpu'
model.to(device)
optimizer = torch.optim.Adam(model.parameters(), lr=1e-4)
criterion = nn.CrossEntropyLoss(ignore_index=-100)

losses = []

model.train()
for epoch in range(50):
    for features, labels, attention_mask in dataloader:
        features = features.to(device)
        labels = labels.to(device)
        attention_mask = attention_mask.to(device)

        logits = model(features, attention_mask)  
        loss = criterion(
            logits.view(-1, n_classes),
            labels.view(-1)
        )

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

    losses.append(loss.item())

    print(f"Epoch {epoch+1}, Loss: {loss.item():.4f}")


In [None]:
import plotly
from plotly.subplots import make_subplots
from plotly.graph_objects import Scatter

fig = make_subplots(rows=1, cols=1)
trace = Scatter(
    x=np.arange(len(losses)), y=losses, name='loss', 
    mode='markers+lines'
)
fig.append_trace(trace, 1, 1)
fig