# Core

> Fill in a module description here

In [None]:
#| default_exp core

In [None]:
#| export
from typing import List, Optional, Callable, Dict, Any, Iterator, Tuple
from pathlib import Path
import os
import h5py
import numpy as np
import itertools
import identibench.metrics

In [None]:
#| hide
import shutil
import time # For testing modification times
from fastcore.test import test_eq, test_ne# Import nbdev testing functions

In [None]:
#| export
def get_default_data_root() -> Path:
    """
    Returns the default root directory for datasets.

    Checks the 'IDENTIBENCH_DATA_ROOT' environment variable first,
    otherwise defaults to '~/.identibench_data'.
    """
    return Path(os.environ.get('IDENTIBENCH_DATA_ROOT', Path.home() / '.identibench_data'))


In [None]:
#| exporti
# Forward declaration needed for type hint in BenchmarkSpec.test_func
class BenchmarkSpec: pass 

In [None]:
#| export
class BenchmarkSpec:
    """
    Specification for a single, standardized benchmark dataset configuration.

    Defines fixed parameters for dataset loading, preprocessing, evaluation metric,
    and potentially a custom testing function. Specific evaluation logic 
    (simulation vs prediction, windowing) is handled by the benchmark execution function 
    or the custom test_func.
    """
    # Explicit __init__ for nbdev documentation compatibility
    def __init__(self, 
                 name: str, # Unique name identifying this specific benchmark task (e.g., 'silverbox_sim_rmse').
                 dataset_id: str, # Identifier for the raw dataset source (e.g., 'silverbox'), corresponds to subdirectory name.
                 u_cols: List[str], # List of column names for the input signals (u).
                 y_cols: List[str], # List of column names for the output signals (y).
                 metric_func: Callable[[np.ndarray, np.ndarray], float], # Primary metric for final test evaluation. `func(y_true, y_pred)`.
                 x_cols: Optional[List[str]] = None, # Optional list of column names for state inputs (x).
                 sampling_time: Optional[float] = None, # Optional sampling time (in seconds) if constant for the dataset.
                 download_func: Optional[Callable[[Path, bool], None]] = None, # Function to download/prepare the raw dataset. `func(save_path, force_download)`
                 test_func: Optional[Callable[[BenchmarkSpec, Callable[[np.ndarray], np.ndarray]], Dict[str, Any]]] = None, # Optional custom function to perform testing and return results dict. Takes spec and predictor. Overrides default test logic.
                 init_window: Optional[int] = None, # Number of initial steps potentially used for model initialization (simulation or prediction).
                 pred_horizon: Optional[int] = None, # The 'k' in k-step ahead prediction, used if the benchmark function performs prediction.
                 pred_step: int = 1, # Step size for k-step ahead prediction, used if the benchmark function performs prediction.
                 data_root_func: Callable[[], Path] = get_default_data_root # Function that returns the root directory where datasets are stored.
                ):
        # Standard attribute assignment
        self.name = name
        self.dataset_id = dataset_id
        self.u_cols = u_cols
        self.y_cols = y_cols
        self.metric_func = metric_func # Now mandatory
        self.x_cols = x_cols
        self.sampling_time = sampling_time
        self.download_func = download_func
        self.test_func = test_func
        self.init_window = init_window
        self.pred_horizon = pred_horizon
        self.pred_step = pred_step
        self.data_root_func = data_root_func

    @property
    def data_root(self) -> Path:
        """Returns the evaluated data root path."""
        return self.data_root_func() 

    @property
    def dataset_path(self) -> Path:
        """Returns the full path to the dataset directory."""
        return self.data_root / self.dataset_id

    # Reinstated ensure_dataset_exists method
    def ensure_dataset_exists(self, force_download: bool = False) -> None:
        """
        Checks if the dataset exists locally, downloads it if not or if forced.

        Args:
            force_download: If True, download the dataset even if it exists locally.
        """
        dataset_path = self.dataset_path 
        download_func = self.download_func 
        if download_func is None:
            print(f"Warning: No download function specified for benchmark '{self.name}'. Cannot ensure data exists at {dataset_path}")
            # Check existence even if we can't download
            if not dataset_path.is_dir():
                 print(f"Warning: Dataset directory {dataset_path} not found.")
            return

        dataset_exists = dataset_path.is_dir()

        if not dataset_exists or force_download:
            print(f"Dataset for '{self.name}' {'not found' if not dataset_exists else 'download forced'}. Preparing dataset at {dataset_path}...")
            self.data_root.mkdir(parents=True, exist_ok=True) 
            try:
                download_func(dataset_path, force_download) 
                print(f"Dataset '{self.name}' prepared successfully.")
            except Exception as e:
                print(f"Error preparing dataset '{self.name}': {e}")
                raise
        else:
             # Optionally print message if dataset already exists and not forced
             # print(f"Dataset for '{self.name}' found at {dataset_path}.")
             pass

