In [None]:
import copy
import logging
import os
import warnings
from collections.abc import Iterator
from typing import Any, Union

import geopandas as gpd
import numpy as np
import pandas as pd
import torch
import torch.nn as nn
import torch.optim as optim
from datasets import Dataset
from torch.nn.utils.rnn import pack_padded_sequence, pad_sequence
from torch.utils.data import DataLoader
from tqdm import tqdm

from srai.benchmark import TrajectoryRegressionEvaluator
from srai.datasets import PortoTaxiDataset
from srai.embedders import Hex2VecEmbedder  # noqa: F401
from srai.joiners import IntersectionJoiner
from srai.loaders.osm_loaders import OSMPbfLoader
from srai.loaders.osm_loaders.filters import HEX2VEC_FILTER
from srai.neighbourhoods.h3_neighbourhood import H3Neighbourhood
from srai.regionalizers import H3Regionalizer

In [None]:
device = "cuda" if torch.cuda.is_available() else "cpu"
resolution = 9
embedder_hidden_sizes = [150, 75, 25]
regionalizer = H3Regionalizer(resolution=resolution)

In [None]:
porto_taxi = PortoTaxiDataset()
hf_token = os.getenv("HF_TOKEN")

In [None]:
ds = porto_taxi.load(hf_token=hf_token, version="TTE")
train, test = ds["train"], ds["test"]

In [None]:
len(train)

In [None]:
test = test.sample(frac=0.2, random_state=42)

create a dev split from train split


In [None]:
train, dev = porto_taxi.train_test_split_bucket_trajectory(
    trajectory_id_column="trip_id", task="TTE", test_size=0.1, bucket_number=7, dev=True
)

In [None]:
type(porto_taxi.dev_gdf)

In [None]:
train = train.sample(frac=0.2, random_state=42)
dev = dev.sample(frac=0.2, random_state=42)

In [None]:
porto_taxi.categorical_columns

In [None]:
porto_taxi.numerical_columns

Linestring embeddings


In [None]:
train.geometry

In [None]:
train_ = train.copy()
dev_ = dev.copy()
test_ = test.copy()

In [None]:
regions_train = regionalizer.transform(train_)
regions_dev = regionalizer.transform(dev_)
regions_test = regionalizer.transform(test_)

In [None]:
regions_train.head()

In [None]:
osm_features = OSMPbfLoader().load(regions_train, HEX2VEC_FILTER)
embedder = Hex2VecEmbedder(embedder_hidden_sizes)
region_intersect_train = IntersectionJoiner().transform(regions_train, osm_features)
neighbourhood = H3Neighbourhood(regions_train)

with warnings.catch_warnings():
    warnings.simplefilter("ignore")
    embedder.fit(
        regions_gdf=regions_train,
        features_gdf=osm_features,
        joint_gdf=region_intersect_train,
        neighbourhood=neighbourhood,
        trainer_kwargs={"max_epochs": 10, "accelerator": device},
    )

In [None]:
embeddings_train = embedder.transform(
    regions_gdf=regions_train, features_gdf=osm_features, joint_gdf=region_intersect_train
)
embeddings_train["h3"] = embeddings_train.index

In [None]:
osm_features_dev = OSMPbfLoader().load(regions_dev, HEX2VEC_FILTER)
osm_features_test = OSMPbfLoader().load(regions_test, HEX2VEC_FILTER)

region_intersect_dev = IntersectionJoiner().transform(regions_dev, osm_features_dev)
region_intersect_test = IntersectionJoiner().transform(regions_test, osm_features_test)

embeddings_dev = embedder.transform(
    regions_gdf=regions_dev, features_gdf=osm_features_dev, joint_gdf=region_intersect_dev
)
embeddings_test = embedder.transform(
    regions_gdf=regions_test, features_gdf=osm_features_test, joint_gdf=region_intersect_test
)
embeddings_test["h3"] = embeddings_test.index
embeddings_dev["h3"] = embeddings_dev.index

