# Apply the trained embedding model to a single location

This notebook implements end-to-end VERGE encoding for a single location.
Its purpose is to lay out exactly what processing needs to take place.
It is a very inefficient implementation; locations can be processed in bulk
much faster.

## Processing Setup

In [None]:
# Google colab
import os
from google.colab import drive
drive.mount('/content/drive')
project_home = '/content/drive/MyDrive/Projects/verge'
os.chdir(project_home)
!pip install geo_encodings osmnx

In [None]:
# Local processing setup
# project_home = '..'

## Notebook Setup

In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.utils.data import DataLoader, Dataset
from typing import List, Tuple, Optional

import pickle
import json
import copy
import pandas as pd
import numpy as np
import pyproj
import shapely
import osmnx
import geopandas

import sys
sys.path.append('%s/03-embeddings' % project_home)
from embedderv5 import *

sys.path.append(project_home)
from utils.verge import rules


## Parameters

In [None]:
# The name of the ROI to use.
roi_name = 'newengland'

# The name of the general-purpose data directory.
data_home = '%s/data' % (project_home)

# The name of the ROI-specific data directory.
roi_home = '%s/data/%s' % (project_home, roi_name)

# The unique identifier of the model to be used.
transformer_run_id = '301b'
collector_run_id = '301b'

## Preliminaries

In [None]:
# Read the ROI definition.
fname = '%s/roi.json' % roi_home
with open(fname) as source:
    roi = json.load(source)

tile_size = roi['tile_size']
encoding_resolution = roi['encoding_resolution']

roi

In [None]:
# Read the file containing labels.
fname = '%s/labels.csv' % data_home
labels = pd.read_csv(fname)

# Make a lookup tbale to get a numberical label from a text label.
label_lookup = {
    z['label']: z['id']
    for z in labels.to_dict('records')
}
label_count = len(label_lookup)
label_lookup

In [None]:
# Define a local map projection, using the definition from the ROI file.
def get_projections(proj_def):
    ltm_crs = pyproj.CRS.from_proj4(proj_def)
    wgs84_crs = pyproj.CRS.from_epsg(4326)
    proj_forward = pyproj.Transformer.from_crs(wgs84_crs, ltm_crs, always_xy=True).transform
    proj_inverse = pyproj.Transformer.from_crs(ltm_crs, wgs84_crs, always_xy=True).transform
    return proj_forward, proj_inverse

proj_forward, proj_inverse = get_projections(roi['proj_def'])

In [None]:
# Read the coastline file.
fname = '%s/coastlines' % (roi_home)
coastlines_gdf = geopandas.read_file(fname)
print('%d coastline polygons' % len(coastlines_gdf))

def get_land_water(bounds, features):

    # Create a baseline polygon consisting of the whole AOI.
    landwater = copy.deepcopy(bounds)

    # Intersect that with the coastlines data.
    coastlines = shapely.union_all(coastlines_gdf['geometry'].values)
    landwater = landwater.intersection(coastlines)

    # subtract out any polygonal water feature.
    for _, f in features.iterrows():
        if f['geometry'].geom_type in ['Polygon', 'MultiPolygon']:
            if f['natural'] == 'water':
                landwater = shapely.difference(landwater, f['geometry'])

    return landwater

## Processing


### Set a tile center location

In [None]:
center_lat, center_lon = 43.000659, -70.921196 # Stratham Subaru

### Pull OSM data for the area around this location

In [None]:
# Get bounds for which to query OSM features.
buffer = roi['tile_size'] / 2 + 200
center_x, center_y = proj_forward(center_lon, center_lat)
x0, y0 = center_x - buffer, center_y - buffer
x1, y1 = center_x + buffer, center_y + buffer
lon0, lat0 = proj_inverse(x0, y0)
lon1, lat1 = proj_inverse(x1, y1)
query_bounds = [lon0, lat0, lon1, lat1]
print(query_bounds)


In [None]:
# Query for all relevant geospatial entities we need within the bounding box.
import osmnx
tags = {
    'landuse': True,
    'place': True,
    'highway': True,
    'railway': True,
    #'aeroway': True,
    'bridge': True,
    'tunnel': True,
    #'power': True,
    'natural': True,
    'waterway': True,
    'landcover': True,
    #'building': True,
    'amenity': True,
    'shop': True,
    'leisure': True
}
features = osmnx.features.features_from_bbox(query_bounds, tags=tags).reset_index()
print('%d features from OSM' % len(features))


In [None]:
# Just retain the relevant columns.
columns_in_rules = set(['geometry', 'amenity', 'highway', 'landuse', 'railway', 'water', 'waterway', 'natural'])
columns_in_features = set(features.columns)
columns_to_keep = list(columns_in_rules.intersection(columns_in_features))
features = features[columns_to_keep]
features.head(3)

### Re-organize geo info for this tile

In [None]:
# Down-select and re-format any relevant geospatial entities ("gents").
gents = []
for feature in features.to_dict('records'):

    geomxy = shapely.ops.transform(proj_forward, feature['geometry'])
    if geomxy.is_empty:
        continue
    gtype = geomxy.geom_type

    for rule in rules:
        if gtype == rule['gtype']:
            osm_key = rule['osm_key']
            if osm_key in feature:
                osm_value = str(feature[osm_key])
                if osm_value in rule['osm_values']:
                    gents.append({
                        'feature': feature,
                        'category': rule['gent_category'],
                        'label': rule['gent_label'],
                        'geomxy': geomxy,
                        'gtype': gtype
                    })
print('%d features selected' % len(gents))


