In [None]:
import copy
import logging
import os
import warnings
from typing import Any, Callable, Optional, Union

import geopandas as gpd
import h3
import numpy as np
import pandas as pd
import torch
import torch.nn as nn
import torch.optim as optim
from examples.benchmark.task_modeling.coords import get_coords
from geopy.distance import great_circle
from srai.benchmark import BaseEvaluator, MobilityPredictionEvaluator
from srai.datasets import PortoTaxiDataset
from srai.embedders import Hex2VecEmbedder  # noqa: F401
from srai.h3 import h3_to_geoseries
from srai.joiners import IntersectionJoiner
from srai.loaders.osm_loaders import OSMPbfLoader
from srai.loaders.osm_loaders.filters import HEX2VEC_FILTER
from srai.neighbourhoods.h3_neighbourhood import H3Neighbourhood
from srai.regionalizers import H3Regionalizer, geocode_to_region_gdf
from torch import Tensor
from torch.utils.data import DataLoader
from tqdm import tqdm

In [None]:
porto_taxi = PortoTaxiDataset()

In [None]:
ds = porto_taxi.load(version="HMP")
train, test = ds["train"], ds["test"]

In [None]:
device = "cuda" if torch.cuda.is_available() else "cpu"
embedder_hidden_sizes = [150, 75, 25]
resolution = porto_taxi.resolution
regionalizer = H3Regionalizer(resolution=resolution)

In [None]:
train, test = ds["train"], ds["test"]

In [None]:
porto_taxi.target

In [None]:
# test = test.sample(frac=0.01, random_state=42)

In [None]:
train, dev = porto_taxi.train_test_split(
    trajectory_id_column=porto_taxi.target,
    task="HMP",
    test_size=0.1,
    n_bins=3,
    validation_split=True,
)

In [None]:
type(porto_taxi.dev_gdf)

In [None]:
# train = train.sample(frac=0.01, random_state=42)
# dev = dev.sample(frac=0.01, random_state=42)

In [None]:
porto_taxi.categorical_columns

In [None]:
porto_taxi.numerical_columns

Linestring embeddings

In [None]:
train.geometry

In [None]:
train_ = train.copy()
dev_ = dev.copy()
test_ = test.copy()

In [None]:
regions_train = regionalizer.transform(train_)
regions_dev = regionalizer.transform(dev_)
regions_test = regionalizer.transform(test_)

In [None]:
area = geocode_to_region_gdf("Porto, Portugal")
regions = regionalizer.transform(area)

In [None]:
full_regions = regionalizer.transform(
    gpd.GeoDataFrame(
        ["full"], geometry=[regions_train.union_all().convex_hull]
    ).set_crs(regions_train.crs)
)

In [None]:
from srai.h3 import ring_buffer_h3_regions_gdf

buffered_regions_train = ring_buffer_h3_regions_gdf(regions_train, 2)
buffered_regions_dev = ring_buffer_h3_regions_gdf(regions_dev, 2)
buffered_regions_test = ring_buffer_h3_regions_gdf(regions_test, 2)


osm_features = OSMPbfLoader().load(full_regions, HEX2VEC_FILTER)
region_intersect_train = IntersectionJoiner().transform(
    buffered_regions_train, osm_features
)

# # For CCE or CE usage
# neighbourhood = H3Neighbourhood(full_regions)
# embedder = ContextualCountEmbedder(neighbourhood=neighbourhood,
#                                    neighbourhood_distance=2,
#                                    expected_output_features=HEX2VEC_FILTER,
#                                     concatenate_vectors=True,
#                                     count_subcategories=True)
# embedder = CountEmbedder(expected_output_features=HEX2VEC_FILTER)

# # For H2V usage
embedder = Hex2VecEmbedder(embedder_hidden_sizes)
neighbourhood = H3Neighbourhood(buffered_regions_train)


# # For GV usage
# embedder = GeoVexEmbedder(target_features=HEX2VEC_FILTER, neighbourhood_radius=2)
# neighbourhood = H3Neighbourhood(full_regions)