In [None]:
def concat_columns(row: gpd.GeoSeries) -> np.ndarray:
    """
    Concatenate embedding values together.

    Args:
        row (gpd.GeoSeries): row of embeddings

    Returns:
        np.ndarray: concatenated embedding
    """
    return np.concatenate([np.atleast_1d(float(val)) for val in row.values])

In [None]:
numeric_cols = embeddings_test.select_dtypes(include=[np.number]).columns

embeddings_train["embedding"] = embeddings_train[numeric_cols].apply(concat_columns, axis=1)
embeddings_dev["embedding"] = embeddings_dev[numeric_cols].apply(concat_columns, axis=1)
embeddings_test["embedding"] = embeddings_test[numeric_cols].apply(concat_columns, axis=1)

In [None]:
def attach_embeddings_to_trips(traj_df, embedding_df):
    """
    Adds a column to traj_df with a list of embedding vectors matching the h3 sequence.

    Args:
        traj_df (pd.DataFrame): A DataFrame with columns ["trip_id", "duration", "h3_sequence"].
        embedding_df (pd.DataFrame): Index is h3 index, values are embedding vectors.

    Returns:
        pd.DataFrame: New DataFrame with an additional 'embedding_sequence' column.
    """

    def get_embedding_sequence(h3_seq):
        embeddings = []
        for h in h3_seq:
            if h in embedding_df.index:
                emb = embedding_df.loc[h]
                # If the embedding is a Series, convert to numpy array
                if hasattr(emb, "values"):
                    emb = emb.values
                embeddings.append(emb)
            else:
                # Handle missing h3 (e.g., pad with zeros or skip)
                embeddings.append(np.zeros(embedding_df.shape[1]))
        return embeddings

    traj_df = traj_df.copy()
    traj_df["embedding_sequence"] = traj_df["h3_sequence"].apply(get_embedding_sequence)
    return traj_df[["trip_id", "duration", "h3_sequence", "embedding_sequence"]]

In [None]:
merged_train = attach_embeddings_to_trips(train, embeddings_train["embedding"])
merged_dev = attach_embeddings_to_trips(dev, embeddings_dev["embedding"])
merged_test = attach_embeddings_to_trips(test, embeddings_test["embedding"])

In [None]:
def generate_examples(df: pd.DataFrame) -> Iterator[dict[str, Any]]:
    """
    Generator function to yield training examples from a DataFrame.

    Args:
        df (pd.DataFrame): A DataFrame containing columns:
            - "embedding_sequence": A list of embedding vectors for each H3 hex in the trajectory
              (i.e., List[List[float]] or numpy.ndarray of shape (seq_len, embed_dim)).
            - "trip_id": Unique identifier for each trip (e.g., a string or integer).
            - "duration": Target variable for the trip duration (e.g., a float or int).

    Yields:
        Dict[str, Any]: A dictionary with keys:
            - "X": The embedding sequence representing the trajectory.
            - "trip_id": The unique trip identifier.
            - "y": The target duration for the trip.
    """
    for _, row in df.iterrows():
        yield {
            "X": row["embedding_sequence"],  # shape: (seq_len, embed_dim), as list
            "trip_id": row["trip_id"],  # list of h3 indexes
            "y": row["duration"],  # target
        }


train_dataset = Dataset.from_generator(lambda: generate_examples(merged_train))
dev_dataset = Dataset.from_generator(lambda: generate_examples(merged_dev))
test_dataset = Dataset.from_generator(lambda: generate_examples(merged_test))

Model definition

In [None]:
"""
Trajectory model module.

This module contains implementation of base model of trajectory.
"""


