In [1]:
import geopandas as gpd
import pandas as pd
from shapely import geometry

from srai.constants import REGIONS_INDEX, WGS84_CRS
from srai.embedders import CountEmbedder
from srai.joiners import IntersectionJoiner
from srai.loaders.osm_loaders import OSMOnlineLoader, OSMPbfLoader
from srai.loaders.osm_loaders.filters import HEX2VEC_FILTER
from srai.regionizers import H3Regionizer, SlippyMapRegionizer
from srai.utils import geocode_to_region_gdf

  warn(f"Failed to load image Python extension: {e}")


In [2]:

def get_embedding_for_region(region_name: str) -> pd.DataFrame:
    bbox_gdf = geocode_to_region_gdf(region_name)
    regionizer = SlippyMapRegionizer(z=16)
    regions_gdf = regionizer.transform(bbox_gdf)
    loader = OSMPbfLoader()
    features_gdf = loader.load(bbox_gdf, tags=HEX2VEC_FILTER)
    joiner = IntersectionJoiner()
    joint_gdf = joiner.transform(regions_gdf, features_gdf)
    embedder = CountEmbedder()
    embedding = embedder.transform(regions_gdf, features_gdf, joint_gdf)
    return embedding

In [None]:
from tqdm import tqdm


with open('data/cities_v3.txt', 'r') as cities_file:
    for city in cities_file:
        print(city)
        city = city.replace("\n", "")
        path = f"data/embeddings/{city}.pkl"
        df = get_embedding_for_region(city)
        df.to_pickle(path)

In [None]:
pd.read_pickle('data/embeddings/Tirana, Albania.pkl')

Unnamed: 0_level_0,aeroway_hangar,aeroway_helipad,amenity_arts_centre,amenity_atm,amenity_bank,amenity_bar,amenity_bbq,amenity_bench,amenity_bicycle_parking,amenity_bicycle_rental,...,tourism_museum,tourism_picnic_site,tourism_viewpoint,water_lake,water_pond,water_reservoir,water_river,waterway_canal,waterway_river,waterway_stream
region_id,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1,Unnamed: 7_level_1,Unnamed: 8_level_1,Unnamed: 9_level_1,Unnamed: 10_level_1,Unnamed: 11_level_1,Unnamed: 12_level_1,Unnamed: 13_level_1,Unnamed: 14_level_1,Unnamed: 15_level_1,Unnamed: 16_level_1,Unnamed: 17_level_1,Unnamed: 18_level_1,Unnamed: 19_level_1,Unnamed: 20_level_1,Unnamed: 21_level_1
"(36364, 24482)",0,0,0,0,0,1,0,0,0,0,...,0,0,0,0,0,0,0,0,0,0
"(36365, 24482)",0,0,0,0,0,0,0,0,0,0,...,0,0,0,0,0,0,0,0,1,0
"(36366, 24482)",0,0,0,0,0,0,0,0,0,0,...,0,0,0,0,0,0,0,0,0,0
"(36364, 24483)",0,0,0,0,0,0,0,0,0,0,...,0,0,0,0,0,0,0,0,0,0
"(36365, 24483)",0,0,0,0,0,0,0,0,0,0,...,0,0,0,0,0,0,0,0,1,0
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
"(36377, 24499)",0,0,0,0,0,0,0,0,0,0,...,0,0,0,0,0,0,0,0,0,0
"(36378, 24499)",0,0,0,0,0,0,0,0,0,0,...,0,0,0,0,0,0,0,0,0,0
"(36379, 24499)",0,0,0,0,0,0,0,0,0,0,...,0,0,0,0,0,0,0,1,0,0
"(36380, 24499)",0,0,0,0,0,0,0,0,0,0,...,0,0,0,0,0,0,0,7,0,0


In [None]:
folium_map = plot_regions(regions_gdf, tiles_style="CartoDB positron", colormap=["rgba(0,0,0,0)"])
df.reset_index().explore(m=folium_map, column=REGIONS_INDEX, cmap=px.colors.qualitative.Bold)