In [None]:
import copy
import logging
import os
import warnings
from typing import Callable, Optional, Union

import geopandas as gpd
import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
from datasets import Dataset
from shapely.geometry import Polygon, box
from sklearn.preprocessing import MinMaxScaler
from torch.optim import Optimizer
from torch.utils.data import DataLoader
from tqdm import tqdm

from srai.benchmark import BaseEvaluator, HexRegressionEvaluator
from srai.datasets import (
    ChicagoCrimeDataset,
)
from srai.embedders import Hex2VecEmbedder
from srai.h3 import h3_to_geoseries
from srai.joiners import IntersectionJoiner
from srai.loaders.osm_loaders import OSMPbfLoader
from srai.loaders.osm_loaders.filters import HEX2VEC_FILTER
from srai.neighbourhoods.h3_neighbourhood import H3Neighbourhood
from srai.plotting import plot_numeric_data
from srai.regionalizers import H3Regionalizer

In [None]:
device = "cuda" if torch.cuda.is_available() else "cpu"
resolution = 8
embedder_hidden_sizes = [150, 75, 50]

In [None]:
scaler = MinMaxScaler()
regionalizer = H3Regionalizer(resolution=resolution)

In [None]:
crimes = ChicagoCrimeDataset()
ds = crimes.load(version=str(resolution))
train, test = ds["train"], ds["test"]

In [None]:
train.head()

Create dev split from train split

In [None]:
train, dev = crimes.train_test_split_bucket_regression(
    test_size=0.1, dev=True, resolution=resolution
)

In [None]:
crimes.target

In [None]:
_, test_h3 = crimes.get_h3_with_labels()

Get information about available categorical and numerical columns

In [None]:
crimes.categorical_columns, crimes.numerical_columns

In [None]:
train_ = train.copy()
dev_ = dev.copy()
test_ = test.copy()

In [None]:
train.head()

Get h3 indexes for data points

In [None]:
regions_train = regionalizer.transform(train_)
full_geometry = regions_train.unary_union.convex_hull.buffer(0.1)

full_regions = regionalizer.transform(
    gpd.GeoDataFrame(["full"], geometry=[full_geometry]).set_crs(regions_train.crs)
)

In [None]:
regions_train = regionalizer.transform(train_)
joined_train = gpd.sjoin(train_, regions_train, how="left", predicate="within")  # noqa: E501
joined_train.rename(columns={"index_right": "h3_index"}, inplace=True)

regions_dev = regionalizer.transform(dev_)
joined_dev = gpd.sjoin(dev_, regions_dev, how="left", predicate="within")  # noqa: E501
joined_dev.rename(columns={"index_right": "h3_index"}, inplace=True)


regions_test = regionalizer.transform(test_)
joined_test = gpd.sjoin(test_, regions_test, how="left", predicate="within")  # noqa: E501
joined_test.rename(columns={"index_right": "h3_index"}, inplace=True)

Scale numerical data

In [None]:
# group to hex (results in count of crimes per hex)
train_counts_per_hex = joined_train.groupby("h3_index").size().reset_index(name="count")
dev_counts_per_hex = joined_dev.groupby("h3_index").size().reset_index(name="count")
test_counts_per_hex = joined_test.groupby("h3_index").size().reset_index(name="count")

# scale the hex-level counts using MinMaxScaler

train_counts_per_hex["count"] = scaler.fit_transform(train_counts_per_hex[["count"]])
dev_counts_per_hex["count"] = scaler.transform(dev_counts_per_hex[["count"]])
dev_counts_per_hex["count"] = np.clip(dev_counts_per_hex["count"], 0, 1)
test_counts_per_hex["count"] = scaler.transform(test_counts_per_hex[["count"]])
test_counts_per_hex["count"] = np.clip(test_counts_per_hex["count"], 0, 1)

Embed h3 regions to vectors. Use srai library to train spatial embeddings on train dataset with chosen embedder type (i.e. Hex2Vec, GeoVex ) and use it to get embeddings for hexagons in train, dev and test split.

In [None]:
from srai.h3 import ring_buffer_h3_regions_gdf

buffered_regions_train = ring_buffer_h3_regions_gdf(regions_train, 2)
buffered_regions_dev = ring_buffer_h3_regions_gdf(regions_dev, 2)
buffered_regions_test = ring_buffer_h3_regions_gdf(regions_test, 2)


osm_features = OSMPbfLoader().load(full_regions, HEX2VEC_FILTER)
region_intersect_train = IntersectionJoiner().transform(buffered_regions_train, osm_features)