# Needed for H2V and GV. Comment fitting block out for CCE and CE
with warnings.catch_warnings():
    warnings.simplefilter("ignore")
    embedder.fit(
        regions_gdf=buffered_regions_train,
        features_gdf=osm_features,
        joint_gdf=region_intersect_train,
        neighbourhood=neighbourhood,
        trainer_kwargs={"max_epochs": 10, "accelerator": device},
    )

In [None]:
region_intersect = IntersectionJoiner().transform(full_regions, osm_features)
all_embeddings = embedder.transform(
    regions_gdf=full_regions, features_gdf=osm_features, joint_gdf=region_intersect
)
all_embeddings["h3"] = all_embeddings.index

In [None]:
embeddings_train = embedder.transform(
    regions_gdf=buffered_regions_train,
    features_gdf=osm_features,
    joint_gdf=region_intersect_train,
)
embeddings_train["h3"] = embeddings_train.index

In [None]:
region_intersect_dev = IntersectionJoiner().transform(
    buffered_regions_dev, osm_features
)
region_intersect_test = IntersectionJoiner().transform(
    buffered_regions_test, osm_features
)

embeddings_dev = embedder.transform(
    regions_gdf=buffered_regions_dev,
    features_gdf=osm_features,
    joint_gdf=region_intersect_dev,
)
embeddings_dev["h3"] = embeddings_dev.index

embeddings_test = embedder.transform(
    regions_gdf=buffered_regions_test,
    features_gdf=osm_features,
    joint_gdf=region_intersect_test,
)
embeddings_test["h3"] = embeddings_test.index

In [None]:
def concat_columns(row: gpd.GeoSeries) -> np.ndarray:
    """
    Concatenate embedding values together.

    Args:
        row (gpd.GeoSeries): row of embeddings

    Returns:
        np.ndarray: concatenated embedding
    """
    return np.concatenate([np.atleast_1d(float(val)) for val in row.values]).astype(
        np.float32
    )

In [None]:
numeric_cols = all_embeddings.select_dtypes(include=[np.number]).columns
all_embeddings["embedding"] = all_embeddings[numeric_cols].apply(concat_columns, axis=1)
embeddings_train["embedding"] = embeddings_train[numeric_cols].apply(
    concat_columns, axis=1
)
embeddings_dev["embedding"] = embeddings_dev[numeric_cols].apply(concat_columns, axis=1)
embeddings_test["embedding"] = embeddings_test[numeric_cols].apply(
    concat_columns, axis=1
)

In [None]:
def attach_embeddings_to_trips(
    traj_df: pd.DataFrame, embedding_df: pd.DataFrame, trip_ids_column: str = "trip_id"
) -> pd.DataFrame:
    """
    Attach embedding sequences to H3-indexed trajectories.

    For each row in the trajectory DataFrame, this function looks up the corresponding
    H3 indices in the `embedding_df` and constructs sequences of embedding vectors
    (as NumPy arrays) for both `h3_sequence_x`.

    If an H3 index is missing in the `embedding_df`, it is replaced with a zero vector
    of the appropriate embedding dimension.

    Args:
        traj_df (pd.DataFrame): A DataFrame with columns:
            - 'trip_columns_id': unique trip identifiers.
            - "h3_sequence_x": list of H3 indices (str) for the x-sequence.
            - "h3_sequence_y": list of H3 indices (str) for the y-sequence.
        embedding_df (pd.DataFrame): A DataFrame where the index consists of H3 indices
            and each row contains a corresponding embedding vector.
        trip_ids_column (str): Column name for unique identifier of each trip.

    Returns:
        pd.DataFrame: pd.DataFrame: New DataFrame with an additional 'embedding_sequence_x' column.
    """

    def get_embedding_sequence(h3_seq: list[str]) -> list[np.ndarray]:
        embeddings: list[np.ndarray] = []
        for h in h3_seq:
            if h in embedding_df.index:
                emb: Any = embedding_df.loc[h]
                if hasattr(emb, "values"):
                    emb = emb.values
                embeddings.append(emb)
            else:
                embeddings.append(np.zeros(len(embedding_df.iloc[0])))
        return embeddings

    traj_df = traj_df.copy()
    traj_df["embedding_sequence_x"] = traj_df["h3_sequence_x"].apply(
        get_embedding_sequence
    )

    return traj_df[
        [trip_ids_column, "h3_sequence_x", "embedding_sequence_x", "h3_sequence_y"]
    ]

