In [1]:
%load_ext autoreload
%autoreload 2

# Training GeoVex model


In [1]:
import os
import sys
from pathlib import Path

# add the
ROOT = Path(os.getcwd())
while not (ROOT / ".git").exists():
    ROOT = ROOT.parent

sys.path.append(str(ROOT))

from src.config import CargoBikeConfig, load_config
from src.osm_tags import build_tag_filter


import polars as pl
from h3ronpy.polars import grid_disk
from srai.h3 import ring_buffer_h3_indexes
from srai.neighbourhoods.h3_neighbourhood import H3Neighbourhood
from srai.embedders.geovex.model import GeoVexModel
from srai.embedders.geovex.dataset import HexagonalDataset

  from .autonotebook import tqdm as notebook_tqdm


## Load Config


In [3]:
config = load_config(ROOT / "config" / "paper.yaml")

target_tags = build_tag_filter(config)
target_tag_list = [f"{t}_{st}" for t, sts in target_tags.items() for st in sts]

## Load the Tag Data


In [4]:
count_df_file = ROOT / "data" / "geovex" / "count.parquet"

if not count_df_file.exists():
    count_df_file.parent.mkdir(parents=True, exist_ok=True)

    count_df = pl.concat(
        pl.scan_parquet(city.count_file).with_columns(pl.lit(city.name).alias("city"))
        for city in config.Cities
    ).collect()

    count_df.write_parquet(
        count_df_file,
        compression="snappy",
    )
else:
    count_df = pl.read_parquet(count_df_file)

count_df.head()