In [None]:
# Create a "land/water" polygon.
lons = [lon0, lon1, lon1, lon0, lon0]
lats = [lat0, lat0, lat1, lat1, lat0]
lonlat_bounds = shapely.Polygon(list(zip(lons, lats)))
landwater = get_land_water(lonlat_bounds, features)
landwaterxy = shapely.ops.transform(proj_forward, landwater)
gents.append({
    'category': 'waterway',
    'label': 'land',
    'geomxy': landwaterxy,
    'gtype': landwaterxy.geom_type
})


In [None]:
# Get the bounds for this tile in projected coordinates.
buffer = roi['tile_size'] / 2
center_x, center_y = proj_forward(center_lon, center_lat)
x0, y0 = center_x - buffer, center_y - buffer
x1, y1 = center_x + buffer, center_y + buffer
xx = [x0, x1, x1, x0, x0]
yy = [y0, y0, y1, y1, y0]
tile_bbox = shapely.Polygon(list(zip(xx, yy)))

In [None]:
# Re-project all geospatial entities and clip them to the bounds of this tile.
tile_gents = []
for gent in gents:
    geomxy = shapely.affinity.translate(
        gent['geomxy'].intersection(tile_bbox),
        xoff=-x0, yoff=-y0
    )
    if geomxy.is_empty:
        continue
    tile_gents.append({
        'category': gent['category'],
        'label': gent['label'],
        'geometry': geomxy,
        'gtype': gent['gtype'],
        'xoff': x0,
        'yoff': y0,
    })
print('%d geospatial entities' % len(tile_gents))
pd.DataFrame(tile_gents).head(3)

### Apply MPP encoding to all entities in this tile

In [None]:
# Define an encoder to use.
from geo_encodings import MPPEncoder
encoder = MPPEncoder(
    region=[0, 0, tile_size, tile_size],
    resolution=encoding_resolution,
    center=True
)
geo_encoding_dim = len(encoder)
print('%d elements in encodings' % geo_encoding_dim)


In [None]:
# Apply encodings.
for gent in tile_gents:
    gent['encoding'] = encoder.encode(gent['geometry']).values()


### Get one-hot label vectors for each entity

In [None]:
# We will also need the one-hot label vectors for each entity.
for gent in tile_gents:
    label_name = '%s : %s' % (gent['category'], gent['label'])
    label_id = label_lookup[label_name]
    label_onehot = np.full(label_count, 0, dtype=float)
    label_onehot[label_id] = 1
    gent['onehot'] = label_onehot

In [None]:
# Display that encoding as a heat map.
mpps = np.vstack([z['encoding'] for z in tile_gents])
print(mpps.shape)
onehots = np.vstack([z['onehot'] for z in tile_gents])
print(onehots.shape)

tile_encoding = np.hstack([onehots, mpps])
print(tile_encoding.shape)

In [None]:
import numpy as np
import matplotlib.pyplot as plt

data = tile_encoding

plt.imshow(data, cmap='viridis', aspect='equal')  # 'viridis' is a good default colormap
plt.colorbar(label="Value")  # add a color scale bar
plt.title("Full Encoding For Tile")
plt.xlabel("Encoding Index")
plt.ylabel("Entity Number")
plt.show()


### Initial embedding for this tile

In [None]:
from utils.geo_transformer_mem import VergeDataset, verge_collate_fn, GeospatialTransformer

# Get initial embedding for this tile.
transformer = GeospatialTransformer(
    feature_dim = geo_encoding_dim + label_count,
    model_dim=128,
    num_heads=4,
    num_layers=4,
    num_classes=label_count,
    dropout=0.2
)

model_fname = '%s/models/transformer-%s' % (roi_home, transformer_run_id)
transformer = torch.load(model_fname, weights_only=False, map_location=torch.device('cpu'))
print('loaded %s' % model_fname)

n_param = sum(p.numel() for p in transformer.parameters() if p.requires_grad)
print('%d trainable parameters in model' % n_param)


In [None]:
input_features = torch.tensor(tile_encoding, dtype=torch.float32).unsqueeze(0)
print(input_features.shape)

attention_mask = torch.ones(1, tile_encoding.shape[0])

input_attention_mask = torch.ones(1, tile_encoding.shape[0], dtype=torch.float32)
print(input_attention_mask.shape)


In [None]:
transformed = transformer.embed(input_features, input_attention_mask)
print(transformed.shape)

### Get the final embedding for this tile

In [None]:
import sys
sys.path.append('%s/03-embeddings' % project_home)
from embedderv5 import ContrastivePairDataset, PermutationInvariantModel, TripletContrastiveLoss, triplet_collate_fn

In [None]:
# Initialize model
embedding_dim = 128
model = PermutationInvariantModel(
    input_dim=embedding_dim,
    hidden_dim=128,
    embedding_dim=embedding_dim,
    num_attention_heads=8,
    num_linear_layers=3,
    dropout=0.1
)

model_fname = '%s/models/collector-%s.pth' % (roi_home, collector_run_id)
state_dict = torch.load(model_fname, map_location='cpu')
model.load_state_dict(state_dict)
print('loaded %s' % model_fname)

n_param = sum(p.numel() for p in model.parameters() if p.requires_grad)
print('%d trainable parameters in model' % n_param)


In [None]:
transformed.shape

In [None]:
masks = torch.ones(1, tile_encoding.shape[0], dtype=torch.bool)
print(masks.shape)

In [None]:

emb = model(transformed, masks)

In [None]:
emb.shape

In [None]:
emb