In [None]:
merged_train = attach_embeddings_to_trips(train, embeddings_train["embedding"])
merged_dev = attach_embeddings_to_trips(dev, embeddings_dev["embedding"])
merged_test = attach_embeddings_to_trips(test, embeddings_test["embedding"])

Define helper dataframes -> h3 to embeddings mapping and regions to neighbourhoods with labels

In [None]:
h3_embedding_lookup = all_embeddings["embedding"]

In [None]:
h3_embedding_lookup_dict = dict(
    zip(h3_embedding_lookup.index, h3_embedding_lookup.tolist())
)

In [None]:
def _get_neighbours(h3_idx: str, h3_neighbourhood_size: int) -> list[int]:
    return h3.grid_disk(h3_idx, k=h3_neighbourhood_size)

In [None]:
full_regions["neighbours"] = full_regions.index.map(
    lambda x: _get_neighbours(x, h3_neighbourhood_size=1)
)
full_regions["labels"] = full_regions["neighbours"].apply(
    lambda x: get_coords(x, x[0], 1)
)

In [None]:
label_lookup_df = full_regions.copy()

Helper function that maps y sequence to a sequence of spatially consistent class labels (look at coords.py)

In [None]:
def label_h3_sequence(
    h3_seq_x: list[str],
    h3_seq_y: list[str],
    label_lookup_df: pd.DataFrame,
) -> list[Union[int, float]]:
    """
    Generate a sequence of labels based on H3 transitions and a lookup table.

    For each step in the H3 y-sequence, this function checks if the target H3 index
    exists among the neighbors of the current H3 cell (starting from the last in `h3_seq_x`),
    and assigns the corresponding label from the `label_lookup_df`.

    If a target H3 index is not a neighbor or if an error occurs, a default label of 0 is used.
    0 represents staying in the same hexagon.

    Args:
        h3_seq_x (List[str]): List of H3 indices representing the observed part of a trip.
        h3_seq_y (List[str]): List of H3 indices representing the predicted/future path.
        label_lookup_df (pd.DataFrame): DataFrame indexed by H3 index, with two columns:
            - "neighbours": List[str] of neighbor H3 indices.
            - "labels": List[int or float] of corresponding labels for transitions.

    Returns:
        List[Union[int, float]]: Sequence of labels aligned with `h3_seq_y`.
    """
    labels: list[Union[int, float]] = []
    current: str = h3_seq_x[-1]  # Start from the last known H3 in x

    for next_h3 in h3_seq_y:
        try:
            label_row = label_lookup_df.loc[current]
            if next_h3 in label_row["neighbours"]:
                idx = label_row["neighbours"].index(next_h3)
                label = label_row["labels"][idx]
            else:
                label = 0
        except (KeyError, IndexError, TypeError):
            label = 0
        labels.append(label)
        current = next_h3

    return labels

In [None]:
merged_train["h3_sequence_y_labels"] = merged_train.apply(
    lambda row: label_h3_sequence(
        row["h3_sequence_x"], row["h3_sequence_y"], label_lookup_df
    ),
    axis=1,
)
merged_dev["h3_sequence_y_labels"] = merged_dev.apply(
    lambda row: label_h3_sequence(
        row["h3_sequence_x"], row["h3_sequence_y"], label_lookup_df
    ),
    axis=1,
)

merged_test["h3_sequence_y_labels"] = merged_test.apply(
    lambda row: label_h3_sequence(
        row["h3_sequence_x"], row["h3_sequence_y"], label_lookup_df
    ),
    axis=1,
)