In [None]:
# Internal dummy loader - needed for tests below
def _dummy_dataset_loader(
    save_path: Path, # Directory where the dummy dataset files will be written
    force_download: bool = False, # Argument for interface compatibility
    create_train_valid_dir: bool = False # If True, create a 'train_valid' subdir as well
    ):
    """Creates a dummy dataset structure with minimal HDF5 files for testing."""
    save_path = Path(save_path)
    if save_path.is_dir() and not force_download: return

    save_path.mkdir(parents=True, exist_ok=True)
    seq_len = 50
    subdirs = ['train', 'valid', 'test']
    if create_train_valid_dir:
        subdirs.append('train_valid')

    for subdir in subdirs:
        subdir_path = save_path / subdir
        subdir_path.mkdir(exist_ok=True)
        n_files = 1 if subdir == 'train_valid' else 2 # Create fewer files in train_valid for testing differentiation
        for i in range(n_files):
            dummy_file_path = subdir_path / f'{subdir}_{i}.hdf5'
            try:
                with h5py.File(dummy_file_path, 'w') as f:
                    f.create_dataset('u0', data=np.random.rand(seq_len).astype(np.float32))
                    f.create_dataset('u1', data=np.random.rand(seq_len).astype(np.float32))
                    f.create_dataset('y0', data=np.random.rand(seq_len).astype(np.float32))
                    f.attrs['fs'] = 10.0
            except Exception as e: print(f"Failed to create dummy file {dummy_file_path}: {e}")

In [None]:
# Setup shared for BenchmarkSpec Tests
_test_data_dir_spec = Path('./_temp_identibench_data_spec_test')
shutil.rmtree(_test_data_dir_spec, ignore_errors=True) # Clean before tests
def _get_test_data_root_spec(): return _test_data_dir_spec

In [None]:
# Test: BenchmarkSpec basic initialization and defaults
_spec_default = BenchmarkSpec(
    name='_spec_default', dataset_id='_dummy_default',
    u_cols=['u0'], y_cols=['y0'], metric_func=identibench.metrics.rmse, 
    download_func=_dummy_dataset_loader, 
    data_root_func=_get_test_data_root_spec
)
test_eq(_spec_default.init_window, None)
test_eq(_spec_default.pred_horizon, None)
test_eq(_spec_default.pred_step, 1)
test_eq(_spec_default.name, '_spec_default') 

In [None]:
# Test: BenchmarkSpec initialization with prediction-related parameters
_spec_pred_params = BenchmarkSpec(
    name='_spec_pred_params', dataset_id='_dummy_pred_params',
    u_cols=['u0'], y_cols=['y0'], metric_func=identibench.metrics.rmse, 
    download_func=_dummy_dataset_loader, 
    init_window=20, pred_horizon=5, pred_step=2,
    data_root_func=_get_test_data_root_spec
)
test_eq(_spec_pred_params.init_window, 20)
test_eq(_spec_pred_params.pred_horizon, 5)
test_eq(_spec_pred_params.pred_step, 2)

In [None]:
# Test: BenchmarkSpec ensure_dataset_exists - first call (creation)
_spec_ensure = BenchmarkSpec(
    name='_spec_ensure', dataset_id='_dummy_ensure',
    u_cols=['u0'], y_cols=['y0'], metric_func=identibench.metrics.rmse, 
    download_func=_dummy_dataset_loader, 
    data_root_func=_get_test_data_root_spec
)
_spec_ensure.ensure_dataset_exists()
_dataset_path_ensure = _spec_ensure.dataset_path
test_eq(_dataset_path_ensure.is_dir(), True)
test_eq((_dataset_path_ensure / 'train' / 'train_0.hdf5').is_file(), True)