class TravelTimePredictionBaseModel(nn.Module):  # type: ignore
    """
    Travel time prediction base model.

    Definition of travel time prediction model
    """

    def __init__(self, input_size: int, hidden_size: int, num_layers: int, output_size: int):
        """
        Initialization of travel time prediction module.

        Args:
            input_size: number of input features
            hidden_size:  number of features in the hidden state of the LSTM
            num_layers: The number of recurrent layers in the LSTM
            output_size: number of output features
        """
        super().__init__()

        self.hidden_size = hidden_size
        self.num_layers = num_layers
        self.lstm = nn.LSTM(input_size, hidden_size, num_layers, batch_first=True)
        self.fc = nn.Linear(hidden_size, output_size)
        self.relu = nn.ReLU()

    def forward(self, x: torch.Tensor, lengths: torch.Tensor) -> torch.Tensor:
        """
        Forward pass of the model.

        Args:
            x (torch.Tensor): Padded input tensor of shape (batch_size, seq_len, input_size), \
                where `seq_len` is the maximum sequence length in the batch and `input_size` \
                is the dimensionality of each timestep's feature vector (e.g., embedding size).
            lengths (torch.Tensor): 1D tensor of shape (batch_size,) containing the original \
                (unpadded) lengths of each sequence in the batch.

        Returns:
            torch.Tensor: Output tensor of shape (batch_size, output_size), containing the \
                predicted values for each sequence in the batch.
        """
        # Handling varying length of sequences
        packed_input = pack_padded_sequence(
            x, lengths.cpu(), batch_first=True, enforce_sorted=False
        )

        packed_output, (hn, cn) = self.lstm(packed_input)

        # hn: (num_layers, batch_size, hidden_size)
        # We can use the last layer's hidden state for regression
        final_hidden = hn[-1]  # (batch_size, hidden_size)

        out = self.fc(final_hidden)
        return self.relu(out)

Training parameters

In [None]:
batch_size = 32
sample_input = merged_train["embedding_sequence"].iloc[0]
# len of single h3 embedding
input_size = sample_input[0].shape[0]
output_size = 1  # Predicting total duration (regression)
hidden_size = 128
num_layers = 2
model = TravelTimePredictionBaseModel(
    input_size=input_size,
    hidden_size=hidden_size,
    num_layers=num_layers,
    output_size=output_size,
)
evaluator = TrajectoryRegressionEvaluator()

In [None]:
sample_input[0].shape[0]

In [None]:
def collate_fn(batch: list[dict[str, Any]]) -> dict[str, Any]:
    """
    Collate function for DataLoader to handle variable-length sequences.

    Pads a batch of embedding sequences to the maximum sequence length in the batch,
    and prepares corresponding labels and metadata for model input.

    Args:
        batch (List[dict[str, Any]]): A list of examples, where each example is a dictionary
            containing:
            - "X": A sequence of embeddings (List[List[float]] or tensor of shape \
                (seq_len, embed_dim))
            - "y": A scalar target value (float)
            - "trip_id": An identifier for the trip (int)

    Returns:
        Dict[str, Any]: A dictionary containing:
            - "X": Tensor of shape (batch_size, max_seq_len, embed_dim) with padded sequences
            - "y": Tensor of shape (batch_size,) with target durations
            - "trip_id": List of trip identifiers
            - "lengths": Tensor of shape (batch_size,) with original sequence lengths
    """
    X = [torch.tensor(item["X"], dtype=torch.float32) for item in batch]
    y = torch.tensor([item["y"] for item in batch], dtype=torch.float32)
    indexes = [item["trip_id"] for item in batch]
    lengths = [x.size(0) for x in X]  # original sequence lengths

    X_padded = pad_sequence(X, batch_first=True)
    return {
        "X": X_padded,
        "y": y,
        "trip_id": indexes,
        "lengths": torch.tensor(lengths, dtype=torch.long),
    }

In [None]:
train_dataloader = DataLoader(
    train_dataset, batch_size=batch_size, shuffle=True, collate_fn=collate_fn
)
dev_dataloader = DataLoader(dev_dataset, batch_size=batch_size, shuffle=True, collate_fn=collate_fn)
test_dataloader = DataLoader(
    test_dataset, batch_size=batch_size, shuffle=False, collate_fn=collate_fn
)