# # For CCE or CE usage
# neighbourhood = H3Neighbourhood(full_regions)
# embedder = ContextualCountEmbedder(neighbourhood=neighbourhood,
#                                    neighbourhood_distance=2,
#                                    expected_output_features=HEX2VEC_FILTER,
#                                     concatenate_vectors=True,
#                                     count_subcategories=True)
# embedder = CountEmbedder(expected_output_features=HEX2VEC_FILTER)

# # For H2V usage
embedder = Hex2VecEmbedder(embedder_hidden_sizes)
neighbourhood = H3Neighbourhood(buffered_regions_train)


# # For GV usage
# embedder = GeoVexEmbedder(target_features=HEX2VEC_FILTER, neighbourhood_radius=2)
# neighbourhood = H3Neighbourhood(full_regions)

# Needed for H2V and GV. Comment fitting block out for CCE and CE
with warnings.catch_warnings():
    warnings.simplefilter("ignore")
    embedder.fit(
        regions_gdf=buffered_regions_train,
        features_gdf=osm_features,
        joint_gdf=region_intersect_train,
        neighbourhood=neighbourhood,
        trainer_kwargs={"max_epochs": 10, "accelerator": device},
    )

In [None]:
embeddings_train = embedder.transform(
    regions_gdf=buffered_regions_train,
    features_gdf=osm_features,
    joint_gdf=region_intersect_train,
)
embeddings_train["h3"] = embeddings_train.index

In [None]:
osm_features_dev = OSMPbfLoader().load(buffered_regions_dev, HEX2VEC_FILTER)
osm_features_test = OSMPbfLoader().load(buffered_regions_test, HEX2VEC_FILTER)

region_intersect_dev = IntersectionJoiner().transform(buffered_regions_dev, osm_features_dev)
region_intersect_test = IntersectionJoiner().transform(buffered_regions_test, osm_features_test)

embeddings_dev = embedder.transform(
    regions_gdf=buffered_regions_dev,
    features_gdf=osm_features_dev,
    joint_gdf=region_intersect_dev,
)
embeddings_dev["h3"] = embeddings_dev.index

embeddings_test = embedder.transform(
    regions_gdf=buffered_regions_test,
    features_gdf=osm_features_test,
    joint_gdf=region_intersect_test,
)
embeddings_test["h3"] = embeddings_test.index

In [None]:
merged_train = embeddings_train.merge(
    train_counts_per_hex, how="inner", left_on="region_id", right_on="h3_index"
)

merged_dev = embeddings_dev.merge(
    dev_counts_per_hex, how="inner", left_on="region_id", right_on="h3_index"
)

merged_test = embeddings_test.merge(
    test_counts_per_hex, how="inner", left_on="region_id", right_on="h3_index"
)

merge_columns = [
    col for col in merged_train.columns if col not in (["h3"] + [crimes.target] + ["h3_index"])
]

Combine numerical columns with the embedding vector

In [None]:
def concat_columns(row: gpd.GeoSeries) -> np.ndarray:
    """
    Concatenate embedding values together.

    Args:
        row (gpd.GeoSeries): row of embeddings

    Returns:
        np.ndarray: concatenated embedding
    """
    return np.concatenate([np.atleast_1d(val) for val in row.values]).astype(np.float32)

Get final version of data splits (X - embedding vector, X_h3_idx - h3 index, y - target value)

In [None]:
train_dataset = Dataset.from_dict(
    {
        "X": merged_train[merge_columns].apply(concat_columns, axis=1).values,
        "X_h3_idx": merged_train["h3"].values,
        "y": merged_train[crimes.target].values,
    }
)

train_dataset.set_format(type="torch", columns=["X", "X_h3_idx", "y"])


dev_dataset = Dataset.from_dict(
    {
        "X": merged_dev[merge_columns].apply(concat_columns, axis=1).values,
        "X_h3_idx": merged_dev["h3"].values,
        "y": merged_dev[crimes.target].values,
    }
)
dev_dataset.set_format(type="torch", columns=["X", "X_h3_idx", "y"])

test_dataset = Dataset.from_dict(
    {
        "X": merged_test[merge_columns].apply(concat_columns, axis=1).values,
        "X_h3_idx": merged_test["h3"].values,
        "y": merged_test[crimes.target].values,
    }
)
test_dataset.set_format(type="torch", columns=["X", "X_h3_idx", "y"])

In [None]:
embedding_size = train_dataset["X"].shape[1]
embedding_size

Model definition


In [None]:
"""
Regression model

Contains implementation of base model of regression.
"""