Dataset for '_spec_ensure' not found. Preparing dataset at _temp_identibench_data_spec_test/_dummy_ensure...
Dataset '_spec_ensure' prepared successfully.


In [None]:
# Test: BenchmarkSpec ensure_dataset_exists - second call (skip)
_mtime_before_skip = (_dataset_path_ensure / 'train' / 'train_0.hdf5').stat().st_mtime
time.sleep(0.1) 
_spec_ensure.ensure_dataset_exists() 
_mtime_after_skip = (_dataset_path_ensure / 'train' / 'train_0.hdf5').stat().st_mtime
test_eq(_mtime_before_skip, _mtime_after_skip)

In [None]:
# Test: BenchmarkSpec ensure_dataset_exists - third call (force_download=True)
_mtime_before_force = (_dataset_path_ensure / 'train' / 'train_0.hdf5').stat().st_mtime
time.sleep(0.1) 
_spec_ensure.ensure_dataset_exists(force_download=True) 
_mtime_after_force = (_dataset_path_ensure / 'train' / 'train_0.hdf5').stat().st_mtime
test_ne(_mtime_before_force, _mtime_after_force)

Dataset for '_spec_ensure' download forced. Preparing dataset at _temp_identibench_data_spec_test/_dummy_ensure...
Dataset '_spec_ensure' prepared successfully.


In [None]:
#| hide
shutil.rmtree(_test_data_dir_spec, ignore_errors=True)

In [None]:
#| exporti
# Internal helper function for loading raw sequences (no windowing)
def _load_raw_sequences_from_files(
    file_paths: List[Path], # List of HDF5 file paths to load from.
    u_cols: List[str], # Input column names.
    y_cols: List[str], # Output column names.
    x_cols: Optional[List[str]], # Optional state column names.
) -> Iterator[Tuple[np.ndarray, np.ndarray, Optional[np.ndarray]]]:
    """
    Loads and yields full sequences (u, y, x) from HDF5 files.
    """
    if not file_paths: return iter([]) 

    for file_path in file_paths:
        try:
            with h5py.File(file_path, 'r') as f:
                try:
                    u_data = np.stack([f[col][()] for col in u_cols], axis=-1).astype(np.float32)
                    y_data = np.stack([f[col][()] for col in y_cols], axis=-1).astype(np.float32)
                    x_data = np.stack([f[col][()] for col in x_cols], axis=-1).astype(np.float32) if x_cols else None
                except KeyError as e:
                    print(f"Warning: Column {e} not found in file {file_path}. Skipping file.")
                    continue

                seq_len = u_data.shape[0]
                if y_data.shape[0] != seq_len or (x_data is not None and x_data.shape[0] != seq_len):
                     print(f"Warning: Column length mismatch in {file_path}. Skipping file.")
                     continue

                yield u_data, y_data, x_data 
        except Exception as e:
            print(f"Error reading or processing file {file_path}: {e}")

