### VERGE: Vector-Mode Regional Geospatial Encoding
# Model evaluation


Elsewhere we trained a model to predict geospatial entity type from the encodings of the things
in its vicinity. Here we run an evaluation.


## Processing Setup

In [None]:
# Google colab setup
import os
from google.colab import drive
drive.mount('/content/drive')
project_home = '/content/drive/MyDrive/Projects/verge'
os.chdir(project_home)

In [None]:
!pip install geo_encodings

In [None]:
# Local processing setup
# project_home = '..'

## Source

In [None]:
# The name of the ROI to use.
roi_name = 'newengland'

# The name of the general-purpose data directory.
data_home = '%s/data' % (project_home)

# The name of the ROI-specific data directory.
roi_home = '%s/data/%s' % (project_home, roi_name)


## Setup

In [None]:
import pandas as pd
import numpy as np
import glob
import pickle
import os
import torch
import torch.nn as nn
import torch.utils
import torch.utils.data
from torch.utils.data import DataLoader
import copy
import json
from geo_encodings import MPPEncoder

import sys
sys.path.append(project_home)
from utils.geo_transformer_mem import VergeDataset, verge_collate_fn, GeospatialTransformer


## Parameters

In [None]:
# Read the ROI definition.
fname = '%s/roi.json' % roi_home
with open(fname) as source:
    roi = json.load(source)

tile_size = roi['tile_size']
encoding_resolution = roi['encoding_resolution']

# We need the dimension of the encoding.
encoder = MPPEncoder(
    region=[0, 0, tile_size, tile_size],
    resolution=encoding_resolution,
    center=True
)
geo_encoding_dim = len(encoder)
print('%d elements in encodings' % geo_encoding_dim)


In [None]:
roi

In [None]:
run_id = '010'

# What type of device to train on.
device = 'cuda' if torch.cuda.is_available() else 'cpu'
print('using device', device)

## Preliminaries

In [None]:
# Read the list of labels.
fname = '%s/labels.csv' % data_home
labels = pd.read_csv(fname)
n_classes = len(labels)
print('%d labels in this dataset' % n_classes)

label_id_lookup = {
    z['label']: z['id']
    for z in labels.to_dict('records')
}

label_name_lookup = {
    z['id']: z['label']
    for z in labels.to_dict('records')
}

In [None]:
# Read the file that gives class probabilities.
fname = '%s/class_info.csv' % roi_home
class_info = pd.read_csv(fname)
print('%d class info records' % len(class_info))

## Load data
We determine which filed to read by loading the associated "split" file.

In [None]:
splits_fname = '%s/splits-%s.csv' % (roi_home, run_id)
splits = pd.read_csv(splits_fname)
val_fnames = splits[splits['split'] == 'val']['aoi'].tolist()
# val_fnames = splits[splits['type'] == 'train']['fname'].tolist()
print('%d files with validation data' % len(val_fnames))

# Read some data.
val_tiles = []
for fname in val_fnames[:3]:
    print('reading', fname)
    with open(fname, 'rb') as source:
        val_tiles += pickle.load(source)

print('%d validation tiles' % len(val_tiles))

## Prep model and data

In [None]:
# The dataset constructor requires a lookup table for class probabilities.
class_prob_lookup = {
    z['label']: z['prob']
    for z in class_info.to_dict('records')
}

val_dataset = VergeDataset(val_fnames, n_classes, mask_fraction=0.15, class_prob=class_prob_lookup)
val_dataloader = DataLoader(
    val_dataset,
    batch_size=64, # Tune depending on GPU memory
    shuffle=True,
    collate_fn=verge_collate_fn,
    drop_last=False
)

In [None]:
# Load the model.
model_fname = '%s/model-%s' % (roi_home, run_id)
model = torch.load(model_fname, weights_only=False)
print('loaded %s' % model_fname)

## Validation Visualization

In [None]:
# Process the validation dataset, getting the class probability predictions
# for every instance.
model.to(device)
cases = []

model.eval()
for features, labels, attention_mask in val_dataloader:

    features = features.to(device)
    attention_mask = attention_mask.to(device)
    labels = labels.to(device)
    logits = model(features, attention_mask)
    batch_size = logits.shape[0]
    for i in range(batch_size):
        case_logits = logits[i]
        case_probs = torch.softmax(case_logits, dim=1)
        case_labels = labels[i]
        entity_count = len(case_labels)
        for k in range(entity_count):
            if case_labels[k].item() >= 0: # Skips the "-100" labels.
                cases.append({
                    'true_label': case_labels[k].item(),
                    'probs': torch.Tensor.cpu(case_probs[k, :]).detach().numpy()
                })

print('compiled prediction probabilities for %d validation instances' % len(cases))

In [None]:
import numpy as np
import matplotlib.pyplot as plt
from collections import defaultdict

class_count = max(d["true_label"] for d in cases) + 1
probs_by_class = defaultdict(list)

for d in cases:
    label = d["true_label"]
    probs = np.array(d["probs"])
    probs_by_class[label].append(probs)

# For each true class, compute the mean probability vector
mean_probs = []
for t in range(class_count):
    if probs_by_class[t]:
        mean = np.stack(probs_by_class[t]).mean(axis=0)
    else:
        mean = np.zeros(class_count)  # if no samples for this class
    mean_probs.append(mean)

