In [None]:
import pandas as pd
import torch

from srai.joiners import IntersectionJoiner
from srai.loaders.osm_loaders.filters import GEOFABRIK_LAYERS
from srai.plotting import plot_regions
from srai.regionalizers import S2Regionalizer, geocode_to_region_gdf

area_gdf = geocode_to_region_gdf("Grand Est")
plot_regions(area_gdf, tiles_style="CartoDB positron")

In [2]:
img_resolution = 8
patch_resolution = 12

img_regionalizer = S2Regionalizer(resolution=img_resolution, buffer=False)
img_s2_regions = img_regionalizer.transform(area_gdf.reset_index(drop=True))

img_s2_geometry = img_s2_regions.union_all()

print("Image regions:", len(img_s2_regions))

Image regions: 34


In [None]:
plot_regions(
    img_s2_regions,
    tiles_style="CartoDB positron",
)

In [None]:
from torch.utils.data import DataLoader, random_split

from srai.embedders.count_embedder import CountEmbedder
from srai.embedders.s2vec.dataset import S2VecDataset
from srai.embedders.s2vec.s2_utils import get_patches_from_img_gdf
from srai.loaders.osm_loaders.osm_pbf_loader import OSMPbfLoader

tags = GEOFABRIK_LAYERS
loader = OSMPbfLoader()
joiner = IntersectionJoiner()

preproc_batch_size = 1

# arrays = []
data = pd.DataFrame()
img_patch_joint_gdf = pd.DataFrame()


for i in range(0, len(img_s2_regions), preproc_batch_size):
    batch = img_s2_regions.iloc[i:i + preproc_batch_size]
    print(f"Processing batch {i // preproc_batch_size + 1} of {len(img_s2_regions) // preproc_batch_size + 1}")
    batch_features_gdf = loader.load(batch, tags)
    patches_gdf, batch_img_patch_joint_gdf = get_patches_from_img_gdf(
    img_gdf=batch, target_level=patch_resolution
    )
    joiner = IntersectionJoiner()
    patch_feature_joint_gdf = joiner.transform(patches_gdf, batch_features_gdf)
    count_embedder = CountEmbedder(GEOFABRIK_LAYERS)
    counts_df = count_embedder.transform(patches_gdf, batch_features_gdf, patch_feature_joint_gdf)
    data = pd.concat([data, counts_df])
    img_patch_joint_gdf = pd.concat([img_patch_joint_gdf, batch_img_patch_joint_gdf])
    # arrays.append(count_embedder.transform(patches_gdf, batch_features_gdf, patch_feature_joint_gdf).values)

In [5]:
ds = S2VecDataset(data=data, img_patch_joint_gdf=img_patch_joint_gdf)

100%|██████████| 34/34 [00:00<00:00, 5524.17it/s]


In [6]:
# Define split sizes
train_size = int(0.8 * len(ds))
val_size = len(ds) - train_size

# Split the dataset
train_ds, val_ds = random_split(ds, [train_size, val_size])

# Create DataLoaders with optimized settings
batch_size = 16

train_loader = DataLoader(
    train_ds,
    batch_size=batch_size,
    shuffle=True,
    num_workers=2,   # Increase if your system supports it
    pin_memory=True  # Enable if using CUDA
)
val_loader = DataLoader(
    val_ds,
    batch_size=batch_size,
    shuffle=False,
    num_workers=2,
    pin_memory=True
)

In [9]:
from srai.embedders.s2vec import S2VecModel

model = S2VecModel(
    img_size=16,
    patch_size=1,
    in_ch=347,
    lr=1e-5
)

In [None]:
from pytorch_lightning import Trainer
from pytorch_lightning.loggers import WandbLogger
from pytorch_lightning.callbacks import EarlyStopping
import wandb

wandb_logger = WandbLogger(project="s2vec-embedding")
early_stopping = EarlyStopping(monitor="validation_loss", patience=5, mode="min")

trainer = Trainer(
    max_epochs=20,
    accelerator="auto",
    logger=wandb_logger,
    callbacks=[early_stopping],
)

# To start training:
trainer.fit(model, train_loader, val_loader)

wandb.finish()