# Pre-training hex2vec model

This notebook shows step-by-step how to pre-train larger hex2vec model using SRAI library 

## Data selection

This example works on Polish cities, with 50k+ inhibitants.

In [None]:
cities = [
    "Warszawa, Polska",
    "Kraków, Polska",
    "Łódź, Polska",
    "Wrocław, Polska",
    "Poznań, Polska",
    "Gdańsk, Polska",
    "Szczecin, Polska",
    "Bydgoszcz, Polska",
    "Lublin, Polska",
    "Białystok, Polska",
    "Katowice, Polska",
    "Gdynia, Polska",
    "Częstochowa, Polska",
    "Radom, Polska",
    "Toruń, Polska",
    "Rzeszów, Polska",
    "Sosnowiec, Polska",
    "Kielce, Polska",
    "Gliwice, Polska",
    "Olsztyn, Polska",
    "Zabrze, Polska",
    "Bielsko-Biała, Polska",
    "Bytom, Polska",
    "Zielona Góra, Polska",
    "Rybnik, Polska",
    "Ruda Śląska, Polska",
    "Opole, Polska",
    "Tychy, Polska",
    "Gorzów Wielkopolski, Polska",
    "Elbląg, Polska",
    "Dąbrowa Górnicza, Polska",
    "Płock, Polska",
    "Wałbrzych, Polska",
    "Włocławek, Polska",
    "Tarnów, Polska",
    "Chorzów, Polska",
    "Koszalin, Polska",
    ##50k
    "Kalisz, Polska",
    "Legnica, Polska",
    "Grudziądz, Polska",
    "Jaworzno, Polska",
    "Słupsk, Polska",
    "Jastrzębie-Zdrój, Polska",
    "Nowy Sącz, Polska",
    "Jelenia Góra, Polska",
    "Siedlce, Polska",
    "Mysłowice, Polska",
    "Konin, Polska",
    "Piła, Polska",
    "Piotrków Trybunalski, Polska",
    "Lubin, Polska",
    "Inowrocław, Polska",
    "Ostrów Wielkopolski, Polska",
    "Suwałki, Polska",
    "Stargard, Polska",
    "Gniezno, Polska",
    "Ostrowiec Świętokrzyski, Polska",
    "Siemianowice Śląskie, Polska",
    "Głogów, Polska",
    "Pabianice, Polska",
    "Leszno, Polska",
    "Żory, Polska",
    "Zamość, Polska",
    "Pruszków, Polska",
    "Łomża, Polska",
    "Ełk, Polska",
    "Tarnowskie Góry, Polska",
    "Tomaszów Mazowiecki, Polska",
    "Chełm, Polska",
    "Mielec, Polska",
    "Kędzierzyn-Koźle, Polska",
    "Przemyśl, Polska",
    "Stalowa Wola, Polska",
    "Tczew, Polska",
    "Biała Podlaska, Polska",
    "Bełchatów, Polska",
    "Świdnica, Polska",
    "Będzin, Polska",
    "Zgierz, Polska",
    "Piekary Śląskie, Polska",
    "Racibórz, Polska",
    "Legionowo, Polska",
    "Ostrołęka, Polska",
]

In [None]:
num_people = [
    1794166,
    779966,
    672185,
    641928,
    532048,
    470805,
    398255,
    344091,
    338586,
    296958,
    290553,
    244969,
    217530,
    209296,
    198613,
    197863,
    197586,
    193415,
    177049,
    171249,
    170924,
    169756,
    163255,
    140892,
    137128,
    136423,
    127839,
    126871,
    122589,
    118582,
    118285,
    118268,
    109971,
    108561,
    107498,
    106846,
    106235,
    99106,
    98436,
    93564,
    90368,
    89780,
    88038,
    83558,
    78335,
    77813,
    74559,
    72539,
    72527,
    72250,
    71710,
    71674,
    71560,
    69639,
    67579,
    67570,
    67404,
    66270,
    66120,
    63945,
    62854,
    62844,
    62785,
    62623,
    62573,
    61903,
    61756,
    61338,
    61135,
    60075,
    60021,
    59779,
    59623,
    59430,
    56942,
    56419,
    56222,
    56008,
    55673,
    54702,
    54259,
    53529,
    51656,
]

Select h3 resolution

In [None]:
RESOLUTION = 9

We will want to use train/test split by cites. This approach ensures that cities are evently distributed based on their size