class RegressionBaseModel(nn.Module):  # type: ignore
    """
    Regression base model.

    Definition of Regression Model
    """

    def __init__(
        self,
        embeddings_size: int,
        linear_sizes: Optional[list[int]] = None,
        activation_function: Optional[nn.Module] = None,
    ):
        """
        Initializaiton of regression module.

        Args:
            embeddings_size (int): size of input embedding
            linear_sizes (Optional[list[int]], optional): sizes of linear layers inside module. \
                Defaults to [500, 1000].
            activation_function (Optional[nn.Module], optional): activation function from torch.nn \
                Defaults to ReLU.
        """
        super().__init__()
        if linear_sizes is None:
            linear_sizes = [500, 1000]
        if activation_function is None:
            activation_function = nn.ReLU()
        self.model = torch.nn.Sequential()
        previous_size = embeddings_size
        for cnt, size in enumerate(linear_sizes):
            self.model.add_module(f"linear_{cnt}", nn.Linear(previous_size, size))
            self.model.add_module(f"ReLU_{cnt}", activation_function)
            previous_size = size
            if cnt % 2:
                self.model.add_module(f"dropout_{cnt}", nn.Dropout(p=0.2))
        self.model.add_module("linear_final", nn.Linear(previous_size, 1))
        self.model.add_module("sigmoid_output", nn.Sigmoid())

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        """
        Forward pass of the model.

        Args:
            x (torch.Tensor): Vector data

        Returns:
            torch.Tensor: target value
        """
        return self.model(x)

Training parameters

In [None]:
regression_model = RegressionBaseModel(
    embedding_size, linear_sizes=[50, 100, 50], activation_function=nn.Sigmoid()
)
loss_fn = nn.L1Loss()
optimizer = optim.Adam(regression_model.parameters(), lr=0.001)
epochs = 70
batch_size = 32
save_dir = os.getcwd()

In [None]:
train_dataloader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
dev_dataloader = DataLoader(dev_dataset, batch_size=batch_size, shuffle=True)
test_dataloader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False)

In [None]:
evaluator = HexRegressionEvaluator()

In [None]:
def train(
    model: nn.Module,
    train_dataloader: DataLoader,
    dev_dataloader: DataLoader,
    loss_fn: Callable[[torch.Tensor, torch.Tensor], torch.Tensor],
    optimizer: Optimizer,
    evaluator: BaseEvaluator,
    device: Union[str, torch.device] = "cuda",
    save_dir: str = "./",
    epochs: int = 50,
    early_stopping_patience: int = 5,
) -> tuple[float, nn.Module, dict[str, float]]:
    """
    Trains a PyTorch model with early stopping and evaluation.

    Args:
        model (nn.Module): The PyTorch model to train.
        train_dataloader (DataLoader): DataLoader for training data.
        dev_dataloader (DataLoader): DataLoader for validation data.
        loss_fn (Callable): Loss function used for training (e.g., nn.MSELoss).
        optimizer (torch.optim.Optimizer): Optimizer for updating model weights.
        evaluator (Any): Object with a `_compute_metrics` method that accepts predicted and
            target values (as numpy arrays) and returns a dictionary of metric results.
        device (str or torch.device): Device to run the model on ('cpu' or 'cuda').
        save_dir (str): Directory where the best model will be saved.
        epochs (int, optional): Number of training epochs. Defaults to 50.
        early_stopping_patience (int, optional): Number of evaluations without improvement
            before early stopping. Defaults to 5.

    Returns:
        Tuple[float, nn.Module, Dict[str, float]]:
            - Final validation loss,
            - Trained model (with best weights loaded),
            - Final evaluation metrics dictionary.
    """
    stop_counter = 0
    prev_eval_loss = np.inf
    loss_eval: list[float] = []
    loss_train: list[float] = []
    metrics_results: list[dict[str, float]] = []
    best_weights: Optional[dict] = None

    model.to(device)

    for epoch in range(epochs):
        batch_loss_list = []
        model.train()
        for batch in tqdm(train_dataloader, desc=f"Epoch: {epoch}", total=len(train_dataloader)):
            inputs = batch["X"].to(device)
            labels = batch["y"].to(device).reshape(-1, 1)

            outputs = model(inputs)
            loss = loss_fn(outputs, labels)

            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
            batch_loss_list.append(loss.item())

        logging.info(f"Epoch [{epoch+1}/{epochs}], avg_loss: {np.mean(batch_loss_list):.4f}")
        loss_train.append(np.mean(batch_loss_list))

        # Evaluation
        model.eval()
        metrics_per_batch: list[dict[str, float]] = []
        batch_eval_loss: list[float] = []
        with torch.no_grad():
            for i, batch in tqdm(
                enumerate(dev_dataloader), desc="Evaluation", total=len(dev_dataloader)
            ):
                inputs = batch["X"].to(device)
                labels = batch["y"].to(device).reshape(-1, 1)

                outputs = model(inputs)
                loss = loss_fn(outputs, labels)
                batch_eval_loss.append(float(loss.item()))

                metrics = evaluator._compute_metrics(
                    np.asarray(outputs.cpu()), np.asarray(labels.cpu())
                )
                metrics_per_batch.append({"Batch": i, **metrics})

        mean_metrics = {
            key: np.mean([batch[key] for batch in metrics_per_batch])
            for key in metrics_per_batch[0].keys()
            if key != "Batch"
        }
        metrics_results.append(mean_metrics)
        val_loss = np.mean(batch_eval_loss)
        loss_eval.append(val_loss)
        logging.info(f"Evaluation loss: {val_loss:.4f}")

        # Early stopping
        if val_loss >= prev_eval_loss:
            stop_counter += 1
            if stop_counter == early_stopping_patience:
                logging.info(f"Early stopping at epoch {epoch}")
                best_weights = copy.deepcopy(model.state_dict())
                break
        else:
            stop_counter = 0
            best_weights = copy.deepcopy(model.state_dict())
        prev_eval_loss = val_loss

    # Load best weights
    if best_weights:
        model.load_state_dict(best_weights)

    torch.save(model.state_dict(), os.path.join(save_dir, "CAP_best_model.pkl"))
    return val_loss, model, metrics_results[-1] if metrics_results else {}