In [None]:
class H3SequenceDataset(torch.utils.data.Dataset):
    """
    PyTorch Dataset for loading sequences of H3 indices and their embeddings.

    Each item in the dataset is expected to include:
        - "h3_sequence_x": input H3 index sequence
        - "h3_sequence_y": target/predicted H3 index sequence
        - "embedding_sequence_x": list of embedding vectors corresponding to h3_sequence_x
        - "h3_sequence_y_labels": label sequence aligned with h3_sequence_y
        - 'trip_ids_column': identifier of the trip

    Args:
        df (pd.DataFrame): A DataFrame containing the required columns for the dataset.

    Attributes:
        df (pd.DataFrame): Internal reference to the dataset.
    """

    def __init__(self, df: pd.DataFrame, trip_ids_column: str = "trip_id") -> None:
        """
        Initialize the dataset with a DataFrame.

        Args:
            df (pd.DataFrame): The DataFrame to wrap.
            trip_ids_column (str): Column name for unique identifier of each trip.
        """
        self.trip_id_column: str = trip_ids_column
        self.df: pd.DataFrame = df

    def __getitem__(self, idx: int) -> dict[str, Any]:
        """
        Retrieve one item from the dataset by index.

        Args:
            idx (int): Index of the item to retrieve.

        Returns:
            Dict[str, Any]: A dictionary containing sequence data and labels.
        """
        row = self.df.iloc[idx]
        return {
            "h3_sequence_x": row["h3_sequence_x"],
            "h3_sequence_y": row["h3_sequence_y"],
            "embedding_sequence_x": row["embedding_sequence_x"],
            "h3_sequence_y_labels": row["h3_sequence_y_labels"],
            "trip_id": row[self.trip_id_column],
        }

    def __len__(self) -> int:
        """
        Return the total number of items in the dataset.

        Returns:
            int: Number of items in the dataset.
        """
        return len(self.df)

In [None]:
def collate_fn(batch: list[dict[str, Any]]) -> dict[str, Any]:
    """
    Collate function for PyTorch DataLoader when using batch size of 1.

    This function converts the "embedding_sequence_x" and "h3_sequence_y_labels"
    fields into torch tensors with appropriate data types.

    Args:
        batch (List[Dict[str, Any]]): A batch of samples, expected to contain one item.

    Returns:
        Dict[str, Any]: A single item with tensor-converted fields.
    """
    item = batch[0]  # Assumes batch_size == 1
    item["embedding_sequence_x"] = torch.tensor(
        item["embedding_sequence_x"], dtype=torch.float32
    )
    item["h3_sequence_y_labels"] = torch.tensor(
        item["h3_sequence_y_labels"], dtype=torch.long
    )
    return item

In [None]:
dataloader_batch_size = 1
train_dataset = H3SequenceDataset(merged_train)
train_loader = torch.utils.data.DataLoader(
    train_dataset, batch_size=dataloader_batch_size, shuffle=True, collate_fn=collate_fn
)
dev_dataset = H3SequenceDataset(merged_dev)
dev_loader = torch.utils.data.DataLoader(
    dev_dataset, batch_size=dataloader_batch_size, shuffle=True, collate_fn=collate_fn
)
test_dataset = H3SequenceDataset(merged_test)
test_loader = torch.utils.data.DataLoader(
    test_dataset, batch_size=dataloader_batch_size, shuffle=False, collate_fn=collate_fn
)

Model definition & helper functions related to training and modeling

In [None]:
def get_predicted_h3(
    current_h3: Union[str, int],
    predicted_class: Any,
    label_lookup_df: pd.DataFrame,
) -> Union[str, int]:
    """
    Retrieve the predicted H3 index based on the current H3 index and the predicted class label.

    If the current H3 index is not in the lookup table, it computes its neighbors and their
    corresponding labels on the fly.

    Args:
        current_h3 (Union[str, int]): The current H3 index (string or integer).
        predicted_class (Any): The predicted class label to match with the label list.
        label_lookup_df (pd.DataFrame): DataFrame with H3 indices as index, containing 'neighbours'
            and 'labels' columns.

    Returns:
        Union[str, int]: The predicted H3 index if found in neighbors, otherwise the current\
              H3 index.
    """
    if current_h3 in label_lookup_df.index:
        row = label_lookup_df.loc[current_h3]
    else:
        df = pd.DataFrame([current_h3], columns=["region_id"])
        df["neighbours"] = df["region_id"].map(
            lambda x: _get_neighbours(x, h3_neighbourhood_size=1)
        )
        df["labels"] = df["neighbours"].apply(lambda x: get_coords(x, x[0], 1))
        df.index = df["region_id"]
        row = df.loc[current_h3]

    if predicted_class in row["labels"]:
        idx = row["labels"].tolist().index(predicted_class)
        return row["neighbours"][idx]
    else:
        return current_h3

