### VERGE: Vector-mode Regional Geospatial Encoding
# VERGE model implementation


Here we build and train a "masked geospatial model". 
This is a model in which each inpout is a set of encoded geospatial entities,
consisting of a cooncatenation of a multi-point proximity encoding and a one-hot label vector.
Modeling consists of masking the labels for a random selection of entities, 
passing the data through an encoder-based architecutre to predicte the labels of masked entities. 
The idea is that the encodings then capture information about the region.

## Version 2: Partial Autoencoder

I';m trying an approache here, in which the model is trained to predict
a subset of its inpout features. I'm calling it a "partial autoencoder".
The idea is:
* Train an embedding on sets of features A and B, whose verature vectors are concatenated together.
* After the embedding part, add a "head" that predicts A based on B.
* If that works, then presumably the model understands something about the relationship between the two.

## Summary

After trying quite a few runs, this does not seem to be working at all. 
I am going to abandon this effort.



In [None]:
import pandas as pd
import numpy as np
import glob
import pickle
import os
import torch
import torch.nn as nn
from torch.utils.data import DataLoader
import copy
import arrow

## Parameters

In [None]:
# This is the dimension of the (square) AOIs. Set thi to match what was used
# when the tiles were created.
aoi_size = 1000

# This is the resolution of the MPP encoding.
resolution = 50
geo_encoding_dim = 400

# Fraction of cases to use for training.
train_fraction = 0.7

# Sample this fraction of entities per tile
mask_fraction = 0.5

## Preliminaries

In [None]:
# Read the list of labels.
fname = 'data/labels.csv'
labels = pd.read_csv(fname)
n_classes = len(labels)
print('%d labels in this dataset' % n_classes)

label_id_lookup = {
    z['label']: z['id']
    for z in labels.to_dict('records')
}


In [None]:
# Get a list of input data files. Each file consists of a list of encodings for 
# a number of square tiles in a particluar AOI.
globstring = 'data/encodings/*'
fnames = glob.glob(globstring)
print('%d input files' % len(fnames))

In [None]:
# Read some data.
tile_data_list = []
for fname in fnames[:5]:
    print('reading', fname)
    with open(fname, 'rb') as source:
        tile_data_list += pickle.load(source)

# Divide things into training and validation sets.
train_tiles = []
val_tiles = []
for t in tile_data_list:
    if np.random.random() < train_fraction:
        train_tiles.append(t)
    else:
        val_tiles.append(t)

print('%d training instances' % len(train_tiles))
print('%d validation instances' % len(val_tiles))

## Modeling setup

In [None]:
# This class wraps a list of input tile data as a pytorch dataset.
# An instance cinsists of (1) the original features, which are concatenations of
# an MPP necoding and a one-hot class label; (2) MPP encodings
# for a sample of the elements in the tile; (3) labels of the entities in (2). 

class VergeDataset(torch.utils.data.Dataset):
    
    def __init__(self, data_list, n_classes, mask_fraction=0.15):
        self.data = data_list
        self.n_classes = n_classes
        self.mask_fraction = mask_fraction
        self.encoding_dim = data_list[0].shape[1] - self.n_classes
        
        # When accessing any item, we will also be sampling from its available classes.
        # But this dataset has a big class imbalance, so we will sample according
        # to inverse probability. Here we compute the probability distribution of classes.
        self.class_prob = {z: 0.0 for z in range(self.n_classes)}
        n = 0.0
        for d in data_list:
            true_labels_onehot = d[:, :self.n_classes]
            true_labels = np.argmax(true_labels_onehot, axis=1)
            for label in true_labels:
                self.class_prob[label] += 1.0
            n += len(true_labels)
        for label in self.class_prob:
            self.class_prob[label] /= n
        
    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        
        features = self.data[idx]
        encodings = features[:, self.n_classes:]
        true_labels_onehot = features[:, :self.n_classes]
        true_labels = np.argmax(true_labels_onehot, axis=1)
        n_entities = features.shape[0]

        # Select entities for which to assign masked labels.
        weights = []
        for label in true_labels:
            prob = self.class_prob[label]
            weights.append(1.0 / prob)
        weights = np.array(weights)
        weights = weights / np.sum(weights)
        sample_size = int(np.ceil(self.mask_fraction * n_entities))
        sample_indices = np.random.choice(n_entities, size=sample_size, replace=True, p=weights) 

        # Get encodings for those samples.
        sample_encodings = torch.tensor(encodings[sample_indices], dtype=torch.float32)

        # Get true labels for those samples.
        sample_labels = torch.tensor(true_labels[sample_indices], dtype=torch.long)
                
        return (features, sample_encodings, sample_labels)

