### VERGE: Vector-Mode Regional Geospatial Encoding
# Initial Embeddings

Back in the "02" folder, we trained a Masked Geospatial Model.
That model can compute a set of embeddings for any tile,
for which the inputs are a set of geospatial entities.
These embeddings are permutation-equivariant ("perm-e")
with respect to the input features,
which is not what we want for a regional embedding.
In this folder we are building
a fully permutation-invariant ("perm-i") aggregation of the perm-e
outputs.

In this notebook, we compute those perm-e embeddings for all instances
in our training and validation datasets.
That is, we run the model that was trained bac inthe 02 folder.


## Processing Setup

In [None]:
# Google colab setup
import os
from google.colab import drive
drive.mount('/content/drive')
project_home = '/content/drive/MyDrive/Projects/verge'
os.chdir(project_home)
!pip install geo_encodings

In [None]:
# Local processing setup
# project_home = '..'

## Notebook Setup

In [None]:
import pandas as pd
import numpy as np
import glob
import pickle
import os
import torch
import torch.nn as nn
import torch.utils
import torch.utils.data
from torch.utils.data import DataLoader
import copy
import json
from geo_encodings import MPPEncoder

import sys
sys.path.append(project_home)
from utils.geo_transformer_mem import VergeDataset, verge_collate_fn, GeospatialTransformer


## Parameters

In [None]:
# The name of the ROI to use.
roi_name = 'newengland'

# The name of the general-purpose data directory.
data_home = '%s/data' % (project_home)

# The name of the ROI-specific data directory.
roi_home = '%s/data/%s' % (project_home, roi_name)

# The unique identifier of the model to be used for initial embeddings.
transformer_model_id = '301b'

# Identifier of the splits file to use.
splits_id = '201'

# What type of device to train on.
device = 'cuda' if torch.cuda.is_available() else 'cpu'
print('using device', device)


## Preliminaries

In [None]:
# Read the ROI definition.
fname = '%s/roi.json' % roi_home
with open(fname) as source:
    roi = json.load(source)

tile_size = roi['tile_size']
encoding_resolution = roi['encoding_resolution']

# We need the dimension of the encoding.
encoder = MPPEncoder(
    region=[0, 0, tile_size, tile_size],
    resolution=encoding_resolution,
    center=True
)
geo_encoding_dim = len(encoder)
print('%d elements in encodings' % geo_encoding_dim)


In [None]:
# Read the list of labels.
fname = '%s/labels.csv' % data_home
labels = pd.read_csv(fname)
n_classes = len(labels)
print('%d labels in this dataset' % n_classes)

label_id_lookup = {
    z['label']: z['id']
    for z in labels.to_dict('records')
}

label_name_lookup = {
    z['id']: z['label']
    for z in labels.to_dict('records')
}

In [None]:
# Read the file that gives class probabilities.
fname = '%s/class_info.csv' % roi_home
class_info = pd.read_csv(fname)
print('%d class info records' % len(class_info))

## Load data
We determine which files to read by loading the associated "split" file.

In [None]:
# Get the list of AOI tags. They can be found in the splits file.
fname = '%s/models/splits-%s.csv' % (roi_home, splits_id)
splits = pd.read_csv(fname)
aoi_tags = np.unique(splits['aoi_tag'])

## Prep model and data

In [None]:
# The dataset constructor requires a lookup table for class probabilities.
class_prob_lookup = {
    z['label']: z['prob']
    for z in class_info.to_dict('records')
}

In [None]:
# Load the model.
model_fname = '%s/models/transformer-%s' % (roi_home, transformer_model_id)
model = torch.load(model_fname, weights_only=False)
print('loaded %s' % model_fname)

model.to(device)
model.eval()

In [None]:
# Loop over encoded input files. For each one, define a dataset,
# run it through the model, and generate an initial set of embeddings.
for aoi_tag in aoi_tags:

    encoding_fname = '%s/encodings/%s.pkl' % (roi_home, aoi_tag)
    print('\n%s [%s / %s]' % (encoding_fname, k, len(aoi_tags)))

    # Define a dataset and datloader for this input file.
    # Note that we set the batch size to 1. This effectively removes all
    # padding, as the dataloader pads to the largest object
    # in the batch.
    dataset = VergeDataset([encoding_fname], n_classes, mask_fraction=0.0, class_prob=class_prob_lookup)
    dataloader = DataLoader(
        dataset,
        batch_size=1,
        shuffle=True,
        collate_fn=verge_collate_fn,
        drop_last=False
    )

    # Get embeddings for every tile in this AOI.
    embeddings_for_this_aoi = []
    for features, labels, attention_mask, idents in dataloader:
        features = features.to(device)
        attention_mask = attention_mask.to(device)
        labels = labels.to(device)
        embeddings = model.embed(features, attention_mask)
        embeddings_for_this_aoi.append({
            'aoi_tag': idents[0].split(':')[0],
            'tile_tag': idents[0].split(':')[1],
            'embedding': embeddings
        })

    # Save those.
    ofname = '%s/initials/%s.pkl' % (roi_home, aoi_tag)
    os.makedirs(os.path.dirname(ofname), exist_ok=True)
    with open(ofname, 'wb') as dest:
        pickle.dump(embeddings_for_this_aoi, dest)
    print('wrote %s' % ofname)


## QA