In [None]:
#| export
class TrainingContext:
    """
    Context object passed to the user's training function (`build_predictor`).

    Holds the benchmark specification, hyperparameters, and seed.
    Provides methods to access the raw, full-length training and validation data sequences.
    Windowing/batching for training must be handled within the user's `build_predictor` function.
    """
    # Explicit __init__ for nbdev documentation compatibility
    def __init__(self, 
                 spec: BenchmarkSpec, # The benchmark specification.
                 hyperparameters: Dict[str, Any], # User-provided dictionary containing model and training hyperparameters.
                 seed: Optional[int] = None # Optional random seed for reproducibility.
                ):
        # Standard attribute assignment
        self.spec = spec
        self.hyperparameters = hyperparameters
        self.seed = seed

    # --- Data Access Methods ---

    def _get_file_paths(self, subset: str) -> List[Path]:
        """Gets sorted list of HDF5 files for a given subset directory."""
        subset_path = self.spec.dataset_path / subset
        if not subset_path.is_dir():
            return []
        return sorted(list(subset_path.glob('*.hdf5')))

    def _get_sequences_from_subset(self, subset: str
                                  ) -> Iterator[Tuple[np.ndarray, np.ndarray, Optional[np.ndarray]]]:
        """Loads raw sequences for a specific subset directory."""
        file_paths = self._get_file_paths(subset)
        if not file_paths:
             print(f"Warning: No HDF5 files found in {self.spec.dataset_path / subset}. Returning empty iterator.")
             return iter([])

        return _load_raw_sequences_from_files(
            file_paths=file_paths,
            u_cols=self.spec.u_cols,
            y_cols=self.spec.y_cols,
            x_cols=self.spec.x_cols,
        )

    def get_train_sequences(self) -> Iterator[Tuple[np.ndarray, np.ndarray, Optional[np.ndarray]]]:
        """Returns a lazy iterator yielding raw (u, y, x) tuples for the 'train' subset."""
        return self._get_sequences_from_subset('train')

    def get_valid_sequences(self) -> Iterator[Tuple[np.ndarray, np.ndarray, Optional[np.ndarray]]]:
        """Returns a lazy iterator yielding raw (u, y, x) tuples for the 'valid' subset."""
        return self._get_sequences_from_subset('valid')

    def get_train_valid_sequences(self) -> Iterator[Tuple[np.ndarray, np.ndarray, Optional[np.ndarray]]]:
        """
        Returns a lazy iterator yielding raw (u, y, x) tuples for combined training and validation.

        Checks for a 'train_valid' subset directory first. If it exists, loads data from there.
        If not, it loads data from 'train' and 'valid' subsets sequentially.
        """
        train_valid_files = self._get_file_paths('train_valid')
        if train_valid_files:
            return _load_raw_sequences_from_files(
                file_paths=train_valid_files, u_cols=self.spec.u_cols, y_cols=self.spec.y_cols,
                x_cols=self.spec.x_cols
            )
        else:
            train_iter = self._get_sequences_from_subset('train')
            valid_iter = self._get_sequences_from_subset('valid')
            return itertools.chain(train_iter, valid_iter)

In [None]:
# Setup shared for TrainingContext Tests
_test_data_dir_ctx = Path('./_temp_identibench_data_ctx_test')
_test_data_dir_ctx_tv = Path('./_temp_identibench_data_ctx_tv_test') 
shutil.rmtree(_test_data_dir_ctx, ignore_errors=True)
shutil.rmtree(_test_data_dir_ctx_tv, ignore_errors=True)
def _get_test_data_root_ctx(): return _test_data_dir_ctx
def _get_test_data_root_ctx_tv(): return _test_data_dir_ctx_tv

In [None]:
# Create base dummy data (no train_valid dir) 
_dummy_spec_ctx_base = BenchmarkSpec(
    name='_dummy_ctx_base', dataset_id='_dummy_dataset_ctx_base',
    u_cols=['u0', 'u1'], y_cols=['y0'],metric_func=identibench.metrics.rmse, 
    download_func=lambda p, f: _dummy_dataset_loader(p, force_download=f, create_train_valid_dir=False),
    data_root_func=_get_test_data_root_ctx, init_window=10
)
_dummy_spec_ctx_base.ensure_dataset_exists()

Dataset for '_dummy_ctx_base' not found. Preparing dataset at _temp_identibench_data_ctx_test/_dummy_dataset_ctx_base...
Dataset '_dummy_ctx_base' prepared successfully.


In [None]:
# Create dummy data WITH train_valid dir 
_dummy_spec_ctx_tv = BenchmarkSpec(
    name='_dummy_ctx_tv', dataset_id='_dummy_dataset_ctx_tv',
    u_cols=['u0', 'u1'], y_cols=['y0'],metric_func=identibench.metrics.rmse, 
    download_func=lambda p, f: _dummy_dataset_loader(p, force_download=f, create_train_valid_dir=True),
    data_root_func=_get_test_data_root_ctx_tv, init_window=10
)
_dummy_spec_ctx_tv.ensure_dataset_exists()