In [None]:
def train_with_early_stopping(
    model: nn.Module,
    train_dataloader: DataLoader,
    dev_dataloader: DataLoader,
    evaluator: object,
    device: Union[str, torch.device] = "cuda",
    epochs: int = 30,
    lr: float = 1e-3,
    save_dir: str = "./",
) -> list[dict[str, float]]:
    """
    Trains a model with early stopping based on validation loss.

    Args:
        model (nn.Module): The PyTorch model to train.
        train_dataloader (DataLoader): DataLoader for the training dataset.
        dev_dataloader (DataLoader): DataLoader for the validation dataset.
        evaluator (object): An evaluator object with `_compute_metrics(preds, targets)` method.
        device (Union[str, torch.device]): Device to train the model on ('cuda' or 'cpu').
        epochs (int): Number of training epochs.
        lr (float): Learning rate for the optimizer.
        save_dir (str): Directory to save the best model weights.

    Returns:
        List[Dict[str, float]]: A list of dictionaries with evaluation metrics for each epoch.
    """
    model.to(device)
    optimizer = optim.Adam(model.parameters(), lr=lr)
    loss_fn = nn.L1Loss()

    best_weights = copy.deepcopy(model.state_dict())
    stop_counter = 0
    prev_eval_loss = np.inf
    loss_eval = []
    loss_train = []
    metrics_results = []

    for epoch in range(epochs):
        model.train()
        batch_loss_list = []
        for batch in tqdm(train_dataloader, desc=f"Epoch: {epoch+1}"):
            inputs = batch["X"].to(device)
            lengths = batch["lengths"].to(device)
            labels = batch["y"].to(device).reshape(-1, 1)

            outputs = model(inputs, lengths)
            loss = loss_fn(outputs, labels)

            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
            batch_loss_list.append(loss.item())

        avg_train_loss = np.mean(batch_loss_list)
        loss_train.append(avg_train_loss)
        logging.info(f"Epoch [{epoch+1}/{epochs}], Train Loss: {avg_train_loss:.4f}")

        # Evaluation
        model.eval()
        batch_eval_loss = []
        metrics_per_batch = []

        with torch.no_grad():
            for i, batch in enumerate(
                tqdm(dev_dataloader, desc="Evaluation", total=len(dev_dataloader))
            ):
                inputs = batch["X"].to(device)
                lengths = batch["lengths"].to(device)
                labels = batch["y"].to(device).reshape(-1, 1)

                outputs = model(inputs, lengths)
                loss = loss_fn(outputs, labels)
                batch_eval_loss.append(loss.item())

                metrics = evaluator._compute_metrics(outputs.cpu().numpy(), labels.cpu().numpy())
                metrics_per_batch.append({"Batch": i, **metrics})

        avg_eval_loss = np.mean(batch_eval_loss)
        loss_eval.append(avg_eval_loss)
        logging.info(f"Evaluation Loss: {avg_eval_loss:.4f}")

        mean_metrics = {
            key: np.mean([b[key] for b in metrics_per_batch])
            for key in metrics_per_batch[0].keys()
            if key != "Batch"
        }
        metrics_results.append(mean_metrics)

        # Early stopping
        if avg_eval_loss >= prev_eval_loss:
            stop_counter += 1
            logging.info(f"No improvement. Early stop counter: {stop_counter}/5")
            if stop_counter == 5:
                logging.info(f"Early stopping at epoch {epoch+1}")
                model.load_state_dict(best_weights)
                break
        else:
            stop_counter = 0
            best_weights = copy.deepcopy(model.state_dict())

        prev_eval_loss = avg_eval_loss

    model.load_state_dict(best_weights)
    torch.save(model.state_dict(), os.path.join(save_dir, "best_travel_time_model.pkl"))
    logging.info("Best model saved.")
    return model, metrics_results

In [None]:
model, metrics = train_with_early_stopping(
    model, train_dataloader, dev_dataloader, evaluator, device, 15
)

In [None]:
model.eval()
trip_indexes = []
all_predictions = []

with torch.no_grad():
    for _, batch in enumerate(
        tqdm(test_dataloader, desc="Predicting...", total=len(test_dataloader))
    ):
        inputs = batch["X"].to(device)
        lengths = batch["lengths"].to(device)
        indexes = batch["trip_id"]
        outputs = model(inputs, lengths)
        trip_indexes.extend(indexes)
        all_predictions.extend(outputs.cpu().numpy())

In [None]:
evaluator.evaluate(
    dataset=porto_taxi, predictions=all_predictions, trip_ids=trip_indexes, log_metrics=False
)