In [None]:
import pandas as pd
df = pd.DataFrame({"city": cities, "num_people": num_people})
def get_stratify_index(num_people) -> int:
    if num_people >= 500_000:
        return 0
    elif num_people >= 200_000:
        return 1
    elif num_people >= 100_000:
        return 2
    else:
        return 3
df["stratify"] = df["num_people"].apply(get_stratify_index)
df

## Downloading data

In [None]:
from srai.regionalizers import H3Regionalizer
from tqdm.auto import tqdm
from srai.regionalizers import geocode_to_region_gdf
import pandas as pd

Loading cities boundaries

In [None]:
areas = [
    geocode_to_region_gdf(city) for city in tqdm(cities)
]
pd.concat(areas).explore()

Split cities into H3 regions and create train and val datasets

In [None]:
import geopandas as gpd
from typing import List

def get_regions(cities: List[str], resolution: int) -> gpd.GeoDataFrame:
    areas = [
        geocode_to_region_gdf(city) for city in tqdm(cities)
    ]
    regionizer = H3Regionalizer(resolution=resolution)
    regions_gdf = regionizer.transform(pd.concat(areas))
    return regions_gdf

In [None]:
from sklearn.model_selection import train_test_split

train_df, val_df = train_test_split(df, test_size=0.2, random_state=42, stratify=df["stratify"])
train_df = train_df.sort_values(by="num_people", ascending=False)
val_df = val_df.sort_values(by="num_people", ascending=False)

In [None]:
train_cities = train_df["city"].tolist()
val_cities = val_df["city"].tolist()

In [None]:
train_regions_gdf = get_regions(train_cities, resolution=RESOLUTION)
val_regions_gdf = get_regions(val_cities, resolution=RESOLUTION)

hex2vec training requires neighbourhood information, we use SRAI implementation of this method

In [None]:
from srai.neighbourhoods import H3Neighbourhood

train_neighbourhood = H3Neighbourhood(train_regions_gdf)
val_neighbourhood = H3Neighbourhood(val_regions_gdf)

Downloading OSM data

In [None]:
from srai.loaders.osm_loaders.filters import HEX2VEC_FILTER
from srai.embedders.hex2vec import Hex2VecEmbedder
from functional import seq


expected_output_features = seq(HEX2VEC_FILTER.items()).starmap(lambda k, v: list(map(lambda x: f"{k}_{x}", v))).flatten().list()
expected_output_features

This takes long time, so we save features to `parquet` files and load them later

In [None]:
from srai.loaders import OSMPbfLoader
from pathlib import Path

save_dir = Path("data/raw").resolve()
save_dir.mkdir(parents=True, exist_ok=True)

loader = OSMPbfLoader()

for city in cities:
    features = loader.load(geocode_to_region_gdf(city), HEX2VEC_FILTER)
    features.to_parquet(save_dir / f"{city}.parquet")

In [None]:
def get_features(cities: List[str]) -> gpd.GeoDataFrame:
    features = [
        gpd.read_parquet(f"data/raw/{city}.parquet") for city in tqdm(cities)
    ]
    features = pd.concat(features)
    features = features[~features.index.duplicated(keep='first')]
    return features

In [None]:
train_features = get_features(train_cities)
val_features = get_features(val_cities)

In [None]:
train_regions_gdf

Mathing features with H3 regions

In [None]:
from srai.joiners import IntersectionJoiner

joiner = IntersectionJoiner()
train_joint = joiner.transform(train_regions_gdf, train_features)
val_joint = joiner.transform(val_regions_gdf, val_features)

## Train model

In [None]:
model_size = [256, 128, 64]
EPOCHS = 20

In [None]:
from pytorch_lightning.loggers import WandbLogger
from pytorch_lightning.callbacks import ModelCheckpoint
import wandb
from srai.embedders import Hex2VecEmbedder

checkpoint_callback = ModelCheckpoint(save_top_k=2, monitor="val_f1", mode="max", filename="{epoch}-{val_f1:.2f}", dirpath=f"models_{RESOLUTION}/{model_size}")
embedder = Hex2VecEmbedder(encoder_sizes=model_size, expected_output_features=expected_output_features)
logger = WandbLogger(project=f"hex2vec_pl_r{RESOLUTION}_b1024_50k", name=f"model_sizes={model_size}")
embedder.fit_transform(train_regions_gdf, train_features, train_joint, train_neighbourhood, val_regions_gdf, val_features, val_joint, val_neighbourhood, batch_size=1024, trainer_kwargs={"max_epochs": EPOCHS, "accelerator": "gpu", "logger": logger, "callbacks": [checkpoint_callback]})
wandb.finish()

And done ;) We have succesfully pre-trained our model!