# dataset = VergeDataset(tile_data_list, n_classes, mask_fraction=mask_fraction)
# features, sample_encodings, sample_labels = dataset[0]
# print('features.shape', features.shape)
# print('sample_encodings.shape', sample_encodings.shape)
# print('sample_labels.shape', sample_labels.shape)


In [None]:
# Define the function that puts together a batch. The main thing we are handling here
# is padding, which only needs to be applied to the input features. 
def collate_fn(batch):
    
    features, sample_encodings, sample_labels = zip(*batch)
    
    batch_size = len(features)
    feature_dim = features[0].shape[1]
    encoding_dim = sample_encodings[0].shape[1]
    
    max_feature_count = max(x.shape[0] for x in features)
    max_sample_count = max(x.shape[0] for x in sample_encodings)

    padded_features = torch.zeros(batch_size, max_feature_count, feature_dim)
    feature_attention_mask = torch.zeros(batch_size, max_feature_count, dtype=torch.bool)

    padded_sample_encodings = torch.zeros(batch_size, max_sample_count, encoding_dim)
    padded_sample_labels = torch.full((batch_size, max_sample_count), -100, dtype=torch.long)  # -100 is the "ignore" value
    sample_attention_mask = torch.zeros(batch_size, max_sample_count, dtype=torch.bool)

    for i in range(batch_size):
        feature_count = features[i].shape[0]
        padded_features[i, :feature_count] = torch.tensor(features[i], dtype=torch.float32)
        feature_attention_mask[i, :feature_count] = 1

        sample_count = sample_encodings[i].shape[0]
        padded_sample_encodings[i, :sample_count] = sample_encodings[i]
        padded_sample_labels[i, :sample_count] = sample_labels[i]
        sample_attention_mask[i, :sample_count] = 1


    return (
        padded_features, feature_attention_mask, 
        padded_sample_encodings, padded_sample_labels, sample_attention_mask
    )

# dataset = VergeDataset(train_tiles, n_classes)
# batch = [dataset[k] for k in [5, 6, 7, 8]]
# features, feature_attention_mask, sample_encodings, sample_labels, sample_attention_mask = collate_fn(batch)
# print('features.shape', features.shape)
# print('feature_attention_mask.shape', feature_attention_mask.shape)
# print('sample_encodings.shape', sample_encodings.shape)
# print('sample_labels.shape', sample_labels.shape)
# print('sample_attention_mask.shape', sample_attention_mask.shape)


## Model definition

In [None]:
class Collector(torch.nn.Module):
    
    def __init__(self, feature_dim, embed_dim, head_count):
        
        super().__init__()

        self.feature_dim = feature_dim
        self.embed_dim = embed_dim
        self.head_count = head_count
        
        weights_r = torch.Tensor(feature_dim, embed_dim)
        self.weights_r = nn.Parameter(weights_r)
        torch.nn.init.normal_(self.weights_r, mean=0.0, std=1.0)

        weights_h = torch.Tensor(feature_dim, head_count)
        self.weights_h = nn.Parameter(weights_h)
        torch.nn.init.normal_(self.weights_h, mean=0.0, std=1.0)

    def forward(self, x):
        matrix_r = torch.matmul(x, self.weights_r)
        matrix_h = torch.matmul(x, self.weights_h)
        matrix_h = torch.nn.functional.softmax(matrix_h, dim=1)
        matrix_e = torch.matmul(torch.transpose(matrix_r, 1, 2), matrix_h)
        return torch.transpose(matrix_e, 1, 2)


