### VERGE: Vector-mode Regional Geospatial Encoding
# VERGE model implementation


Here we build and train a "masked geospatial model". 
This is a model in which each inpout is a set of encoded geospatial entities,
consisting of a cooncatenation of a multi-point proximity encoding and a one-hot label vector.
Modeling consists of masking the labels for a random selection of entities, 
passing the data through an encoder-based architecutre to predicte the labels of masked entities. 
The idea is that the encodings then capture information about the region.

## Version 2: Partial Autoencoder

I';m trying an approache here, in which the model is trained to predict
a subset of its inpout features. I'm calling it a "partial autoencoder".
The idea is:
* Train an embedding on sets of features A and B, whose verature vectors are concatenated together.
* After the embedding part, add a "head" that predicts A based on B.
* If that works, then presumably the model understands somethign about the relationship between the two.



In [None]:
import pandas as pd
import numpy as np
import glob
import pickle
import os
import torch
import torch.nn as nn
from torch.utils.data import DataLoader
import copy

## Parameters

In [None]:
# This is the dimension of the (square) AOIs. Set thi to match what was used
# when the tiles were created.
aoi_size = 500

# This is the resolution of the MPP encoding.
resolution = 50

# Fraction of cases to use for training.
train_fraction = 0.7

## Preliminaries

In [None]:
# Read the list of labels.
fname = 'labels.csv'
labels = pd.read_csv(fname)
n_classes = len(labels)
print('%d labels in this dataset' % n_classes)

label_id_lookup = {
    z['label']: z['id']
    for z in labels.to_dict('records')
}

In [None]:
# Get a list of input data files. Each file consists of a list of encodings for 
# a number of square tiles in a particluar AOI.
globstring = 'data/encodings/*'
fnames = glob.glob(globstring)
print('%d input files' % len(fnames))

In [None]:
# Read some data.
tile_data_list = []
for fname in fnames[:2]:
    print('reading', fname)
    with open(fname, 'rb') as source:
        tile_data_list += pickle.load(source)

# Divide things into training and validation sets.
train_tiles = []
val_tiles = []
for t in tile_data_list:
    if np.random.random() < train_fraction:
        train_tiles.append(t)
    else:
        val_tiles.append(t)

print('%d training instances' % len(train_tiles))
print('%d validation instances' % len(val_tiles))

In [None]:
# This class wraps a list of input tile data as a pytorch dataset.
# The "getitem" method here parses apart the true labels and the encodings,
# and applies random masking to the endocing.

class VergeDataset(torch.utils.data.Dataset):
    
    def __init__(self, data_list, n_classes, mask_fraction=0.15):
        self.data = data_list
        self.n_classes = n_classes
        self.mask_fraction = mask_fraction
        self.encoding_dim = data_list[0].shape[1] - self.n_classes
        # self.mask_label_index = mask_label_index

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        
        features = self.data[idx]
        encodings = features[:, self.n_classes:]
        true_labels_onehot = features[:, :self.n_classes]
        true_labels = np.argmax(true_labels_onehot, axis=1)
        n_entities = features.shape[0]

        # Select a few entities for which to assign masked labels.
        mask = np.random.rand(n_entities) < self.mask_fraction
        mask_indices = np.where(mask)

        # In the feature array, labels are one-hot vectors that get concatenated
        # with the geometric encodings. To "mask" those labels, we replace the
        # one-hot vector with a zero-hot vector.
        mask_vector = np.zeros(self.n_classes)
        # mask_label_onehot[self.mask_label_index] = 1.0
        masked_labels_onehot = copy.copy(true_labels_onehot)
        for i in mask_indices:
            masked_labels_onehot[i] = mask_vector

        # Re-concatenate the masked labels with the geometric encodings.
        masked_features = torch.concat((
            torch.tensor(masked_labels_onehot, dtype=torch.float32), 
            torch.tensor(encodings, dtype=torch.float32)
        ), axis=-1)

        # During model training below, we will be using the "CrossEntropyLoss" function,
        # which has a built-in capability to ignore un-masked entitites. To get it to work,
        # we need to pack an "ignore" token into any label slot that is not masked.
        # Pytorch's standard value for that token is -100.
        labels = torch.full(true_labels.shape, -100, dtype=torch.long)
        for i in mask_indices[0]:
            labels[i] = true_labels[i]
        
        return (masked_features, labels)