In [None]:
val_loss, regression_model, metrics = train(
    model=regression_model,
    train_dataloader=train_dataloader,
    dev_dataloader=dev_dataloader,
    loss_fn=loss_fn,
    optimizer=optimizer,
    evaluator=evaluator,
    device=device,
    save_dir=save_dir,
    epochs=epochs,
    early_stopping_patience=5,
)

In [None]:
regression_model.to(device)
regression_model.eval()
h3_indexes = []
xy_points = []
all_predictions = []

with torch.no_grad():
    for batch in tqdm(test_dataloader, desc="Predicting...", total=len(test_dataloader)):
        inputs = batch["X"].to(device)
        indexes = batch["X_h3_idx"]
        points = batch["point"] if "point" in batch else ["" for _ in indexes]
        outputs = regression_model(inputs)
        h3_indexes.extend(indexes)
        xy_points.extend(points)
        all_predictions.extend(outputs.cpu().numpy())

In [None]:
crimes.resolution, crimes.__init__

In [None]:
evaluator.evaluate(
    dataset=crimes,
    predictions=all_predictions,
    region_ids=h3_indexes,
    log_metrics=False,
)

Results visualisation

In [None]:
original_label = [test_dataset[i]["y"] for i in range(len(test_dataset))]
original_hexes = [test_dataset[i]["X_h3_idx"] for i in range(len(test_dataset))]

In [None]:
chicago_bbox = box(-87.9401, 41.6445, -87.5237, 42.0230)
# philadelphia_bbox = box(-75.2803, 39.8670, -74.9558, 40.1376)
# sf_box=box(-123.173825, 37.639830, -122.281780, 37.929824)
polygons = h3_to_geoseries(h3_indexes)
preds_gdf = gpd.GeoDataFrame(geometry=polygons)
preds_gdf.crs = {"init": "epsg:4326"}
preds_gdf["intensity"] = [t.item() for t in all_predictions]
preds_gdf["region_id"] = h3_indexes
preds_gdf.index = preds_gdf["region_id"]

# Original labeled hexes
original_polygons = h3_to_geoseries(original_hexes)
original_gdf = gpd.GeoDataFrame(geometry=[Polygon(polygon) for polygon in original_polygons])
original_gdf.crs = {"init": "epsg:4326"}
original_gdf["intensity"] = [t.item() for t in original_label]
original_gdf["region_id"] = original_hexes
original_gdf.index = original_gdf["region_id"]

# Generate H3 regions
regionalizer = H3Regionalizer(resolution=resolution)
regions = regionalizer.transform(original_gdf)

# Filter to chosen region using bbox
regions_to_plot = regions[regions.intersects(chicago_bbox)]
original_gdf = original_gdf[original_gdf.geometry.intersects(chicago_bbox)]
preds_gdf = preds_gdf[preds_gdf.geometry.intersects(chicago_bbox)]

In [None]:
plot_numeric_data(regions_to_plot, "intensity", original_gdf)

In [None]:
plot_numeric_data(regions_to_plot, "intensity", preds_gdf)