In [None]:
class ClassifierHead(torch.nn.Module):
    
    def __init__(self, region_embedding_count, region_embedding_dim, geo_encoding_dim, class_count):
        
        super().__init__()

        weights_g = torch.Tensor(geo_encoding_dim, region_embedding_dim)
        self.weights_g = nn.Parameter(weights_g)
        torch.nn.init.normal_(self.weights_g, mean=0.0, std=1.0)

        weights_c = torch.Tensor(region_embedding_count, class_count)
        self.weights_c = nn.Parameter(weights_c)
        torch.nn.init.normal_(self.weights_c, mean=0.0, std=1.0)

    def forward(self, region_embeddings, sample_encodings):
        matrix_g = torch.matmul(sample_encodings, self.weights_g)
        matrix_r = torch.transpose(region_embeddings, 1, 2)
        matrix_x = torch.matmul(matrix_g, matrix_r)
        matrix_c = torch.matmul(matrix_x, self.weights_c)

        return torch.nn.functional.gelu(matrix_c)


In [None]:
class GeospatialTransformer(nn.Module):
    
    def __init__(
        self, 
        geo_encoding_dim, # the dimension of the geometric encodings
        class_count=n_classes,
        transformer_dim=128, 
        region_embedding_dim=64, 
        region_embedding_count=32,
    ):
        super().__init__()

        self.geo_encoding_dim = geo_encoding_dim
        self.class_count = class_count
        self.feature_dim = geo_encoding_dim + class_count
        self.region_embedding_dim = region_embedding_dim
        self.region_embedding_count = region_embedding_count
        self.transformer_dim = transformer_dim

        self.input_proj = nn.Linear(self.feature_dim, self.transformer_dim)

        encoder_layer = nn.TransformerEncoderLayer(
            d_model=transformer_dim,
            nhead=4,
            dim_feedforward=4 * transformer_dim,
            dropout=0.1,
            batch_first=True  
        )
        self.encoder = nn.TransformerEncoder(encoder_layer, num_layers=4)
        self.collector = Collector(self.transformer_dim, self.region_embedding_dim, self.region_embedding_count)
        self.classifier = ClassifierHead(
            self.region_embedding_count, self.region_embedding_dim, 
            self.geo_encoding_dim, self.class_count
        )
        

    def forward(self, features, feature_attention_mask, sample_encodings, sample_attention_mask):
        """
        x: Tensor of shape [batch_size, n_entities, encoding_dim]
        attention_mask: Tensor of shape [batch_size, n_entities], with 1 for valid, 0 for padding
        """
        
        x = self.input_proj(features)

        # Transformer expects padding mask: True for PAD tokens
        pad_mask = (feature_attention_mask == 0)
        x = self.encoder(x, src_key_padding_mask=pad_mask)
        region_embeddings = self.collector(x)
        logits = self.classifier(region_embeddings, sample_encodings)
        return logits

    
    def embed(self, features, feature_attention_mask, sample_encodings, sample_attention_mask):
        """
        Returns an embedding for the input features
        """
        x = self.input_proj(features)

        # Transformer expects padding mask: True for PAD tokens
        pad_mask = (feature_attention_mask == 0)
        x = self.encoder(x, src_key_padding_mask=pad_mask)
        region_embeddings = self.collector(x)
        return region_embeddings


### For testing

In [None]:
model = GeospatialTransformer(
    geo_encoding_dim = geo_encoding_dim, 
    class_count = n_classes,
    transformer_dim = 32, 
    region_embedding_dim = 16, 
    region_embedding_count = 4,
)