amenity_bar,amenity_biergarten,amenity_cafe,amenity_fast_food,amenity_food_court,amenity_ice_cream,amenity_pub,amenity_restaurant,amenity_college,amenity_driving_school,amenity_kindergarten,amenity_language_school,amenity_library,amenity_toy_library,amenity_music_school,amenity_school,amenity_university,amenity_bicycle_parking,amenity_bicycle_repair_station,amenity_bicycle_rental,amenity_boat_rental,amenity_boat_sharing,amenity_bus_station,amenity_car_rental,amenity_car_sharing,amenity_car_wash,amenity_vehicle_inspection,amenity_charging_station,amenity_ferry_terminal,amenity_fuel,amenity_grit_bin,amenity_motorcycle_parking,amenity_parking,amenity_parking_entrance,amenity_parking_space,amenity_taxi,amenity_atm,…,sport_zurkhaneh_sport,water_river,water_oxbow,water_canal,water_ditch,water_lock,water_fish_pass,water_lake,water_reservoir,water_pond,water_basin,water_lagoon,water_stream_pool,water_reflecting_pool,water_moat,water_wastewater,waterway_river,waterway_riverbank,waterway_stream,waterway_tidal_channel,waterway_canal,waterway_drain,waterway_ditch,waterway_pressurised,waterway_fairway,waterway_dock,waterway_boatyard,waterway_dam,waterway_weir,waterway_waterfall,waterway_lock_gate,waterway_soakhole,waterway_turning_point,waterway_water_point,waterway_fuel,region_id,city
i64,i64,i64,i64,i64,i64,i64,i64,i64,i64,i64,i64,i64,i64,i64,i64,i64,i64,i64,i64,i64,i64,i64,i64,i64,i64,i64,i64,i64,i64,i64,i64,i64,i64,i64,i64,i64,…,i64,i64,i64,i64,i64,i64,i64,i64,i64,i64,i64,i64,i64,i64,i64,i64,i64,i64,i64,i64,i64,i64,i64,i64,i64,i64,i64,i64,i64,i64,i64,i64,i64,i64,i64,str,str
0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,…,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,"""892a3029a33fff…","""Boston, USA"""
0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,…,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,"""892a3065c8ffff…","""Boston, USA"""
0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,…,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,"""892a302b077fff…","""Boston, USA"""
0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,1,0,0,0,0,…,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,"""892a3067227fff…","""Boston, USA"""
0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,…,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,3,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,"""892a339146ffff…","""Boston, USA"""


### Clean the Dataset


#### Drop Hexagons with No Building / Highway

We do this b.c. we are only focused on hexagons that could have deliveries


In [5]:
keep_h3s = ring_buffer_h3_indexes(
    count_df.filter(
        (pl.sum_horizontal(pl.col("^building_.*$")) > 0)
        | (pl.sum_horizontal(pl.col("^highway_.*$")) > 0),
    )["region_id"],
    config.GeoVex.radius,
)

count_df = count_df.filter(pl.col("region_id").is_in(keep_h3s))

#### Drop Columns that Sum to Zero


In [6]:
# drop columns that sum to 0
count_df = count_df.drop(
    count_df.select([pl.col(target_tag_list).sum() <= 0])
    .transpose(
        include_header=True,
        column_names=["drop"],
    )
    .filter(pl.col("drop"))["column"]
)

target_tag_list = sorted(
    set(count_df.columns).intersection(set(target_tag_list))
)

## Create the Dataset

In [7]:
from torch.utils.data import DataLoader

train_df = count_df.select(target_tag_list + ["region_id", ]).to_pandas().set_index("region_id")

dataset = HexagonalDataset(
    train_df,
    neighbourhood=H3Neighbourhood(),
    neighbor_k_ring=config.GeoVex.radius,
)

100%|██████████| 57067/57067 [00:32<00:00, 1737.00it/s]


In [8]:
dataloader = DataLoader(
    dataset,
    batch_size=1024,
    shuffle=False,
    num_workers=10,
)

## Train the Model

In [9]:
model = GeoVexModel(
    k_dim=train_df.shape[1],
    radius=config.GeoVex.radius,
    conv_layers=3,
    emb_size=50,
    learning_rate=0.00025,
)

In [10]:
import os
import torch
import wandb
from pytorch_lightning.loggers.wandb import WandbLogger
import pytorch_lightning as pl_lightning
from pytorch_lightning.callbacks import LambdaCallback, ModelCheckpoint
from pytorch_lightning import seed_everything

seed_everything(42)

# this is required for some reason to get the data to save correctly
wandb.finish()

epochs = 100

os.environ["WANDB_NOTEBOOK_NAME"] = "cluster-word2vec.ipynb"

logger = WandbLogger(
    project="GeoVex",
    log_model="all",
)

checkpoint_callback = ModelCheckpoint(
    monitor="train_loss",
    mode="min",
)

trainer = pl_lightning.Trainer(
    accelerator="cuda" if torch.cuda.is_available() else "cpu",
    max_epochs=epochs,
    logger=logger,
    callbacks=[checkpoint_callback],
    log_every_n_steps=20
)

Seed set to 42


GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs


### Train

In [11]:
trainer.fit(model, dataloader)

logger.experiment.finish()

run = logger.experiment

train_df.columns.to_series(name="column").to_csv(
    Path(run.dir) / "columns.csv", index=False
)


/home/shadeform/.local/lib/python3.10/site-packages/pytorch_lightning/trainer/configuration_validator.py:74: You defined a `validation_step` but have no `val_dataloader`. Skipping val loop.
You are using a CUDA device ('NVIDIA A100-SXM4-80GB') that has Tensor Cores. To properly utilize them, you should set `torch.set_float32_matmul_precision('medium' | 'high')` which will trade-off precision for performance. For more details, read https://pytorch.org/docs/stable/generated/torch.set_float32_matmul_precision.html#torch.set_float32_matmul_precision


[34m[1mwandb[0m: Currently logged in as: [33m_max_[0m ([33mgreen-last-mile[0m). Use [1m`wandb login --relogin`[0m to force relogin


LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]

  | Name    | Type       | Params
---------------------------------------
0 | encoder | Sequential | 8.3 M 
1 | decoder | Sequential | 7.1 M 
2 | _loss   | GeoVeXLoss | 0     
---------------------------------------
15.4 M    Trainable params
0         Non-trainable params
15.4 M    Total params
61.597    Total estimated model params size (MB)


Epoch 99: 100%|██████████| 40/40 [00:18<00:00,  2.14it/s, v_num=u8ry]

`Trainer.fit` stopped: `max_epochs=100` reached.


Epoch 99: 100%|██████████| 40/40 [00:19<00:00,  2.07it/s, v_num=u8ry]




0,1
epoch,▁▁▁▁▂▂▂▂▂▃▃▃▃▃▃▄▄▄▄▄▅▅▅▅▅▅▆▆▆▆▆▇▇▇▇▇▇███
train_loss_epoch,█▄▂▂▂▂▂▂▂▂▂▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁
train_loss_step,█▄▃▃▃▂▂▂▂▂▂▂▂▂▂▂▂▂▂▂▂▂▂▂▂▂▂▂▂▂▂▁▁▁▁▁▁▁▁▁
trainer/global_step,▁▁▁▁▂▂▂▂▂▃▃▃▃▃▃▄▄▄▄▄▅▅▅▅▅▅▆▆▆▆▆▇▇▇▇▇▇███

0,1
epoch,99.0
train_loss_epoch,1235.62671
train_loss_step,99.62657
trainer/global_step,3999.0


## Create the Embeddings

In [12]:
import pandas as pd
import numpy as np
from srai.constants import REGIONS_INDEX

embeddings = [
    model.cuda().encoder(batch.cuda()).cpu().detach().numpy() for batch in dataloader  # type: ignore
]

df = pd.DataFrame(np.concatenate(embeddings), index=dataset.get_ordered_index())
df.index.name = REGIONS_INDEX

df.to_parquet(Path(run.dir) / "embeddings.parquet", compression="snappy")

In [14]:
(Path(run.dir) / "embeddings.parquet").absolute()

PosixPath('/home/shadeform/cargo-bike-analysis/notebooks/embeddings/wandb/run-20231015_230811-ffs9u8ry/files/embeddings.parquet')

In [14]:
df

Unnamed: 0_level_0,0,1,2,3,4,5,6,7,8,9,...,40,41,42,43,44,45,46,47,48,49
region_id,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1,Unnamed: 7_level_1,Unnamed: 8_level_1,Unnamed: 9_level_1,Unnamed: 10_level_1,Unnamed: 11_level_1,Unnamed: 12_level_1,Unnamed: 13_level_1,Unnamed: 14_level_1,Unnamed: 15_level_1,Unnamed: 16_level_1,Unnamed: 17_level_1,Unnamed: 18_level_1,Unnamed: 19_level_1,Unnamed: 20_level_1,Unnamed: 21_level_1
892a339a5afffff,-12.543304,-36.702934,-35.427788,62.237679,14.633109,85.349991,-23.133249,45.920357,-51.708145,45.033699,...,-16.903761,26.290535,29.346359,-15.618069,20.286274,1.167608,63.781994,57.566757,-23.694431,-4.005977
892a3066a3bffff,31.329262,-41.410572,-55.827526,78.853081,25.929245,49.997467,-38.882084,-5.730484,-58.445007,-24.405167,...,-35.929619,45.780430,36.663208,-16.449844,32.866051,-2.159680,78.207001,33.529587,-30.632816,-9.345604
892a3066e17ffff,17.920731,-50.485878,-39.077431,93.122025,24.191360,27.859776,-39.975128,-1.498533,-70.378983,-16.668730,...,-39.374737,9.126648,42.700592,3.464363,12.610657,-5.984903,88.849754,48.797073,-30.830357,4.599929
892a3066b3bffff,44.426163,-43.210972,-60.217773,71.468353,24.534119,67.819855,-35.414818,-4.939835,-55.482632,-30.228302,...,-33.145702,71.767830,35.279186,-21.912142,40.395546,-6.421528,71.696129,37.463757,-32.746212,-9.986138
892a3066803ffff,-1.315850,-37.682495,-35.074394,74.233047,17.954483,61.830872,-30.686661,21.605570,-57.101604,32.648930,...,-30.076464,15.182814,35.788933,-17.001144,0.092901,-2.930601,71.126785,51.737259,-32.164906,2.179149
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
891fa441d47ffff,28.075670,-25.931828,-40.101681,97.436829,37.152378,33.094837,-40.951519,-9.537457,-62.798843,-22.340813,...,-31.757479,10.416734,44.635113,3.639849,-25.776157,-0.614910,86.354729,31.092669,-11.332111,0.074496
891fa44f093ffff,-4.758158,0.949384,-22.122681,74.865120,21.037386,58.749435,-26.209946,16.376972,-44.618019,42.069408,...,-4.049875,7.629431,24.520319,-7.363722,-1.790652,-10.719803,69.532257,17.530775,2.926747,2.576683
891fa4401c7ffff,10.786926,-25.158751,-29.395390,86.569366,35.706261,56.238113,-37.394676,34.304871,-54.556118,47.895363,...,-2.857403,-1.355015,36.963276,0.742130,-58.049046,-7.530038,78.168404,37.680874,-11.993789,20.742382
891fa44f4abffff,22.865328,-10.863411,-30.502930,84.571190,29.671062,53.162815,-37.400650,5.071507,-52.505394,13.697501,...,-7.959752,26.386868,32.800007,-12.766939,-5.552060,-15.301873,78.653915,24.157524,-9.235600,2.773456