dataset = VergeDataset(tile_data_list, n_classes, mask_fraction=0.2)
features, labels = dataset[0]
print('features.shape', features.shape)
print('labels.shape', labels.shape)


In [None]:
# Define the function that puts together a batch. The main thing we are handling here
# is padding. We make all arrays have a size equal to the largest one in the batch,
# ithe excess space filled with padding tokens.
def collate_fn(batch):
    
    features, labels = zip(*batch)
    max_len = max(x.shape[0] for x in features)
    batch_size = len(features)
    feature_dim = features[0].shape[1]

    padded_features = torch.zeros(batch_size, max_len, feature_dim)
    padded_labels = torch.full((batch_size, max_len), -100, dtype=torch.long)  # -100 is the "ignore" value
    attention_mask = torch.zeros(batch_size, max_len, dtype=torch.bool)

    for i in range(batch_size):
        n = features[i].shape[0]
        padded_features[i, :n] = features[i]
        padded_labels[i, :n] = labels[i]
        attention_mask[i, :n] = 1

    return padded_features, padded_labels, attention_mask


dataset = VergeDataset(train_tiles, n_classes)
batch = [dataset[k] for k in [5, 6, 7, 8]]
batch_features, batch_labels, batch_attention_mask = collate_fn(batch)
print('batch_features.shape', batch_features.shape)
print('batch_labels.shape', batch_labels.shape)
print('batch_attention_mask.shape', batch_attention_mask.shape)


## Model definition

In [None]:
class Collector(torch.nn.Module):
    
    def __init__(self, feature_dim, embed_dim, head_count):
        
        super().__init__()

        self.feature_dim = feature_dim
        self.embed_dim = embed_dim
        self.head_count = head_count
        
        weights_r = torch.Tensor(feature_dim, embed_dim)
        self.weights_r = nn.Parameter(weights_r)
        torch.nn.init.normal_(self.weights_r, mean=0.0, std=1.0)

        weights_h = torch.Tensor(feature_dim, head_count)
        self.weights_h = nn.Parameter(weights_h)
        torch.nn.init.normal_(self.weights_h, mean=0.0, std=1.0)

    def forward(self, x):
        print('x', x.shape)
        
        matrix_r = torch.matmul(x, self.weights_r)
        print('matrix_r', matrix_r.shape)

        matrix_h = torch.matmul(x, self.weights_h)
        matrix_h = torch.nn.functional.softmax(matrix_h, dim=1)
        print('matrix_h', matrix_h.shape)
        print('sum of columns of h:', torch.sum(matrix_h, dim=1))

        matrix_e = torch.matmul(torch.transpose(matrix_r, 1, 2), matrix_h)
        return matrix_e


In [None]:
class GeospatialTransformer(nn.Module):
    
    def __init__(self, feature_dim, model_dim=256, num_heads=4, num_layers=2, num_classes=10, dropout=0.1):
        super().__init__()
        self.input_proj = nn.Linear(feature_dim, model_dim)

        encoder_layer = nn.TransformerEncoderLayer(
            d_model=model_dim,
            nhead=num_heads,
            dim_feedforward=4 * model_dim,
            dropout=dropout,
            batch_first=True  
        )
        self.encoder = nn.TransformerEncoder(encoder_layer, num_layers=num_layers)

        self.collector = Collector(model_dim, 20, 13)

        self.output_head = nn.Linear(20 * 13, num_classes)

    def forward(self, x, attention_mask):
        """
        x: Tensor of shape [batch_size, n_entities, encoding_dim]
        attention_mask: Tensor of shape [batch_size, n_entities], with 1 for valid, 0 for padding
        """
        print('input', x.shape)
        
        x = self.input_proj(x)
        print('projected', x.shape)

        # Transformer expects padding mask: True for PAD tokens
        pad_mask = (attention_mask == 0)
        x = self.encoder(x, src_key_padding_mask=pad_mask)
        print('transformed', x.shape)

        x = self.collector(x)
        print('collected', x.shape)

        x = torch.flatten(x, start_dim=1)
        print('flattened', x.shape)

        logits = self.output_head(x)
        print('logits', logits.shape)

        return logits

    
    def embed(self, x, attention_mask):
        """
        Returns an embedding for the input features
        """
        x = self.input_proj(x)
        pad_mask = (attention_mask == 0)
        x = self.encoder(x, src_key_padding_mask=pad_mask)
        x = self.collector(x)
        return x