# Convert to 2D array: [true_class, predicted_class]
matrix = np.stack(mean_probs)  # shape [C, C]

# Plot heatmap
fig, ax = plt.subplots(figsize=(9, 6))
im = ax.imshow(matrix, cmap='viridis', aspect='auto')

plt.colorbar(im, ax=ax, label='Avg Predicted Probability')
ax.set_title("Mean Predicted Probabilities by True Class")
ax.set_xlabel("Predicted Class")
ax.set_ylabel("True Class")
ax.set_xticks(range(class_count))
ax.set_yticks(range(class_count))
ax.set_yticklabels(['%s [%d]' % (label_name_lookup[i], i) for i in range(class_count)])
plt.tight_layout()
plt.show()



In [None]:
# Confusion matrix.

class_count = max(d["true_label"] for d in cases) + 1
print(class_count)
cmat = np.zeros((class_count, class_count))

for d in cases:
    true_label = d["true_label"]
    pred_label = np.argmax(d["probs"])
    cmat[true_label, pred_label] += 1

cmat = np.sqrt(cmat)

# Plot heatmap
fig, ax = plt.subplots(figsize=(12, 8))
im = ax.imshow(cmat, cmap='viridis', aspect='auto')

plt.colorbar(im, ax=ax, label='Number Of Cases')
ax.set_title("Confusion Matrix")
ax.set_xlabel("Predicted Class")
ax.set_ylabel("True Class")
ax.set_xticks(range(class_count))
ax.set_yticks(range(class_count))
ax.set_yticklabels([label_name_lookup[i] for i in range(class_count)])
plt.tight_layout()
plt.show()


In [None]:
# Accuracy stats

from sklearn.metrics import top_k_accuracy_score, f1_score

y_true = np.array([d["true_label"] for d in cases])
y_pred = np.array([np.argmax(d["probs"]) for d in cases])
y_prob = np.vstack([d["probs"] for d in cases])

all_labels = list(range(22))
print('top-1 accuracy: %.4f' % top_k_accuracy_score(y_true, y_prob, k=1, labels=all_labels))
print('top-2 accuracy: %.4f' % top_k_accuracy_score(y_true, y_prob, k=2, labels=all_labels))
print('top-3 accuracy: %.4f' % top_k_accuracy_score(y_true, y_prob, k=3, labels=all_labels))
print('top-4 accuracy: %.4f' % top_k_accuracy_score(y_true, y_prob, k=4, labels=all_labels))
print('top-5 accuracy: %.4f' % top_k_accuracy_score(y_true, y_prob, k=5, labels=all_labels))
print('f1 score: %.4f' % f1_score(y_true, y_pred, average='macro'))


## Check consistency with respect to randomization.

In [None]:
# Get a "batch" consisting of one instance.
dataset = VergeDataset(val_tiles, n_classes, mask_fraction=0.15)
batch = [dataset[k] for k in [3]]
batch_features, batch_labels, batch_attention_mask = collate_fn(batch)
batch_labels


In [None]:
# Get the logits for that batch.
model.to(device)
batch_features = batch_features.to(device)
batch_attention_mask = batch_attention_mask.to(device)
logits = model(batch_features, batch_attention_mask)
print(logits.shape)

In [None]:
# Get a random permutation to be applied to the rows of the batch.
# Permute every row except the first -- we will use that as a reference.
perm = torch.concat((
    torch.tensor([0]),
    torch.randperm(batch_features.shape[1] - 1) + 1
))
print(perm)

In [None]:
# Apply the permutation to the features and re-run the model.
permuted_features = batch_features.clone()
permuted_features = permuted_features[:, perm]
permuted_features.to(device)
permuted_logits = model(permuted_features, batch_attention_mask)
print(permuted_logits.shape)

In [None]:
# Extract the corresponding logits from the original and permuted result.
a_index = 0
b_index = torch.where(perm == a_index)[0][0].item()
print(a_index, b_index)
a = logits[0, a_index]
b = permuted_logits[0, b_index]
print(a)
print(b)

In [None]:
import plotly
from plotly.subplots import make_subplots
from plotly.graph_objects import Scatter

aa = torch.Tensor.cpu(a).detach().numpy()
bb = torch.Tensor.cpu(b).detach().numpy()

fig = make_subplots(rows=1, cols=1)
n = aa.shape[0]
xx = np.arange(n)

trace = Scatter(x=xx, y=aa, name='a', mode='markers+lines')
fig.add_trace(trace, row=1, col=1)

trace = Scatter(x=xx, y=bb, name='b', mode='markers+lines')
fig.add_trace(trace, row=1, col=1)

fig.show()

In [None]:
batch_features[0, a_index, :] - permuted_features[0, b_index, :]

In [None]:
x = batch_features
x_perm = x.clone()
x_perm[:, [0, 1], :] = x_perm[:, [1, 0], :]  # Swap two entities in each sample
logits_orig = model(x, batch_attention_mask)
logits_perm = model(x_perm, batch_attention_mask)

print((logits_orig - logits_perm).abs().max())  # Should be > 0 for non-invariant models