Helper function to handle case of an h3 which embedding not present in the lookup

In [None]:
def embed_h3(h3_index: Union[str, int]) -> pd.Series:
    """
    Compute a feature embedding for a given H3 index using OSM-derived features.

    This function transforms the H3 index into a geometry, loads corresponding
    OSM features, intersects them with the region, and generates a numerical embedding.

    Args:
        h3_index (Union[str, int]): The H3 index for which to compute the embedding.

    Returns:
        pd.Series: A Pandas Series where the index is the H3 index and the value is a
                   NumPy array representing the embedding vector.
    """
    geo_series = h3_to_geoseries(h3_index)

    # Ensure the CRS is correctly transformed
    # if geo_series.crs != "EPSG:4326":
    #     geo_series = geo_series.to_crs("EPSG:4326")

    geometry = geo_series.geometry.iloc[0]

    regions = gpd.GeoDataFrame(
        geometry=[geometry],
        index=pd.Index([h3_index], name="region_id"),
        crs="EPSG:4326",
    )
    osm_features = OSMPbfLoader().load(regions, HEX2VEC_FILTER)
    region_intersect = IntersectionJoiner().transform(regions, osm_features)

    single_embedding = embedder.transform(
        regions_gdf=regions, features_gdf=osm_features, joint_gdf=region_intersect
    )

    single_embedding["h3"] = single_embedding.index
    numeric_cols = single_embedding.select_dtypes(include=[np.number]).columns
    single_embedding["embedding"] = single_embedding[numeric_cols].apply(
        concat_columns, axis=1
    )

    return single_embedding["embedding"]

Hvaersine distance used ass additional loss function

In [None]:
def haversine_distance(lat1: float, lon1: float, lat2: float, lon2: float) -> float:
    """
    Calculate the great-circle distance between two geographic coordinates using geopy.

    Args:
        lat1 (float): Latitude of the first point in decimal degrees.
        lon1 (float): Longitude of the first point in decimal degrees.
        lat2 (float): Latitude of the second point in decimal degrees.
        lon2 (float): Longitude of the second point in decimal degrees.

    Returns:
        float: Distance between the two points in kilometers.
    """
    return great_circle((lat1, lon1), (lat2, lon2)).kilometers

