In [None]:
import pandas as pd
import torch

from srai.joiners import IntersectionJoiner
from srai.loaders.osm_loaders.filters import GEOFABRIK_LAYERS
from srai.plotting import plot_regions
from srai.regionalizers import S2Regionalizer, geocode_to_region_gdf

area_gdf = geocode_to_region_gdf("Metropolitan France")
plot_regions(area_gdf, tiles_style="CartoDB positron")

In [8]:
img_resolution = 8
patch_resolution = 12

img_regionalizer = S2Regionalizer(resolution=img_resolution, buffer=False)
img_s2_regions = img_regionalizer.transform(area_gdf.reset_index(drop=True))

img_s2_geometry = img_s2_regions.union_all()

print("Image regions:", len(img_s2_regions))

Image regions: 598


In [None]:
plot_regions(
    img_s2_regions,
    tiles_style="CartoDB positron",
)

In [None]:
from torch.utils.data import DataLoader, random_split

from srai.embedders.count_embedder import CountEmbedder
from srai.embedders.s2vec.dataset import S2VecDataset
from srai.embedders.s2vec.s2_utils import get_patches_from_img_gdf
from srai.loaders.osm_loaders.osm_pbf_loader import OSMPbfLoader

tags = GEOFABRIK_LAYERS
loader = OSMPbfLoader(pbf_file="france-latest.osm.pbf")
joiner = IntersectionJoiner()

preproc_batch_size = 1

for i in range(0, len(img_s2_regions), preproc_batch_size):
    batch = img_s2_regions.iloc[i:i + preproc_batch_size]
    print(f"Processing batch {i // preproc_batch_size + 1} of {len(img_s2_regions) // preproc_batch_size + 1}")
    batch_features_gdf = loader.load(batch, tags)
    patches_gdf, img_patch_joint_gdf = get_patches_from_img_gdf(
    img_gdf=img_s2_regions, target_level=patch_resolution
    )
    joiner = IntersectionJoiner()
    patch_feature_joint_gdf = joiner.transform(patches_gdf, batch_features_gdf)
    count_embedder = CountEmbedder(GEOFABRIK_LAYERS)
    counts_df = count_embedder.transform(patches_gdf, batch_features_gdf, patch_feature_joint_gdf)
    print(counts_df)
    break