n_param = sum(p.numel() for p in model.parameters() if p.requires_grad)
print('%d trainable parameters in model' % n_param)

In [None]:
dataset = VergeDataset(train_tiles, n_classes)
dataloader = DataLoader(
    dataset,
    batch_size=1,            # Tune depending on GPU memory
    shuffle=True,
    collate_fn=collate_fn,   # Key for padding variable-length instances
    drop_last=False
)

features, feature_attention_mask, sample_encodings, sample_labels, sample_attention_mask = dataloader.__iter__().__next__()
print('features.shape', features.shape)
print('feature_attention_mask.shape', feature_attention_mask.shape)
print('sample_encodings.shape', sample_encodings.shape)
print('sample_labels.shape', sample_labels.shape)
print('sample_attention_mask.shape', sample_attention_mask.shape)


In [None]:
embeddings = model.embed(features, feature_attention_mask, sample_encodings, sample_attention_mask)
embeddings.shape

In [None]:
embeddings

In [None]:
# logits = model.forward(features, feature_attention_mask, sample_encodings, sample_attention_mask)
# logits.shape

### Real training loop

In [None]:
model = GeospatialTransformer(
    geo_encoding_dim = geo_encoding_dim, 
    class_count = n_classes,
    transformer_dim = 128, 
    region_embedding_dim = 64, 
    region_embedding_count = 8,
)

n_param = sum(p.numel() for p in model.parameters() if p.requires_grad)
print('%d trainable parameters in model' % n_param)

In [None]:
from torch.utils.data import DataLoader

dataset = VergeDataset(train_tiles, n_classes, mask_fraction=mask_fraction)
dataloader = DataLoader(
    dataset,
    batch_size=10,            # Tune depending on GPU memory
    shuffle=True,
    collate_fn=collate_fn,
    drop_last=True
)

In [None]:
device = 'cpu'
model.to(device)
optimizer = torch.optim.Adam(model.parameters(), lr=1e-5)
criterion = nn.CrossEntropyLoss(ignore_index=-100, label_smoothing=0.1)

losses = []

model.train()
for epoch in range(100):
    for features, feature_attention_mask, sample_encodings, sample_labels, sample_attention_mask in dataloader:
        features = features.to(device)
        feature_attention_mask = feature_attention_mask.to(device)
        sample_encodings = sample_encodings.to(device)
        sample_labels = sample_labels.to(device)

        logits = model(features, feature_attention_mask, sample_encodings, sample_attention_mask)  
        loss = criterion(
            logits.view(-1, n_classes),
            sample_labels.view(-1)
        )

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

    losses.append(loss.item())

    timestring = arrow.get().isoformat()
    print(f"{timestring}: epoch {epoch+1}, loss: {loss.item():.4f}")


In [None]:
import plotly
from plotly.subplots import make_subplots
from plotly.graph_objects import Scatter

fig = make_subplots(rows=1, cols=1)
trace = Scatter(
    x=np.arange(len(losses)), y=losses, name='loss', 
    mode='markers+lines'
)
fig.append_trace(trace, 1, 1)
fig

## Performance visualization

In [None]:
# Cue up the validation dataset.
dataset = VergeDataset(train_tiles, n_classes)
dataloader = DataLoader(
    dataset,
    batch_size=100,
    shuffle=True,
    collate_fn=collate_fn,
    drop_last=False
)

In [None]:
device = 'cpu'
model.to(device)
criterion = nn.CrossEntropyLoss(ignore_index=-100)

cases = []