In [None]:
def train(
    model: nn.Module,
    train_loader: DataLoader,
    dev_loader: DataLoader,
    evaluator: BaseEvaluator,
    optimizer: optim.Optimizer,
    loss_fn: Callable[[torch.Tensor, torch.Tensor], torch.Tensor],
    geo_loss_weight: float,
    embedding_lookup: dict[str, Any],
    label_lookup_df: pd.DataFrame,
    epochs: int = 10,
    save_dir: str = "./",
    early_stopping_patience: int = 5,
) -> None:
    """
    Trains a sequence prediction model on H3 embedding data.

    Args:
        model (nn.Module): The PyTorch model to train.
        train_loader (DataLoader): DataLoader for the training dataset.
        dev_loader (DataLoader): DataLoader for the validation dataset.
        evaluator (BaseEvaluator): An BaseEvaluator with `_compute_metrics(preds, targets)` method.
        optimizer (torch.optim.Optimizer): Optimizer for updating model weights.
        loss_fn (Callable): Loss function (e.g., nn.CrossEntropyLoss).
        geo_loss_weight (float): Weight of haversine distance added to the loss function.
        embedding_lookup (Dict[str, Any]): Dictionary mapping H3 cell indexes to embedding vectors.
        label_lookup_df (pd.DataFrame): DataFrame used to resolve neighboring H3 cells and labels.
        epochs (int, optional): Number of training epochs. Defaults to 10.
        save_dir (str): Directory to save the best model weights.
        early_stopping_patience (int, optional): Number of evaluations without improvement
            before early stopping. Defaults to 5.

    Returns:
        None
    """
    model.to(device)
    mean_metrics = []
    best_model_state = None
    best_eval_loss = float("inf")

    for epoch in range(epochs):
        model.train()
        total_train_loss = 0.0
        train_bar = tqdm(
            train_loader, desc=f"Epoch {epoch + 1}/{epochs} [Training]", leave=False
        )

        for item in train_bar:
            h3_seq_x = item["h3_sequence_x"]
            h3_seq_y = item["h3_sequence_y"]
            h3_embed_seq_x = item["embedding_sequence_x"]
            label_seq = item["h3_sequence_y_labels"]

            hidden = None
            h3_embed_seq = h3_embed_seq_x.unsqueeze(0).to(
                device
            )  # [1, seq_len, embed_dim]
            current_h3 = h3_seq_x[-1]

            optimizer.zero_grad()
            loss = 0.0

            for t in range(len(label_seq)):
                logits, _ = model(h3_embed_seq, hidden)
                target = torch.tensor([label_seq[t]], dtype=torch.long, device=device)

                pred_class = logits.argmax(dim=1).item() + 1
                true_h3 = h3_seq_y[t]
                pred_h3 = get_predicted_h3(current_h3, pred_class, label_lookup_df)

                # Get coordinates
                true_lat, true_lon = h3.cell_to_latlng(true_h3)
                pred_lat, pred_lon = h3.cell_to_latlng(pred_h3)
                haversine_m = haversine_distance(true_lat, true_lon, pred_lat, pred_lon)
                geo_loss = torch.tensor(haversine_m, dtype=torch.float, device=device)
                log_geo_loss = torch.log1p(geo_loss)
                loss += (1 - geo_loss_weight) * loss_fn(
                    logits, target
                ) + geo_loss_weight * log_geo_loss

                # Teacher forcing: use actual next H3 from sequence
                pred_h3 = h3_seq_y[t]
                # if pred_h3 in embedding_lookup:
                next_embedding = embedding_lookup[pred_h3]
                # else:
                #     next_embedding = embed_h3(pred_h3).get(pred_h3)

                next_embedding_tensor = (
                    torch.tensor(next_embedding, dtype=torch.float, device=device)
                    .unsqueeze(0)
                    .unsqueeze(0)
                )  # [1, 1, embed_dim]

                h3_embed_seq = torch.cat([h3_embed_seq, next_embedding_tensor], dim=1)
                current_h3 = pred_h3

            loss.backward()
            optimizer.step()
            total_train_loss += loss.item() / len(label_seq)
            train_bar.set_postfix(loss=loss.item() / len(label_seq))

        avg_train_loss = total_train_loss / len(train_loader)
        print(f"Epoch {epoch + 1}, Train Loss: {avg_train_loss:.4f}")

        # ---- Evaluation ----
        model.eval()
        total_eval_loss = 0.0
        batch_metrics = []

        with torch.no_grad():
            eval_bar = tqdm(dev_loader, desc=f"Epoch {epoch + 1} [Eval]", leave=False)

            for item in eval_bar:
                h3_seq_x = item["h3_sequence_x"]
                h3_seq_y = item["h3_sequence_y"]
                h3_embed_seq_x = item["embedding_sequence_x"]
                label_seq = item["h3_sequence_y_labels"]

                hidden = None
                h3_embed_seq = h3_embed_seq_x.unsqueeze(0).to(device)
                current_h3 = h3_seq_x[-1]
                loss = 0.0

                pred_h3_sequence: list[str] = []
                true_h3_sequence: list[str] = h3_seq_y

                for t in range(len(label_seq)):
                    logits, _ = model(h3_embed_seq, hidden)
                    target = torch.tensor(
                        [label_seq[t]], dtype=torch.long, device=device
                    )

                    pred_class = logits.argmax(dim=1).item() + 1
                    pred_h3 = get_predicted_h3(current_h3, pred_class, label_lookup_df)
                    pred_h3_sequence.append(pred_h3)

                    true_h3 = h3_seq_y[t]
                    # Get coordinates
                    true_lat, true_lon = h3.cell_to_latlng(true_h3)
                    pred_lat, pred_lon = h3.cell_to_latlng(pred_h3)
                    haversine_m = haversine_distance(
                        true_lat, true_lon, pred_lat, pred_lon
                    )
                    geo_loss = torch.tensor(
                        haversine_m, dtype=torch.float, device=device
                    )
                    log_geo_loss = torch.log1p(geo_loss)
                    loss += (1 - geo_loss_weight) * loss_fn(
                        logits, target
                    ) + geo_loss_weight * log_geo_loss

                    # if pred_h3 in embedding_lookup:
                    next_embedding = embedding_lookup[pred_h3]
                    # else:
                    #     next_embedding = embed_h3(pred_h3).get(pred_h3)

                    next_embedding_tensor = (
                        torch.tensor(next_embedding, dtype=torch.float, device=device)
                        .unsqueeze(0)
                        .unsqueeze(0)
                    )

                    h3_embed_seq = torch.cat(
                        [h3_embed_seq, next_embedding_tensor], dim=1
                    )
                    current_h3 = pred_h3

                metrics = evaluator._compute_metrics(
                    true_sequences=[true_h3_sequence], pred_sequences=[pred_h3_sequence]
                )
                batch_metrics.append(metrics)
                total_eval_loss += loss.item() / len(label_seq)
                eval_bar.set_postfix(loss=loss.item() / len(label_seq))

        avg_eval_loss = total_eval_loss / len(dev_loader)
        print(f"Epoch {epoch + 1}, Dev Loss: {avg_eval_loss:.4f}")

        if avg_eval_loss < best_eval_loss:
            best_eval_loss = avg_eval_loss
            best_model_state = copy.deepcopy(model.state_dict())
            stop_counter = 0
        else:
            stop_counter += 1
            print(f"No improvement. Early stopping counter: {stop_counter}/5")
            if stop_counter >= early_stopping_patience:
                print(f"Early stopping triggered at epoch {epoch + 1}")
                break
    mean_metrics = {
        metric: float(np.mean([m[metric] for m in batch_metrics]))
        for metric in batch_metrics[0]
    }
    model.load_state_dict(best_model_state)
    torch.save(
        model.state_dict(), os.path.join(save_dir, "best_mobility_prediction_model.pkl")
    )
    logging.info("Best model saved.")
    return model, mean_metrics