In [None]:
model = GeospatialTransformer(feature_dim=171, model_dim=64, num_heads=4, num_layers=2, num_classes=50, dropout=0.0)
n_param = sum(p.numel() for p in model.parameters() if p.requires_grad)
print('%d trainable parameters in model' % n_param)

### For testing

In [None]:
dataset = VergeDataset(train_tiles, n_classes)
dataloader = DataLoader(
    dataset,
    batch_size=1,            # Tune depending on GPU memory
    shuffle=True,
    collate_fn=collate_fn,   # Key for padding variable-length instances
    drop_last=False
)

features, labels, attention_mask = dataloader.__iter__().__next__()
print(features.shape, labels.shape, attention_mask.shape)

In [None]:
model.forward(features, attention_mask)

### Check permutation invariance

In [None]:
device = 'cpu'
model.to(device)
features = features.to(device)
labels = labels.to(device)
attention_mask = attention_mask.to(device)
e0 = model.embed(features, attention_mask)
print(e0.shape)
print(features[0, :, -1])
print(e0[0, :, 0])


In [None]:
torch.sum(e0, axis=1)

In [None]:
perm = torch.randperm(features.size(1))
print(perm)
permuted_features = features[:, perm, :]
permuted_features = permuted_features.to(device)
e1 = model.embed(permuted_features, attention_mask)
print(e1.shape)
print(permuted_features[0, :, -1])
print(e1[0, :, 0])


In [None]:
torch.sum(e1, axis=1)

In [None]:
torch.sum(e0, axis=1) - torch.sum(e1, axis=1)

### Real traing loop

In [None]:
from torch.utils.data import DataLoader

# Initialize dataset
dataset = VergeDataset(train_tiles, n_classes)

# Create DataLoader
dataloader = DataLoader(
    dataset,
    batch_size=16,            # Tune depending on GPU memory
    shuffle=True,
    collate_fn=collate_fn,   # Key for padding variable-length instances
    drop_last=False
)

In [None]:
device = 'cpu'
model.to(device)
optimizer = torch.optim.Adam(model.parameters(), lr=1e-4)
criterion = nn.CrossEntropyLoss(ignore_index=-100)

losses = []

model.train()
for epoch in range(2):
    for features, labels, attention_mask in dataloader:
        features = features.to(device)
        labels = labels.to(device)
        attention_mask = attention_mask.to(device)

        logits = model(features, attention_mask)  
        loss = criterion(
            logits.view(-1, n_classes),
            labels.view(-1)
        )

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

    losses.append(loss.item())

    print(f"Epoch {epoch+1}, Loss: {loss.item():.4f}")


In [None]:
import plotly
from plotly.subplots import make_subplots
from plotly.graph_objects import Scatter

fig = make_subplots(rows=1, cols=1)
trace = Scatter(
    x=np.arange(len(losses)), y=losses, name='loss', 
    mode='markers+lines'
)
fig.append_trace(trace, 1, 1)
fig

In [None]:
import torch
import torch.nn as nn

# Set seed for reproducibility
torch.manual_seed(42)

# Parameters
batch_size = 1
num_entities = 5
embedding_dim = 8

# Dummy input tensor: shape [1, 5, 8]
x = torch.randn(batch_size, num_entities, embedding_dim)

# Transformer encoder without positional encodings
encoder_layer = nn.TransformerEncoderLayer(d_model=embedding_dim, nhead=2, batch_first=True, dropout=0.0)
transformer = nn.TransformerEncoder(encoder_layer, num_layers=1)

# Attention mask: all entities valid (no padding)
attention_mask = torch.ones(batch_size, num_entities, dtype=torch.bool)

# Forward pass with original order
out_orig = transformer(x, src_key_padding_mask=~attention_mask)

# Randomly permute the input rows
perm = torch.randperm(num_entities)
x_permuted = x[:, perm, :]
mask_permuted = attention_mask[:, perm]

# Forward pass with permuted input
out_perm = transformer(x_permuted, src_key_padding_mask=~mask_permuted)

# Undo permutation to align output rows for comparison
out_perm_reordered = torch.zeros_like(out_orig)
out_perm_reordered[:, perm] = out_perm

# Compute difference
diff = (out_orig - out_perm_reordered).abs().max().item()
print("Max difference after permutation:", diff)


#  based on A. 