Dataset for '_dummy_ctx_tv' not found. Preparing dataset at _temp_identibench_data_ctx_tv_test/_dummy_dataset_ctx_tv...
Dataset '_dummy_ctx_tv' prepared successfully.


In [None]:
#| hide
# Shared constants for tests
_seq_len_ctx = 50 
_n_files_train_valid_ctx = 2 
_n_files_tv_dir_ctx = 1 
_hyperparams_ctx = {'lr': 0.01, 'hidden': 64}
_seed_ctx = 42

# %% ../nbs/benchmark.ipynb 16
# Test: TrainingContext initialization
_ctx = TrainingContext(spec=_dummy_spec_ctx_base, hyperparameters=_hyperparams_ctx, seed=_seed_ctx)
test_eq(_ctx.spec, _dummy_spec_ctx_base)
test_eq(_ctx.hyperparameters, _hyperparams_ctx)
test_eq(_ctx.seed, _seed_ctx)

# %% ../nbs/benchmark.ipynb 17
# Test: TrainingContext get_train_sequences
_ctx = TrainingContext(spec=_dummy_spec_ctx_base, hyperparameters=_hyperparams_ctx, seed=_seed_ctx)
_train_sequences = list(_ctx.get_train_sequences())
test_eq(len(_train_sequences), _n_files_train_valid_ctx) 
_u_train, _y_train, _x_train = _train_sequences[0]
test_eq(_u_train.shape, (_seq_len_ctx, len(_dummy_spec_ctx_base.u_cols))) 
test_eq(_y_train.shape, (_seq_len_ctx, len(_dummy_spec_ctx_base.y_cols)))
test_eq(_x_train, None)
test_eq(_u_train.dtype, np.float32)

# %% ../nbs/benchmark.ipynb 18
# Test: TrainingContext get_valid_sequences
_ctx = TrainingContext(spec=_dummy_spec_ctx_base, hyperparameters=_hyperparams_ctx, seed=_seed_ctx)
_valid_sequences = list(_ctx.get_valid_sequences())
test_eq(len(_valid_sequences), _n_files_train_valid_ctx)
_u_valid, _y_valid, _x_valid = _valid_sequences[0]
test_eq(_u_valid.shape, (_seq_len_ctx, len(_dummy_spec_ctx_base.u_cols)))

# %% ../nbs/benchmark.ipynb 19
# Test: TrainingContext get_train_valid_sequences - fallback (no train_valid dir)
_ctx_tv_fallback = TrainingContext(spec=_dummy_spec_ctx_base, hyperparameters=_hyperparams_ctx, seed=_seed_ctx)
_tv_sequences_fallback = list(_ctx_tv_fallback.get_train_valid_sequences())
test_eq(len(_tv_sequences_fallback), _n_files_train_valid_ctx + _n_files_train_valid_ctx)
_u_tv_fb_train, _y_tv_fb_train, _ = _tv_sequences_fallback[0] 
test_eq(_u_tv_fb_train.shape[0], _seq_len_ctx)
_u_tv_fb_valid, _y_tv_fb_valid, _ = _tv_sequences_fallback[_n_files_train_valid_ctx] 
test_eq(_u_tv_fb_valid.shape[0], _seq_len_ctx)

# %% ../nbs/benchmark.ipynb 20
# Test: TrainingContext get_train_valid_sequences - direct (train_valid dir exists)
_ctx_tv_direct = TrainingContext(spec=_dummy_spec_ctx_tv, hyperparameters=_hyperparams_ctx, seed=_seed_ctx)
_tv_sequences_direct = list(_ctx_tv_direct.get_train_valid_sequences())
test_eq(len(_tv_sequences_direct), _n_files_tv_dir_ctx)
_u_tv_direct, _y_tv_direct, _ = _tv_sequences_direct[0]
test_eq(_u_tv_direct.shape[0], _seq_len_ctx)

In [None]:
#| hide
shutil.rmtree(_test_data_dir_ctx, ignore_errors=True)
shutil.rmtree(_test_data_dir_ctx_tv, ignore_errors=True)

In [None]:
#| hide
import nbdev; nbdev.nbdev_export()