In [None]:
class AttentiveH3Predictor(nn.Module):
    """
    A sequence model combining LSTM and Multi-Head Self-Attention for H3 index prediction.

    This model processes a sequence of H3 embeddings using an LSTM to capture temporal dynamics,
    followed by a self-attention mechanism to better aggregate contextual information across the\
          sequence.
    It predicts the next H3 class label based on the final attended representation.

    Args:
        embedding_dim (int): Dimensionality of the input H3 embeddings. Default is 64.
        hidden_dim (int): Dimensionality of the LSTM hidden states and attention. Default is 128.
        num_classes (int): Number of output classes (e.g., neighboring H3 cells). Default is 7.
    """

    def __init__(
        self, embedding_dim: int = 64, hidden_dim: int = 128, num_classes: int = 7
    ) -> None:
        """Initialize model with embedding dimensions, hidden dimensions and num classes."""
        super().__init__()
        self.lstm = nn.LSTM(
            embedding_dim, hidden_dim, num_layers=2, batch_first=True, dropout=0.1
        )
        self.attn = nn.MultiheadAttention(hidden_dim, num_heads=4, batch_first=True)
        self.classifier = nn.Linear(hidden_dim, num_classes)

    def forward(
        self, x: Tensor, hidden: Optional[tuple[Tensor, Tensor]] = None
    ) -> tuple[Tensor, tuple[Tensor, Tensor]]:
        """
        Forward pass of the AttentiveH3Predictor model.

        Args:
            x (Tensor): Input tensor of shape (batch_size, sequence_length, embedding_dim).
            hidden (Optional[Tuple[Tensor, Tensor]]): Initial hidden and cell states for the LSTM.
                Each is of shape (1, batch_size, hidden_dim), or None to initialize zero states.

        Returns:
            Tuple[Tensor, Tuple[Tensor, Tensor]]:
                - logits: Tensor of shape (batch_size, num_classes) representing class logits at \
                    the final timestep.
                - hidden: Final LSTM hidden and cell states.
        """
        out, hidden = self.lstm(x, hidden)  # out: (B, T, H)
        attn_out, _ = self.attn(out, out, out)  # self-attention: (B, T, H)
        logits = self.classifier(attn_out[:, -1, :])
        return logits, hidden