model.train()
for features, feature_attention_mask, sample_encodings, sample_labels, sample_attention_mask in dataloader:

    print(features.shape)
    features = features.to(device)
    feature_attention_mask = feature_attention_mask.to(device)
    sample_encodings = sample_encodings.to(device)
    sample_labels = sample_labels.to(device)

    logits = model(features, feature_attention_mask, sample_encodings, sample_attention_mask)  
    loss = criterion(
        logits.view(-1, n_classes),
        sample_labels.view(-1)
    )

    batch_size = logits.shape[0]
    for i in range(batch_size):
        case_logits = logits[i]
        case_probs = torch.softmax(case_logits, axis=1)
        case_labels = sample_labels[i]
        entity_count = len(case_labels)
        for k in range(entity_count):
            if case_labels[k].item() > 0:
                cases.append({
                    'true_label': case_labels[k].item(),
                    'probs': case_probs[k, :].detach().numpy()
                })
    if len(cases) >= 1000:
        break

In [None]:
import numpy as np
import matplotlib.pyplot as plt
from collections import defaultdict

class_count = max(d["true_label"] for d in cases) + 1
probs_by_class = defaultdict(list)

for d in cases:
    label = d["true_label"]
    probs = np.array(d["probs"])
    probs_by_class[label].append(probs)

# For each true class, compute the mean probability vector
mean_probs = []
for t in range(class_count):
    if probs_by_class[t]:
        mean = np.stack(probs_by_class[t]).mean(axis=0)
    else:
        mean = np.zeros(class_count)  # if no samples for this class
    mean_probs.append(mean)

# Convert to 2D array: [true_class, predicted_class]
matrix = np.stack(mean_probs)  # shape [C, C]

# Plot heatmap
fig, ax = plt.subplots(figsize=(8, 6))
im = ax.imshow(matrix, cmap='viridis', aspect='auto')

plt.colorbar(im, ax=ax, label='Avg Predicted Probability')
ax.set_title("Mean Predicted Probabilities by True Class")
ax.set_xlabel("Predicted Class")
ax.set_ylabel("True Class")
ax.set_xticks(range(class_count))
ax.set_yticks(range(class_count))
plt.tight_layout()
plt.show()



In [None]:
class_count = max(d["true_label"] for d in cases) + 1
print(class_count)
cmat = np.zeros((class_count, class_count))

for d in cases:
    true_label = d["true_label"]
    pred_label = np.argmax(d["probs"])
    cmat[true_label, pred_label] += 1

cmat = cmat.clip(0, 500)

# Plot heatmap
fig, ax = plt.subplots(figsize=(11, 8))
im = ax.imshow(cmat, cmap='viridis', aspect='auto')

plt.colorbar(im, ax=ax, label='Number Of Cases')
ax.set_title("Confusion Matrix")
ax.set_xlabel("Predicted Class")
ax.set_ylabel("True Class")
ax.set_xticks(range(class_count))
ax.set_yticks(range(class_count))
plt.tight_layout()
plt.show()


In [None]:
label_name_lookup = {
    z['id']: z['label']
    for z in labels.to_dict('records')
}
means = np.mean(matrix, axis=0)
for i in range(len(means)):
    print('%40s [%2d] : %6.1f' % (label_name_lookup[i], i, means[i] * 100))

In [None]:
ix = 4
label = label_name_lookup[ix]
probs = matrix[ix, :]
for i in range(len(probs)):
    print('true: %s | pred: %-40s [%2d] : prob: %6.2f' % (label, label_name_lookup[i], i, probs[i] * 100))

## Check out class imbalance

In [None]:
dataset = VergeDataset(val_tiles, n_classes)
dataloader = DataLoader(
    dataset,
    batch_size=100,
    shuffle=True,
    collate_fn=collate_fn,
    drop_last=False
)

In [None]:
sample_counts = {z: 0 for z in range(n_classes)}
for features, feature_attention_mask, sample_encodings, sample_labels, sample_attention_mask in dataloader:
    for s in torch.flatten(sample_labels):
        s = s.item()
        if s >= 0:
            sample_counts[s] += 1
            

In [None]:
 for s in sorted(sample_counts.keys()):
     print('[%2d] %-40s : %6.0f' % (s, label_name_lookup[s], sample_counts[s])) 