In [None]:
embedding_dim = h3_embedding_lookup.iloc[0].shape[0]
model = AttentiveH3Predictor(embedding_dim=embedding_dim)
optimizer = torch.optim.Adam(model.parameters(), lr=0.001)
geo_loss_weight = 0.7
loss_fn = nn.CrossEntropyLoss()
evaluator = MobilityPredictionEvaluator()

# Train
model, metrics = train(
    model,
    train_loader,
    dev_loader,
    evaluator,
    optimizer,
    loss_fn,
    geo_loss_weight,
    h3_embedding_lookup_dict,
    label_lookup_df,
    epochs=1,
    early_stopping_patience=3,
)

In [None]:
model.eval()
all_true_seqs = []
all_pred_seqs = []
trip_ids = []

with torch.no_grad():
    for batch in tqdm(test_loader, desc="Predicting...", total=len(test_loader)):
        h3_seq_x = batch["h3_sequence_x"]
        h3_seq_y = batch["h3_sequence_y"]
        embed_seq_x = batch["embedding_sequence_x"]
        label_seq = batch["h3_sequence_y_labels"]
        trip_id = batch.get("trip_id")

        hidden = None
        h3_embed_seq = embed_seq_x.unsqueeze(0).to(device)
        current_h3 = h3_seq_x[-1]

        pred_h3_sequence = []

        for _ in range(len(label_seq)):
            logits, _ = model(h3_embed_seq, hidden)

            pred_class = logits.argmax(dim=1).item() + 1
            pred_h3 = get_predicted_h3(current_h3, pred_class, label_lookup_df)
            pred_h3_sequence.append(pred_h3)
            if pred_h3 in h3_embedding_lookup_dict.keys():
                next_embedding = h3_embedding_lookup_dict.get(pred_h3)
            else:
                next_embedding = embed_h3(pred_h3).get(pred_h3)
            next_embedding_tensor = (
                torch.tensor(next_embedding, dtype=torch.float, device=device)
                .unsqueeze(0)
                .unsqueeze(0)
            )

            h3_embed_seq = torch.cat([h3_embed_seq, next_embedding_tensor], dim=1)
            current_h3 = pred_h3

        all_true_seqs.append(h3_seq_y)
        all_pred_seqs.append(pred_h3_sequence)

        if trip_id is not None:
            trip_ids.append(trip_id)

Evaluate for different sequence length

In [None]:
for k in [1, 3, 5, 7, 10]:
    evaluator = MobilityPredictionEvaluator(k=k)
    print(f"Evaluating with k={k}...")
    metrics = evaluator.evaluate(
        porto_taxi, predictions=all_pred_seqs, trip_ids=trip_ids, log_metrics=False
    )

    print(f"Metrics for k={k}: {metrics}")

Visualize predictions vs true data

In [None]:
t_seqs = [s for s in all_true_seqs if len(s) > 4]
p_seqs = [s for s in all_pred_seqs if len(s) > 4]

In [None]:
t_seqs

In [None]:
ID = 0

s_true = gpd.GeoDataFrame(
    t_seqs[ID], geometry=h3_to_geoseries(t_seqs[ID]), columns=["h3"]
)
s_pred = gpd.GeoDataFrame(
    p_seqs[ID], geometry=h3_to_geoseries(p_seqs[ID]), columns=["h3"]
)

print(len(s_true), len(s_pred))

m = s_true.explore()
s_pred.explore(m=m, color="red", name="Predicted H3")