From bf5b837bbdc310750d54c37c263c3183e4722a88 Mon Sep 17 00:00:00 2001 From: kylmcgr Date: Fri, 22 Aug 2025 10:28:37 -0400 Subject: [PATCH 01/16] First commit of generic sklearn unit --- pyproject.toml | 1 + src/ezmsg/learn/process/sklearn.py | 274 +++++++++++++++++++++++ tests/integration/test_sklearn_system.py | 87 +++++++ tests/unit/test_sklearn.py | 228 +++++++++++++++++++ 4 files changed, 590 insertions(+) create mode 100644 src/ezmsg/learn/process/sklearn.py create mode 100644 tests/integration/test_sklearn_system.py create mode 100644 tests/unit/test_sklearn.py diff --git a/pyproject.toml b/pyproject.toml index 31a752d..633a9b1 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -25,6 +25,7 @@ lint = [ "ruff>=0.12.9", ] test = [ + "hmmlearn>=0.3.3", "pytest>=8.4.1", ] diff --git a/src/ezmsg/learn/process/sklearn.py b/src/ezmsg/learn/process/sklearn.py new file mode 100644 index 0000000..8a2f281 --- /dev/null +++ b/src/ezmsg/learn/process/sklearn.py @@ -0,0 +1,274 @@ +import importlib +import pickle +import typing + +import ezmsg.core as ez +import numpy as np +import pandas as pd +from ezmsg.sigproc.base import ( + BaseAdaptiveTransformer, + BaseAdaptiveTransformerUnit, + processor_state, +) +from ezmsg.sigproc.sampler import SampleMessage +from ezmsg.util.messages.axisarray import AxisArray +from ezmsg.util.messages.util import replace + + +class SklearnModelSettings(ez.Settings): + model_class: str + """ + Full path to the sklearn model class + Example: 'sklearn.linear_model.LinearRegression' + """ + model_kwargs: dict[str, typing.Any] = None + """ + Additional keyword arguments to pass to the model constructor. + Example: {'fit_intercept': True, 'normalize': False} + """ + checkpoint_path: str | None = None + """ + Path to a checkpoint file to load the model from. + If provided, the model will be initialized from this checkpoint. + Example: 'path/to/checkpoint.pkl' + """ + partial_fit_classes: np.ndarray | None = None + """ + The full list of classes to use for partial_fit calls. + This must be provided to use `partial_fit` with classifiers. + """ + + +@processor_state +class SklearnModelState: + model: typing.Any = None + chan_ax: AxisArray.CoordinateAxis | None = None + + +class SklearnModelProcessor( + BaseAdaptiveTransformer[ + SklearnModelSettings, AxisArray, AxisArray, SklearnModelState + ] +): + """ + Processor that wraps a scikit-learn, River, or HMMLearn model for use in the ezmsg framework. + + This processor supports: + - `fit`, `partial_fit`, or River's `learn_many`/`learn_one` for training. + - `predict`, River's `predict_many`, or `predict_one` for inference. + - Optional model checkpoint loading and saving. + + The processor expects and outputs `AxisArray` messages with a `"ch"` (channel) axis. + + Settings: + --------- + model_class : str + Full path to the sklearn or River model class to use. + Example: "sklearn.linear_model.SGDClassifier" or "river.linear_model.LogisticRegression" + + model_kwargs : dict[str, typing.Any], optional + Additional keyword arguments passed to the model constructor. + + checkpoint_path : str, optional + Path to a pickle file to load a previously saved model. If provided, the model will + be restored from this path at startup. + + partial_fit_classes : np.ndarray, optional + For classifiers that require all class labels to be specified during `partial_fit`. + + Example: + ----------------------------- + ```python + processor = SklearnModelProcessor( + settings=SklearnModelSettings( + model_class='sklearn.linear_model.SGDClassifier', + model_kwargs={'loss': 'log_loss'}, + partial_fit_classes=np.array([0, 1]), + ) + ) + ``` + """ + + def _init_model(self) -> None: + module_path, class_name = self.settings.model_class.rsplit(".", 1) + model_cls = getattr(importlib.import_module(module_path), class_name) + kwargs = self.settings.model_kwargs or {} + self._state.model = model_cls(**kwargs) + + def save_checkpoint(self, path: str) -> None: + with open(path, "wb") as f: + pickle.dump(self._state.model, f) + + def load_checkpoint(self, path: str) -> None: + try: + with open(path, "rb") as f: + self._state.model = pickle.load(f) + except Exception as e: + ez.logger.error(f"Failed to load model from {path}: {str(e)}") + raise RuntimeError(f"Failed to load model from {path}: {str(e)}") from e + + def _reset_state(self, message: AxisArray) -> None: + # Try loading from checkpoint first + if self.settings.checkpoint_path: + self.load_checkpoint(self.settings.checkpoint_path) + n_input = message.data.shape[message.get_axis_idx("ch")] + if hasattr(self._state.model, "n_features_in_"): + expected = self._state.model.n_features_in_ + if expected != n_input: + raise ValueError( + f"Model expects {expected} features, but got {n_input}" + ) + else: + # No checkpoint, initialize from scratch + self._init_model() + + def partial_fit(self, message: SampleMessage) -> None: + X = message.sample.data + y = message.trigger.value + if self._state.model is None: + self._reset_state(message.sample) + if hasattr(self._state.model, "partial_fit"): + kwargs = {} + if self.settings.partial_fit_classes is not None: + kwargs["classes"] = self.settings.partial_fit_classes + self._state.model.partial_fit(X, y, **kwargs) + elif hasattr(self._state.model, "learn_many"): + df_X = pd.DataFrame( + { + k: v + for k, v in zip( + message.sample.axes["ch"].data, message.sample.data.T + ) + } + ) + name = ( + message.trigger.value.axes["ch"].data[0] + if hasattr(message.trigger.value, "axes") + and "ch" in message.trigger.value.axes + else "target" + ) + ser_y = pd.Series( + data=np.asarray(message.trigger.value.data).flatten(), + name=name, + ) + self._state.model.learn_many(df_X, ser_y) + elif hasattr(self._state.model, "learn_one"): + # river's random forest does not support learn_many + for xi, yi in zip(X, y): + features = {f"f{i}": xi[i] for i in range(len(xi))} + self._state.model.learn_one(features, yi) + else: + raise NotImplementedError( + "Model does not support partial_fit or learn_many" + ) + + def fit(self, X: np.ndarray, y: np.ndarray) -> None: + if self._state.model is None: + dummy_msg = AxisArray( + data=X, + dims=["time", "ch"], + axes={ + "time": AxisArray.TimeAxis(fs=1.0), + "ch": AxisArray.CoordinateAxis( + data=np.array([f"ch_{i}" for i in range(X.shape[1])]), + dims=["ch"], + ), + }, + ) + self._reset_state(dummy_msg) + if hasattr(self._state.model, "fit"): + self._state.model.fit(X, y) + elif hasattr(self._state.model, "learn_many"): + df_X = pd.DataFrame(X, columns=[f"f{i}" for i in range(X.shape[1])]) + ser_y = pd.Series(y.flatten(), name="target") + self._state.model.learn_many(df_X, ser_y) + elif hasattr(self._state.model, "learn_one"): + # river's random forest does not support learn_many + for xi, yi in zip(X, y): + features = {f"f{i}": xi[i] for i in range(len(xi))} + self._state.model.learn_one(features, yi) + else: + raise NotImplementedError("Model does not support fit or learn_many") + + def _process(self, message: AxisArray) -> AxisArray: + if self._state.model is None: + raise RuntimeError( + "Model has not been fit yet. Call `fit()` or `partial_fit()` before processing." + ) + X = message.data + original_shape = X.shape + n_input = X.shape[message.get_axis_idx("ch")] + + # Ensure X is 2D + X = X.reshape(-1, n_input) + if hasattr(self._state.model, "n_features_in_"): + expected = self._state.model.n_features_in_ + if expected != n_input: + raise ValueError( + f"Model expects {expected} features, but got {n_input}" + ) + + if hasattr(self._state.model, "predict"): + y_pred = self._state.model.predict(X) + elif hasattr(self._state.model, "predict_many"): + df_X = pd.DataFrame(X, columns=[f"f{i}" for i in range(X.shape[1])]) + y_pred = self._state.model.predict_many(df_X) + y_pred = np.array(list(y_pred)) + elif hasattr(self._state.model, "predict_one"): + # river's random forest does not support predict_many + y_pred = np.array( + [ + self._state.model.predict_one( + {f"f{i}": xi[i] for i in range(len(xi))} + ) + for xi in X + ] + ) + else: + raise NotImplementedError("Model does not support predict or predict_many") + + # For scalar outputs, ensure the output is 2D + if y_pred.ndim == 1: + y_pred = y_pred[:, np.newaxis] + + output_shape = original_shape[:-1] + (y_pred.shape[-1],) + y_pred = y_pred.reshape(output_shape) + + if self._state.chan_ax is None: + self._state.chan_ax = AxisArray.CoordinateAxis( + data=np.arange(output_shape[1]), dims=["ch"] + ) + + return replace( + message, + data=y_pred, + axes={**message.axes, "ch": self._state.chan_ax}, + ) + + +class SklearnModelUnit( + BaseAdaptiveTransformerUnit[ + SklearnModelSettings, AxisArray, AxisArray, SklearnModelProcessor + ] +): + """ + Unit wrapper for the `SklearnModelProcessor`. + + This unit provides a plug-and-play interface for using a scikit-learn or River model + in an ezmsg graph-based system. It takes in `AxisArray` inputs and outputs predictions + in the same format, optionally performing training via `partial_fit` or `fit`. + + Example: + -------- + ```python + unit = SklearnModelUnit( + settings=SklearnModelSettings( + model_class='sklearn.linear_model.SGDClassifier', + model_kwargs={'loss': 'log_loss'}, + partial_fit_classes=np.array([0, 1]), + ) + ) + ``` + """ + + SETTINGS = SklearnModelSettings diff --git a/tests/integration/test_sklearn_system.py b/tests/integration/test_sklearn_system.py new file mode 100644 index 0000000..b82a4dc --- /dev/null +++ b/tests/integration/test_sklearn_system.py @@ -0,0 +1,87 @@ +import os +import pickle +import tempfile +from pathlib import Path + +import ezmsg.core as ez +import numpy as np +import pandas as pd +from ezmsg.sigproc.synth import Counter +from ezmsg.util.messagecodec import message_log +from ezmsg.util.messagelogger import MessageLogger +from ezmsg.util.messages.axisarray import AxisArray +from ezmsg.util.terminate import TerminateOnTotal +from river.linear_model import LinearRegression + +from ezmsg.learn.process.sklearn import SklearnModelUnit + + +def test_sklearn_model_unit_system(): + fs = 10.0 + block_size = 4 + duration = 2.0 # seconds + input_size = 3 + output_size = 1 # For most sklearn regressors, output is single dim + + # Create temporary checkpoint file path + checkpoint_path = Path(tempfile.gettempdir()) / "sklearn_checkpoint.pkl" + + # Fit model and save checkpoint + model = LinearRegression() + X = pd.DataFrame( + np.random.randn(block_size, input_size), + columns=[f"f{i}" for i in range(input_size)], + ) + y = pd.Series(np.random.randn(block_size), name="target") + model.learn_many(X, y) + with open(checkpoint_path, "wb") as f: + pickle.dump(model, f) + + test_filename = Path(tempfile.gettempdir()) / "test_sklearn_system.txt" + with open(test_filename, "w"): + pass + ez.logger.info(f"Logging to {test_filename}") + + comps = { + "SRC": Counter( + fs=fs, + n_ch=input_size, + n_time=block_size, + dispatch_rate=duration, + mod=None, + ), + "MODEL": SklearnModelUnit( + model_class="river.linear_model.LinearRegression", + model_kwargs={}, + checkpoint_path=str(checkpoint_path), + partial_fit_classes=None, + ), + "LOG": MessageLogger(output=test_filename), + "TERM": TerminateOnTotal(total=int(duration * fs / block_size)), + } + + conns = ( + (comps["SRC"].OUTPUT_SIGNAL, comps["MODEL"].INPUT_SIGNAL), + (comps["MODEL"].OUTPUT_SIGNAL, comps["LOG"].INPUT_MESSAGE), + (comps["LOG"].OUTPUT_MESSAGE, comps["TERM"].INPUT_MESSAGE), + ) + + # Run the pipeline + ez.run(components=comps, connections=conns) + + # Read logged messages + messages: list[AxisArray] = list(message_log(test_filename)) + + # Clean up the temporary files + os.remove(test_filename) + os.remove(checkpoint_path) + + # Assertions + # Output shape should have same number of samples, channels = output_size + assert all(msg.data.shape[0] == block_size for msg in messages) + assert all(msg.data.shape[1] == output_size for msg in messages) + + # Dimensions and axes presence checks + assert all("time" in msg.dims and "ch" in msg.dims for msg in messages) + assert all("ch" in msg.axes for msg in messages) + assert messages[0].axes["ch"].data.shape[0] == output_size diff --git a/tests/unit/test_sklearn.py b/tests/unit/test_sklearn.py new file mode 100644 index 0000000..398d335 --- /dev/null +++ b/tests/unit/test_sklearn.py @@ -0,0 +1,228 @@ +import numpy as np +import pytest +from ezmsg.sigproc.sampler import SampleMessage, SampleTriggerMessage +from ezmsg.util.messages.axisarray import AxisArray + +from ezmsg.learn.process.sklearn import SklearnModelProcessor + + +@pytest.fixture +def input_axisarray(): + data = np.random.randn(10, 4) + return AxisArray( + data=data, + dims=["time", "ch"], + axes={ + "time": AxisArray.TimeAxis(fs=100.0), + "ch": AxisArray.CoordinateAxis(data=np.arange(4), dims=["ch"]), + }, + ) + + +@pytest.fixture +def labels_classification(): + return np.random.randint(0, 2, size=(10,)) + + +@pytest.fixture +def labels_regression(): + return np.random.randn(10) + + +@pytest.mark.parametrize( + "model_class,is_classifier", + [ + ("river.linear_model.LinearRegression", False), + ("river.linear_model.LogisticRegression", True), + ("sklearn.linear_model.Ridge", False), + ("sklearn.discriminant_analysis.LinearDiscriminantAnalysis", True), + ], +) +def test_output_shape_inference( + model_class, + is_classifier, + input_axisarray, + labels_classification, + labels_regression, +): + proc = SklearnModelProcessor(model_class=model_class) + proc._reset_state(input_axisarray) + + # Fit the model before prediction + labels = labels_classification if is_classifier else labels_regression + proc.fit(input_axisarray.data, labels) + + output = proc._process(input_axisarray) + assert output.data.shape[0] == input_axisarray.data.shape[0] + assert output.data.ndim == 2 + + +@pytest.mark.parametrize( + "model_class,is_classifier", + [ + ("river.linear_model.LinearRegression", False), + ("river.linear_model.LogisticRegression", True), + ("sklearn.linear_model.SGDClassifier", True), + ("sklearn.linear_model.SGDRegressor", False), + ], +) +def test_partial_fit_supported_models( + model_class, + is_classifier, + input_axisarray, + labels_classification, + labels_regression, +): + labels = labels_classification if is_classifier else labels_regression + + settings_kwargs = {"model_class": model_class} + + if is_classifier: + settings_kwargs["partial_fit_classes"] = np.array([0, 1]) + + proc = SklearnModelProcessor(**settings_kwargs) + proc._reset_state(input_axisarray) + + sample_msg = SampleMessage( + sample=input_axisarray, + trigger=SampleTriggerMessage(timestamp=0.0, value=labels), + ) + + proc.partial_fit(sample_msg) + output = proc._process(input_axisarray) + assert output.data.shape[0] == input_axisarray.data.shape[0] + + +def test_partial_fit_unsupported_model(input_axisarray, labels_regression): + proc = SklearnModelProcessor(model_class="sklearn.linear_model.Ridge") + proc._reset_state(input_axisarray) + sample_msg = SampleMessage( + sample=input_axisarray, + trigger=SampleTriggerMessage(timestamp=0.0, value=labels_regression), + ) + with pytest.raises(NotImplementedError, match="partial_fit"): + proc.partial_fit(sample_msg) + + +def test_partial_fit_changes_model(input_axisarray, labels_regression): + proc = SklearnModelProcessor(model_class="sklearn.linear_model.SGDRegressor") + proc._reset_state(input_axisarray) + + sample_msg = SampleMessage( + sample=input_axisarray, + trigger=SampleTriggerMessage(timestamp=0.0, value=labels_regression), + ) + + proc.partial_fit(sample_msg) + output_before = proc._process(input_axisarray).data.copy() + proc.partial_fit(sample_msg) + output_after = proc._process(input_axisarray).data + assert not np.allclose(output_before, output_after) + + +def test_model_save_and_load(tmp_path, input_axisarray): + proc = SklearnModelProcessor(model_class="sklearn.linear_model.Ridge") + proc._reset_state(input_axisarray) + + checkpoint_path = tmp_path / "model_checkpoint.pkl" + proc.save_checkpoint(str(checkpoint_path)) + + new_proc = SklearnModelProcessor( + model_class="sklearn.linear_model.Ridge", checkpoint_path=str(checkpoint_path) + ) + new_proc._reset_state(input_axisarray) + assert new_proc._state.model is not None + + +def test_input_shape_mismatch_raises(input_axisarray, labels_regression): + proc = SklearnModelProcessor(model_class="sklearn.linear_model.Ridge") + proc._reset_state(input_axisarray) + + # Fit the model first so `n_features_in_` is set + proc.fit(input_axisarray.data, labels_regression) + + # Create mismatched message with 3 input channels instead of 4 + bad_data = np.random.randn(10, 3) + bad_msg = AxisArray( + data=bad_data, + dims=["time", "ch"], + axes={ + "time": AxisArray.TimeAxis(fs=100.0), + "ch": AxisArray.CoordinateAxis(data=np.arange(3), dims=["ch"]), + }, + ) + + with pytest.raises(ValueError, match="Model expects .* features, but got .*"): + proc._process(bad_msg) + + +@pytest.mark.parametrize( + "model_class,is_classifier", + [ + ("river.forest.ARFClassifier", True), + ("river.forest.ARFRegressor", False), + ], +) +def test_random_forest_inference_shape( + model_class, + is_classifier, + input_axisarray, + labels_classification, + labels_regression, +): + labels = labels_classification if is_classifier else labels_regression + + proc = SklearnModelProcessor(model_class=model_class) + proc._reset_state(input_axisarray) + + # Simulate one-time fitting (not partial) + proc.fit(input_axisarray.data, labels) + + output = proc._process(input_axisarray) + expected_output_dim = 1 + + assert output.data.shape == (input_axisarray.data.shape[0], expected_output_dim) + + +def test_random_forest_save_load(tmp_path, input_axisarray, labels_classification): + proc = SklearnModelProcessor( + model_class="river.forest.ARFClassifier", + ) + proc._reset_state(input_axisarray) + proc.fit(input_axisarray.data, labels_classification) + + ckpt_path = tmp_path / "rf_ckpt.pkl" + proc.save_checkpoint(str(ckpt_path)) + + # Load new processor + new_proc = SklearnModelProcessor( + model_class="river.forest.ARFClassifier", + checkpoint_path=str(ckpt_path), + ) + new_proc._reset_state(input_axisarray) + assert new_proc._state.model is not None + + # Check outputs still work + output = new_proc._process(input_axisarray) + assert output.data.shape[0] == input_axisarray.data.shape[0] + + +def test_hmmlearn_gaussianhmm_predict(input_axisarray): + # Ensure data has no NaNs and is usable for fitting + X = input_axisarray.data + + proc = SklearnModelProcessor( + model_class="hmmlearn.hmm.GaussianHMM", + model_kwargs={"n_components": 2, "n_iter": 10}, + ) + + # hmmlearn expects (n_samples, n_features), so we combine time axis + proc._reset_state(input_axisarray) + proc.fit(X, None) # HMM doesn't use labels + + output = proc._process(input_axisarray) + + # Output should be a sequence of state predictions, shape (timesteps, 1) + assert output.data.shape[0] == input_axisarray.data.shape[0] + assert output.data.ndim == 2 + assert output.data.shape[1] == 1 From 4c71e0f59a3ba721a6ac72739dd557f7e34f5325 Mon Sep 17 00:00:00 2001 From: kylmcgr Date: Fri, 22 Aug 2025 11:56:01 -0400 Subject: [PATCH 02/16] First commit of ModelInitMixin to help generic models init with weights checkpoint --- src/ezmsg/learn/process/base.py | 173 ++++++++++++++++++++++++++++++++ 1 file changed, 173 insertions(+) create mode 100644 src/ezmsg/learn/process/base.py diff --git a/src/ezmsg/learn/process/base.py b/src/ezmsg/learn/process/base.py new file mode 100644 index 0000000..4afa058 --- /dev/null +++ b/src/ezmsg/learn/process/base.py @@ -0,0 +1,173 @@ +import inspect +import json +from pathlib import Path +import typing + +import ezmsg.core as ez +import torch + + +class ModelInitMixin: + """ + Mixin class to support model initialization from: + 1. Setting parameters + 2. Config file + 3. Checkpoint file + """ + + @staticmethod + def _merge_config(model_kwargs: dict, config) -> None: + """ + Mutate the model_kwargs dictionary with the config parameters. + Args: + model_kwargs: Original to-be-mutated model kwargs. + config: Update config parameters. + + Returns: + None because model_kwargs is mutated in place. + """ + if "model_params" in config: + config = config["model_params"] + # Update model_kwargs with config parameters + for key, value in config.items(): + if key in model_kwargs: + if model_kwargs[key] != value: + ez.logger.warning( + f"Config parameter {key} ({value}) differs from settings ({model_kwargs[key]})." + ) + else: + ez.logger.warning(f"Config parameter {key} is not in model_kwargs.") + model_kwargs[key] = value + + def _filter_model_kwargs(self, model_class, kwargs: dict) -> dict: + valid_params = inspect.signature(model_class.__init__).parameters + filtered_out = set(kwargs.keys()) - {k for k in valid_params if k != "self"} + if filtered_out: + ez.logger.warning( + f"Ignoring unexpected model parameters not accepted by {model_class.__name__} constructor: {sorted(filtered_out)}" + ) + # Keep all valid parameters, including None values, so checkpoint-inferred values can overwrite them + return {k: v for k, v in kwargs.items() if k in valid_params and k != "self"} + + def _init_model( + self, + model_class, + params: dict[str, typing.Any] | None = None, + config_path: str | None = None, + checkpoint_path: str | None = None, + device: str = "cpu", + state_dict_prefix: str | None = None, + weights_only: bool | None = None, + ) -> torch.nn.Module: + """ + Args: + model_class: The class of the model to be initialized. + params: A dictionary of setting parameters to be used for model initialization. + config_path: Path to a JSON config file to update model parameters. + checkpoint_path: Path to a checkpoint file to load model weights and possibly config. + + Returns: + The initialized model. + The model will be initialized with the correct config and weights. + + """ + # Model parameters are taken from multiple sources, in ascending priority: + # 1. Setting parameters + # 2. Config file if provided + # 3. "config" entry in checkpoint file if checkpoint file provided and config present + # 4. Sizes of weights in checkpoint file if provided + + # Get configs from setting params. + model_kwargs = params or {} + state_dict = None + + # Check if a config file is provided and if so use that to update kwargs (with warnings). + if config_path: + config_path = Path(config_path) + if not config_path.exists(): + ez.logger.error(f"Config path {config_path} does not exist.") + raise FileNotFoundError(f"Config path {config_path} does not exist.") + try: + with open(config_path, "r") as f: + config = json.load(f) + self._merge_config(model_kwargs, config) + except Exception as e: + raise RuntimeError( + f"Failed to load config from {config_path}: {str(e)}" + ) + + # If a checkpoint file is provided, load it. + if checkpoint_path: + checkpoint_path = Path(checkpoint_path) + if not checkpoint_path.exists(): + ez.logger.error(f"Checkpoint path {checkpoint_path} does not exist.") + raise FileNotFoundError( + f"Checkpoint path {checkpoint_path} does not exist." + ) + try: + checkpoint = torch.load( + checkpoint_path, map_location=device, weights_only=weights_only + ) + + if "config" in checkpoint: + config = checkpoint["config"] + self._merge_config(model_kwargs, config) + + # Load the model weights and infer the config. + state_dict = checkpoint + if "model_state_dict" in checkpoint: + state_dict = checkpoint["model_state_dict"] + elif "state_dict" in checkpoint: + # This is for backward compatibility with older checkpoints + # that used "state_dict" instead of "model_state_dict" + state_dict = checkpoint["state_dict"] + infer_config = getattr( + model_class, + "infer_config_from_state_dict", + lambda _state_dict: {}, # Default to empty dict if not defined + ) + infer_kwargs = ( + {"rnn_type": model_kwargs["rnn_type"]} + if "rnn_type" in model_kwargs + else {} + ) + self._merge_config( + model_kwargs, + infer_config(state_dict, **infer_kwargs), + ) + + except Exception as e: + raise RuntimeError( + f"Failed to load checkpoint from {checkpoint_path}: {str(e)}" + ) + + # Filter model_kwargs to only include valid parameters for the model class + filtered_kwargs = self._filter_model_kwargs(model_class, model_kwargs) + + # Remove None values from filtered_kwargs to avoid passing them to the model constructor + # This should only happen for parameters that weren't inferred from the checkpoint + final_kwargs = {k: v for k, v in filtered_kwargs.items() if v is not None} + + # Create the model with the final kwargs + model = model_class(**final_kwargs) + + # Finally, load the weights. + if state_dict: + if state_dict_prefix: + # If a prefix is provided, filter the state_dict keys + state_dict = { + k[len(state_dict_prefix) :]: v + for k, v in state_dict.items() + if k.startswith(state_dict_prefix) + } + # Load the model weights + missing, unexpected = model.load_state_dict( + state_dict, strict=False, assign=True + ) + if missing or unexpected: + ez.logger.warning( + f"Partial load: missing keys: {missing}, unexpected keys: {unexpected}" + ) + + model.to(device) + return model From 765cacc6571f5b487e37844c19733aaf84c69fd2 Mon Sep 17 00:00:00 2001 From: kylmcgr Date: Fri, 22 Aug 2025 11:57:40 -0400 Subject: [PATCH 03/16] First commit of TorchModelProcessor for generic torch models --- src/ezmsg/learn/process/torch.py | 378 +++++++++++++++++++++++++ tests/integration/test_torch_system.py | 68 +++++ tests/unit/test_torch.py | 349 +++++++++++++++++++++++ 3 files changed, 795 insertions(+) create mode 100644 src/ezmsg/learn/process/torch.py create mode 100644 tests/integration/test_torch_system.py create mode 100644 tests/unit/test_torch.py diff --git a/src/ezmsg/learn/process/torch.py b/src/ezmsg/learn/process/torch.py new file mode 100644 index 0000000..60137b0 --- /dev/null +++ b/src/ezmsg/learn/process/torch.py @@ -0,0 +1,378 @@ +import importlib +import typing + +import ezmsg.core as ez +import numpy as np +import torch +from ezmsg.sigproc.base import ( + BaseAdaptiveTransformer, + BaseAdaptiveTransformerUnit, + BaseStatefulTransformer, + BaseTransformerUnit, + processor_state, +) +from ezmsg.sigproc.sampler import SampleMessage +from ezmsg.sigproc.util.profile import profile_subpub +from ezmsg.util.messages.axisarray import AxisArray +from ezmsg.util.messages.util import replace + +from .base import ModelInitMixin + + +class TorchSimpleSettings(ez.Settings): + model_class: str + """ + Fully qualified class path of the model to be used. + Example: "my_module.MyModelClass" + This class should inherit from `torch.nn.Module`. + """ + + checkpoint_path: str | None = None + """ + Path to a checkpoint file containing model weights. + If None, the model will be initialized with random weights. + If parameters inferred from the weight sizes conflict with the settings / config, + then the the inferred parameters will take priority and a warning will be logged. + """ + + config_path: str | None = None + """ + Path to a config file containing model parameters. + Parameters loaded from the config file will take priority over settings. + If settings differ from config parameters then a warning will be logged. + If `checkpoint_path` is provided then any parameters inferred from the weights + will take priority over the config parameters. + """ + + single_precision: bool = True + """Use single precision (float32) instead of double precision (float64)""" + + device: str | None = None + """ + Device to use for the model. If None, the device will be determined automatically, + with preference for cuda > mps > cpu. + """ + + model_kwargs: dict[str, typing.Any] | None = None + """ + Additional keyword arguments to pass to the model constructor. + This can include parameters like `input_size`, `output_size`, etc. + If a config file is provided, these parameters will be updated with the config values. + If a checkpoint file is provided, these parameters will be updated with the inferred values + from the model weights. + """ + + +class TorchModelSettings(TorchSimpleSettings): + learning_rate: float = 0.001 + """Learning rate for the optimizer""" + + weight_decay: float = 0.0001 + """Weight decay for the optimizer""" + + loss_fn: torch.nn.Module | dict[str, torch.nn.Module] | None = None + """ + Loss function(s) for the decoder. If using multiple heads, this should be a dictionary + mapping head names to loss functions. The keys must match the output head names. + """ + + loss_weights: dict[str, float] | None = None + """ + Weights for each loss function if using multiple heads. + The keys must match the output head names. + If None or missing/mismatched keys, losses will be unweighted. + """ + + scheduler_gamma: float = 0.999 + """Learning scheduler decay rate. Set to 0.0 to disable the scheduler.""" + + +@processor_state +class TorchSimpleState: + model: torch.nn.Module | None = None + device: torch.device | None = None + chan_ax: dict[str, AxisArray.CoordinateAxis] | None = None + + +class TorchModelState(TorchSimpleState): + optimizer: torch.optim.Optimizer | None = None + scheduler: torch.optim.lr_scheduler.LRScheduler | None = None + + +P = typing.TypeVar("P", bound=BaseStatefulTransformer) + + +class TorchProcessorMixin: + """Mixin with shared functionality for torch processors.""" + + def _import_model(self, class_path: str) -> type[torch.nn.Module]: + """Dynamically import model class from string.""" + if class_path is None: + raise ValueError("Model class path must be provided in settings.") + module_path, class_name = class_path.rsplit(".", 1) + module = importlib.import_module(module_path) + return getattr(module, class_name) + + def _infer_output_sizes( + self: P, model: torch.nn.Module, n_input: int + ) -> dict[str, int]: + """Simple inference to get output channel size. Override if needed.""" + dummy_input = torch.zeros(1, 1, n_input, device=self._state.device) + with torch.no_grad(): + output = model(dummy_input) + + if isinstance(output, dict): + return {k: v.shape[-1] for k, v in output.items()} + else: + return {"output": output.shape[-1]} + + def _init_optimizer(self) -> None: + self._state.optimizer = torch.optim.AdamW( + self._state.model.parameters(), + lr=self.settings.learning_rate, + weight_decay=self.settings.weight_decay, + ) + self._state.scheduler = ( + torch.optim.lr_scheduler.ExponentialLR( + self._state.optimizer, gamma=self.settings.scheduler_gamma + ) + if self.settings.scheduler_gamma > 0.0 + else None + ) + + def _validate_loss_keys(self, output_keys: list[str]): + if isinstance(self.settings.loss_fn, dict): + missing = [k for k in output_keys if k not in self.settings.loss_fn] + if missing: + raise ValueError(f"Missing loss function(s) for output keys: {missing}") + + def _to_tensor(self: P, data: np.ndarray) -> torch.Tensor: + dtype = torch.float32 if self.settings.single_precision else torch.float64 + if isinstance(data, torch.Tensor): + return data.detach().clone().to(device=self._state.device, dtype=dtype) + return torch.tensor(data, dtype=dtype, device=self._state.device) + + def save_checkpoint(self: P, path: str) -> None: + """Save the current model state to a checkpoint file.""" + if self._state.model is None: + raise RuntimeError("Model must be initialized before saving a checkpoint.") + + checkpoint = { + "model_state_dict": self._state.model.state_dict(), + "config": self.settings.model_kwargs or {}, + } + + # Add optimizer state if available + if hasattr(self._state, "optimizer") and self._state.optimizer is not None: + checkpoint["optimizer_state_dict"] = self._state.optimizer.state_dict() + + torch.save(checkpoint, path) + + def _ensure_batched(self, tensor: torch.Tensor) -> tuple[torch.Tensor, bool]: + """ + Ensure tensor has a batch dimension. + Returns the potentially modified tensor and a flag indicating whether a dimension was added. + """ + if tensor.ndim == 2: + return tensor.unsqueeze(0), True + return tensor, False + + def _common_process(self: P, message: AxisArray) -> list[AxisArray]: + data = message.data + data = self._to_tensor(data) + + # Add batch dimension if missing + data, added_batch_dim = self._ensure_batched(data) + + with torch.no_grad(): + output = self._state.model(data) + + if isinstance(output, dict): + output_messages = [ + replace( + message, + data=value.cpu().numpy().squeeze(0) + if added_batch_dim + else value.cpu().numpy(), + axes={ + **message.axes, + "ch": self._state.chan_ax[key], + }, + key=key, + ) + for key, value in output.items() + ] + return output_messages + + return [ + replace( + message, + data=output.cpu().numpy().squeeze(0) + if added_batch_dim + else output.cpu().numpy(), + axes={ + **message.axes, + "ch": self._state.chan_ax["output"], + }, + ) + ] + + def _common_reset_state(self: P, message: AxisArray, model_kwargs: dict) -> None: + n_input = message.data.shape[message.get_axis_idx("ch")] + + if "input_size" in model_kwargs: + if model_kwargs["input_size"] != n_input: + raise ValueError( + f"Mismatch between model_kwargs['input_size']={model_kwargs['input_size']} " + f"and input data channels={n_input}" + ) + else: + model_kwargs["input_size"] = n_input + + device = ( + "cuda" + if torch.cuda.is_available() + else ("mps" if torch.mps.is_available() else "cpu") + ) + device = self.settings.device or device + self._state.device = torch.device(device) + + model_class = self._import_model(self.settings.model_class) + + self._state.model = self._init_model( + model_class=model_class, + params=model_kwargs, + config_path=self.settings.config_path, + checkpoint_path=self.settings.checkpoint_path, + device=device, + ) + + self._state.model.eval() + + output_sizes = self._infer_output_sizes(self._state.model, n_input) + self._state.chan_ax = { + head: AxisArray.CoordinateAxis( + data=np.array([f"{head}_ch{_}" for _ in range(size)]), + dims=["ch"], + ) + for head, size in output_sizes.items() + } + + +class TorchSimpleProcessor( + BaseStatefulTransformer[ + TorchSimpleSettings, AxisArray, AxisArray, TorchSimpleState + ], + TorchProcessorMixin, + ModelInitMixin, +): + def _reset_state(self, message: AxisArray) -> None: + model_kwargs = dict(self.settings.model_kwargs or {}) + self._common_reset_state(message, model_kwargs) + + def _process(self, message: AxisArray) -> list[AxisArray]: + """Process the input message and return the output messages.""" + return self._common_process(message) + + +class TorchSimpleUnit( + BaseTransformerUnit[ + TorchSimpleSettings, + AxisArray, + AxisArray, + TorchSimpleProcessor, + ] +): + SETTINGS = TorchSimpleSettings + + @ez.subscriber(BaseTransformerUnit.INPUT_SIGNAL, zero_copy=True) + @ez.publisher(BaseTransformerUnit.OUTPUT_SIGNAL) + @profile_subpub(trace_oldest=False) + async def on_signal(self, message: AxisArray) -> typing.AsyncGenerator: + results = await self.processor.__acall__(message) + for result in results: + yield self.OUTPUT_SIGNAL, result + + +class TorchModelProcessor( + BaseAdaptiveTransformer[TorchModelSettings, AxisArray, AxisArray, TorchModelState], + TorchProcessorMixin, + ModelInitMixin, +): + def _reset_state(self, message: AxisArray) -> None: + model_kwargs = dict(self.settings.model_kwargs or {}) + self._common_reset_state(message, model_kwargs) + self._init_optimizer() + self._validate_loss_keys(list(self._state.chan_ax.keys())) + + def _process(self, message: AxisArray) -> list[AxisArray]: + return self._common_process(message) + + def partial_fit(self, message: SampleMessage) -> None: + self._state.model.train() + + X = self._to_tensor(message.sample.data) + X, batched = self._ensure_batched(X) + + y_targ = message.trigger.value + if not isinstance(y_targ, dict): + y_targ = {"output": y_targ} + y_targ = {key: self._to_tensor(value) for key, value in y_targ.items()} + if batched: + for key in y_targ: + y_targ[key] = y_targ[key].unsqueeze(0) + + loss_fns = self.settings.loss_fn + if loss_fns is None: + raise ValueError("loss_fn must be provided in settings to use partial_fit") + if not isinstance(loss_fns, dict): + loss_fns = {k: loss_fns for k in y_targ.keys()} + + weights = self.settings.loss_weights or {} + + with torch.set_grad_enabled(True): + y_pred = self._state.model(X) + if not isinstance(y_pred, dict): + y_pred = {"output": y_pred} + + losses = [] + for key in y_targ.keys(): + loss_fn = loss_fns.get(key) + if loss_fn is None: + raise ValueError( + f"Loss function for key '{key}' is not defined in settings." + ) + if isinstance(loss_fn, torch.nn.CrossEntropyLoss): + loss = loss_fn(y_pred[key].permute(0, 2, 1), y_targ[key].long()) + else: + loss = loss_fn(y_pred[key], y_targ[key]) + weight = weights.get(key, 1.0) + losses.append(loss * weight) + total_loss = sum(losses) + + self._state.optimizer.zero_grad() + total_loss.backward() + self._state.optimizer.step() + if self._state.scheduler is not None: + self._state.scheduler.step() + + self._state.model.eval() + + +class TorchModelUnit( + BaseAdaptiveTransformerUnit[ + TorchModelSettings, + AxisArray, + AxisArray, + TorchModelProcessor, + ] +): + SETTINGS = TorchModelSettings + + @ez.subscriber(BaseAdaptiveTransformerUnit.INPUT_SIGNAL, zero_copy=True) + @ez.publisher(BaseAdaptiveTransformerUnit.OUTPUT_SIGNAL) + @profile_subpub(trace_oldest=False) + async def on_signal(self, message: AxisArray) -> typing.AsyncGenerator: + results = await self.processor.__acall__(message) + for result in results: + yield self.OUTPUT_SIGNAL, result diff --git a/tests/integration/test_torch_system.py b/tests/integration/test_torch_system.py new file mode 100644 index 0000000..efb9523 --- /dev/null +++ b/tests/integration/test_torch_system.py @@ -0,0 +1,68 @@ +import os +import tempfile +from pathlib import Path + +import ezmsg.core as ez +from ezmsg.sigproc.synth import Counter, CounterSettings +from ezmsg.util.messagecodec import message_log +from ezmsg.util.messagelogger import MessageLogger, MessageLoggerSettings +from ezmsg.util.messages.axisarray import AxisArray +from ezmsg.util.terminate import TerminateOnTotal, TerminateOnTotalSettings + +from ezmsg.learn.process.torch import TorchModelUnit + + +def test_torch_model_unit_system(): + fs = 10.0 + block_size = 4 + duration = 2.0 # seconds + input_size = 3 + output_size = 2 + + test_filename = Path(tempfile.gettempdir()) + test_filename = test_filename / Path("test_torch_system.txt") + with open(test_filename, "w"): + pass + ez.logger.info(f"Logging to {test_filename}") + + comps = { + "SRC": Counter( + CounterSettings( + fs=fs, + n_ch=input_size, + n_time=block_size, + dispatch_rate=duration, + mod=None, + ) + ), + "MODEL": TorchModelUnit( + model_class="tests.unit.test_torch.DummyModel", + model_kwargs={ + "input_size": input_size, + "output_size": output_size, + }, + device="cpu", + ), + "LOG": MessageLogger(MessageLoggerSettings(output=test_filename)), + "TERM": TerminateOnTotal( + TerminateOnTotalSettings(total=int(duration * fs / block_size)) + ), + } + + conns = ( + (comps["SRC"].OUTPUT_SIGNAL, comps["MODEL"].INPUT_SIGNAL), + (comps["MODEL"].OUTPUT_SIGNAL, comps["LOG"].INPUT_MESSAGE), + (comps["LOG"].OUTPUT_MESSAGE, comps["TERM"].INPUT_MESSAGE), + ) + + ez.run(components=comps, connections=conns) + + # Read from message log + messages: list[AxisArray] = [_ for _ in message_log(test_filename)] + os.remove(test_filename) + + # Check basic structure + assert all(msg.data.shape[-1] == output_size for msg in messages) + assert all("time" in msg.dims and "ch" in msg.dims for msg in messages) + assert all("ch" in msg.axes for msg in messages) + assert messages[0].axes["ch"].data.shape[0] == output_size diff --git a/tests/unit/test_torch.py b/tests/unit/test_torch.py new file mode 100644 index 0000000..e1ace17 --- /dev/null +++ b/tests/unit/test_torch.py @@ -0,0 +1,349 @@ +from pathlib import Path + +import numpy as np +import pytest +import torch +from ezmsg.sigproc.sampler import SampleMessage, SampleTriggerMessage +from ezmsg.util.messages.axisarray import AxisArray + +from ezmsg.learn.process.torch import TorchModelProcessor + +DUMMY_MODEL_CLASS = "tests.unit.test_torch.DummyModel" +MULTIHEAD_MODEL_CLASS = "tests.unit.test_torch.MultiHeadModel" + + +class DummyModel(torch.nn.Module): + def __init__(self, input_size=4, output_size=2, dropout=0.0): + super().__init__() + self.linear = torch.nn.Linear(input_size, output_size) + self.dropout = torch.nn.Dropout(dropout) if dropout > 0 else None + + def forward(self, x): + if self.dropout: + x = self.dropout(x) + return self.linear(x) + + @classmethod + def infer_config_from_state_dict(cls, state_dict): + weight = ( + state_dict["linear.weight"] + if isinstance(state_dict, dict) + else state_dict.state_dict()["linear.weight"] + ) + out_features, in_features = weight.shape + return { + "input_size": in_features, + "output_size": out_features, + } + + +class MultiHeadModel(torch.nn.Module): + def __init__(self, input_size=4): + super().__init__() + self.head_a = torch.nn.Linear(input_size, 2) + self.head_b = torch.nn.Linear(input_size, 3) + + def forward(self, x): + return { + "head_a": self.head_a(x), + "head_b": self.head_b(x), + } + + @classmethod + def infer_config_from_state_dict(cls, state_dict): + return {"input_size": state_dict["head_a.weight"].shape[1]} + + +@pytest.fixture +def batch_message(): + input_dim = 6 + batch_size = 10 + data = np.random.randn(batch_size, input_dim) + return AxisArray( + data=data, + dims=["time", "ch"], + axes={ + "time": AxisArray.TimeAxis(fs=100.0), + "ch": AxisArray.CoordinateAxis(data=np.arange(input_dim), dims=["ch"]), + }, + ) + + +@pytest.mark.parametrize("input_size,output_size", [(4, 2), (6, 3), (8, 1)]) +def test_inference_shapes(input_size, output_size): + data = np.random.randn(12, input_size) + msg = AxisArray( + data=data, + dims=["time", "ch"], + axes={ + "time": AxisArray.TimeAxis(fs=100.0), + "ch": AxisArray.CoordinateAxis(data=np.arange(input_size), dims=["ch"]), + }, + ) + proc = TorchModelProcessor( + model_class=DUMMY_MODEL_CLASS, + model_kwargs={ + "input_size": input_size, + "output_size": output_size, + }, + ) + out = proc(msg)[0] + # Check output last dim matches output_size + assert out.data.shape[-1] == output_size + # Check ch axis size + assert out.get_axis("ch").data.shape[0] == output_size + + +def test_checkpoint_loading_and_weights(batch_message): + proc = TorchModelProcessor( + model_class=DUMMY_MODEL_CLASS, + model_kwargs={ + "output_size": 2, + }, + device="cpu", + ) + proc(batch_message) # initialize model + + checkpoint_filename = "test_torch_checkpoint.pth" + proc.save_checkpoint(checkpoint_filename) + + proc2 = TorchModelProcessor( + model_class=DUMMY_MODEL_CLASS, + checkpoint_path=checkpoint_filename, + device="cpu", + ) + proc2(batch_message) + + model_state = proc._state.model.state_dict() + loaded_state = proc2._state.model.state_dict() + + # Check all keys and values + for key in model_state.keys(): + assert key in loaded_state, f"Key '{key}' missing in loaded model state" + assert torch.allclose(model_state[key], loaded_state[key], atol=1e-6), ( + f"Mismatch for key '{key}'" + ) + + # Clean up checkpoint file + for _ in range(5): + Path(checkpoint_filename).unlink(missing_ok=True) + + +@pytest.mark.parametrize("dropout", [0.0, 0.1, 0.5]) +def test_model_kwargs_propagation(dropout, batch_message): + proc = TorchModelProcessor( + model_class=DUMMY_MODEL_CLASS, + model_kwargs={ + "output_size": 2, + "dropout": dropout, + }, + ) + proc(batch_message) + model = proc._state.model + if dropout > 0: + assert isinstance(model.dropout, torch.nn.Dropout) + assert model.dropout.p == dropout + else: + assert model.dropout is None + + +def test_partial_fit_changes_weights(batch_message): + proc = TorchModelProcessor( + model_class=DUMMY_MODEL_CLASS, + loss_fn=torch.nn.MSELoss(), + learning_rate=0.1, + model_kwargs={ + "output_size": 2, + }, + ) + x = batch_message.data[:1] + y = np.random.randn(1, 2) + + sample = AxisArray( + data=x, + dims=["time", "ch"], + axes={ + "time": AxisArray.TimeAxis(fs=100.0), + "ch": AxisArray.CoordinateAxis(data=np.arange(x.shape[1]), dims=["ch"]), + }, + ) + + msg = SampleMessage( + sample=sample, + trigger=SampleTriggerMessage(timestamp=0.0, value=y), + ) + + proc(sample) # run forward pass once to init model + before = proc._state.model.linear.weight.clone() + + proc.partial_fit(msg) + + after = proc._state.model.linear.weight + assert not torch.allclose(before, after) + + # Expect error if no loss function provided + bad_proc = TorchModelProcessor( + model_class=DUMMY_MODEL_CLASS, + loss_fn=None, + learning_rate=0.1, + model_kwargs={ + "input_size": x.shape[-1], + "output_size": 2, + }, + ) + bad_proc(sample) + with pytest.raises(ValueError): + bad_proc.partial_fit(msg) + + +@pytest.mark.parametrize("device", ["cpu", "mps", "cuda"]) +def test_model_runs_on_devices(device, batch_message): + # Skip unavailable devices + if device == "cuda" and not torch.cuda.is_available(): + pytest.skip("CUDA not available") + if device == "mps" and not torch.backends.mps.is_available(): + pytest.skip("MPS not available") + + proc = TorchModelProcessor( + model_class=DUMMY_MODEL_CLASS, + device=device, + model_kwargs={ + "output_size": 2, + }, + ) + proc(batch_message) + model = proc._state.model + for param in model.parameters(): + assert param.device.type == device + + +@pytest.mark.parametrize("batch_size", [1, 5, 10]) +def test_batch_processing(batch_size): + input_dim = 4 + output_dim = 2 + data = np.random.randn(batch_size, input_dim) + + msg = AxisArray( + data=data, + dims=["time", "ch"], + axes={ + "time": AxisArray.TimeAxis(fs=100.0), + "ch": AxisArray.CoordinateAxis(data=np.arange(input_dim), dims=["ch"]), + }, + ) + proc = TorchModelProcessor( + model_class=DUMMY_MODEL_CLASS, + model_kwargs={ + "input_size": input_dim, + "output_size": output_dim, + }, + ) + out = proc(msg)[0] + assert out.data.shape[0] == batch_size + assert out.data.shape[-1] == output_dim + + +def test_input_size_mismatch_raises_error(): + input_dim = 6 + wrong_input_dim = 4 + data = np.random.randn(10, input_dim) + msg = AxisArray( + data=data, + dims=["time", "ch"], + axes={ + "time": AxisArray.TimeAxis(fs=100.0), + "ch": AxisArray.CoordinateAxis(data=np.arange(input_dim), dims=["ch"]), + }, + ) + with pytest.raises(ValueError, match="Mismatch.*input_size.*"): + TorchModelProcessor( + model_class=DUMMY_MODEL_CLASS, + model_kwargs={ + "input_size": wrong_input_dim, + "output_size": 2, + }, + )(msg) + + +def test_multihead_output(batch_message): + proc = TorchModelProcessor( + model_class=MULTIHEAD_MODEL_CLASS, + model_kwargs={"input_size": batch_message.data.shape[1]}, + ) + results = proc(batch_message) + + keys = {r.key for r in results} + assert keys == {"head_a", "head_b"} + for r in results: + assert r.data.ndim == 2 + + +def test_multihead_partial_fit_with_loss_dict(batch_message): + proc = TorchModelProcessor( + model_class=MULTIHEAD_MODEL_CLASS, + loss_fn={ + "head_a": torch.nn.MSELoss(), + "head_b": torch.nn.L1Loss(), + }, + model_kwargs={"input_size": batch_message.data.shape[1]}, + ) + + proc(batch_message) # initialize model + + y_targ = { + "head_a": np.random.randn(1, 2), + "head_b": np.random.randn(1, 3), + } + sample = AxisArray( + data=batch_message.data[:1], + dims=["time", "ch"], + axes=batch_message.axes, + ) + msg = SampleMessage( + sample=sample, + trigger=SampleTriggerMessage(timestamp=0.0, value=y_targ), + ) + + before_a = proc._state.model.head_a.weight.clone() + before_b = proc._state.model.head_b.weight.clone() + + proc.partial_fit(msg) + + after_a = proc._state.model.head_a.weight + after_b = proc._state.model.head_b.weight + + assert not torch.allclose(before_a, after_a) + assert not torch.allclose(before_b, after_b) + + +def test_partial_fit_with_loss_weights(batch_message): + proc = TorchModelProcessor( + model_class=MULTIHEAD_MODEL_CLASS, + loss_fn={ + "head_a": torch.nn.MSELoss(), + "head_b": torch.nn.MSELoss(), + }, + loss_weights={ + "head_a": 2.0, + "head_b": 0.5, + }, + model_kwargs={"input_size": batch_message.data.shape[1]}, + ) + proc(batch_message) + + y_targ = { + "head_a": np.random.randn(1, 2), + "head_b": np.random.randn(1, 3), + } + sample = AxisArray( + data=batch_message.data[:1], + dims=["time", "ch"], + axes=batch_message.axes, + ) + msg = SampleMessage( + sample=sample, + trigger=SampleTriggerMessage(timestamp=0.0, value=y_targ), + ) + + # Expect no error, and just run once + proc.partial_fit(msg) From fa8d03d29f075cd22e000ae5e83a1eb698fb079a Mon Sep 17 00:00:00 2001 From: kylmcgr Date: Fri, 22 Aug 2025 11:58:40 -0400 Subject: [PATCH 04/16] First commit of MLP for use with generic torch processor --- src/ezmsg/learn/model/mlp.py | 133 ++++++++++++++++++ tests/integration/test_mlp_system.py | 68 +++++++++ tests/unit/test_mlp.py | 199 +++++++++++++++++++++++++++ 3 files changed, 400 insertions(+) create mode 100644 src/ezmsg/learn/model/mlp.py create mode 100644 tests/integration/test_mlp_system.py create mode 100644 tests/unit/test_mlp.py diff --git a/src/ezmsg/learn/model/mlp.py b/src/ezmsg/learn/model/mlp.py new file mode 100644 index 0000000..194709a --- /dev/null +++ b/src/ezmsg/learn/model/mlp.py @@ -0,0 +1,133 @@ +import torch +import torch.nn + + +class MLP(torch.nn.Module): + """ + A simple Multi-Layer Perceptron (MLP) model. Adapted from Ezmsg MLP. + + Attributes: + feature_extractor (torch.nn.Sequential): The sequential feature extractor part of the MLP. + heads (torch.nn.ModuleDict): A dictionary of output linear layers for each output head. + """ + + def __init__( + self, + input_size: int, + hidden_size: int | list[int], + num_layers: int | None = None, + output_heads: int | dict[str, int] = 2, + norm_layer: str | None = None, + activation_layer: str | None = "ReLU", + inplace: bool | None = None, + bias: bool = True, + dropout: float = 0.0, + ): + """ + Initialize the MLP model. + Args: + input_size (int): The size of the input features. + hidden_size (int | list[int]): The sizes of the hidden layers. If a list, num_layers must be None or the length + of the list. If a single integer, num_layers must be specified and determines the number of hidden layers. + num_layers (int, optional): The number of hidden layers. Length of hidden_size if None. Default is None. + output_heads (int | dict[str, int], optional): Number of output features or classes if single head output or a + dictionary mapping head names to output sizes if multi-head output. Default is 2 (single head). + norm_layer (str, optional): A normalization layer to be applied after each linear layer. Default is None. + Common choices are "BatchNorm1d" or "LayerNorm". + activation_layer (str, optional): An activation function to be applied after each normalization + layer. Default is "ReLU". + inplace (bool, optional): Whether the activation function is performed in-place. Default is None. + bias (bool, optional): Whether to use bias in the linear layers. Default is True. + dropout (float, optional): The dropout rate to be applied after each linear layer. Default is 0.0. + """ + super().__init__() + if isinstance(hidden_size, int): + if num_layers is None: + raise ValueError( + "If hidden_size is an integer, num_layers must be specified." + ) + hidden_size = [hidden_size] * num_layers + if len(hidden_size) == 0: + raise ValueError("hidden_size must have at least one element") + if any(not isinstance(x, int) for x in hidden_size): + raise ValueError("hidden_size must contain only integers") + if num_layers is not None and len(hidden_size) != num_layers: + raise ValueError( + "Length of hidden_size must match num_layers if num_layers is specified." + ) + + params = {} if inplace is None else {"inplace": inplace} + + layers = [] + in_dim = input_size + + def _get_layer_class(layer_name: str): + if layer_name is not None and "torch.nn" in layer_name: + return getattr(torch.nn, layer_name.rsplit(".", 1)[1]) + return None + + norm_layer_class = _get_layer_class(norm_layer) + activation_layer_class = _get_layer_class(activation_layer) + for hidden_dim in hidden_size[:-1]: + layers.append(torch.nn.Linear(in_dim, hidden_dim, bias=bias)) + if norm_layer_class is not None: + layers.append(norm_layer_class(hidden_dim)) + if activation_layer_class is not None: + layers.append(activation_layer_class(**params)) + layers.append(torch.nn.Dropout(dropout, **params)) + in_dim = hidden_dim + + layers.append(torch.nn.Linear(in_dim, hidden_size[-1], bias=bias)) + + self.feature_extractor = torch.nn.Sequential(*layers) + + if isinstance(output_heads, int): + output_heads = {"output": output_heads} + self.heads = torch.nn.ModuleDict( + { + name: torch.nn.Linear(hidden_size[-1], output_size) + for name, output_size in output_heads.items() + } + ) + + @classmethod + def infer_config_from_state_dict(cls, state_dict: dict) -> dict[str, int | float]: + """ + Infer the configuration from the state dict. + + Args: + state_dict: The state dict of the model. + + Returns: + dict[str, int | float]: A dictionary containing the inferred configuration. + """ + input_size = state_dict["feature_extractor.0.weight"].shape[1] + hidden_size = [ + param.shape[0] + for key, param in state_dict.items() + if key.startswith("feature_extractor.") and key.endswith(".weight") + ] + output_heads = { + key.split(".")[1]: param.shape[0] + for key, param in state_dict.items() + if key.startswith("heads.") and key.endswith(".bias") + } + + return { + "input_size": input_size, + "hidden_size": hidden_size, + "output_heads": output_heads, + } + + def forward(self, x: torch.Tensor) -> dict[str, torch.Tensor]: + """ + Forward pass through the MLP. + + Args: + x (torch.Tensor): Input tensor of shape (batch, seq_len, input_size). + + Returns: + dict[str, torch.Tensor]: A dictionary mapping head names to output tensors. + """ + x = self.feature_extractor(x) + return {name: head(x) for name, head in self.heads.items()} diff --git a/tests/integration/test_mlp_system.py b/tests/integration/test_mlp_system.py new file mode 100644 index 0000000..2673155 --- /dev/null +++ b/tests/integration/test_mlp_system.py @@ -0,0 +1,68 @@ +import os +import tempfile +from pathlib import Path + +import ezmsg.core as ez +from ezmsg.sigproc.synth import Counter, CounterSettings +from ezmsg.util.messagecodec import message_log +from ezmsg.util.messagelogger import MessageLogger, MessageLoggerSettings +from ezmsg.util.messages.axisarray import AxisArray +from ezmsg.util.terminate import TerminateOnTotal, TerminateOnTotalSettings + +from ezmsg.learn.process.torch import TorchModelUnit + + +def test_torch_model_unit_system(): + fs = 10.0 + block_size = 4 + duration = 2.0 # seconds + input_size = 3 + output_size = 2 + + test_filename = Path(tempfile.gettempdir()) + test_filename = test_filename / Path("test_mlp_system.txt") + with open(test_filename, "w"): + pass + ez.logger.info(f"Logging to {test_filename}") + + comps = { + "SRC": Counter( + CounterSettings( + fs=fs, + n_ch=input_size, + n_time=block_size, + dispatch_rate=duration, + mod=None, + ) + ), + "MODEL": TorchModelUnit( + model_class="ezmsg.learn.model.mlp.MLP", + model_kwargs={ + "input_size": input_size, + "hidden_size": [5, output_size], + }, + device="cpu", + ), + "LOG": MessageLogger(MessageLoggerSettings(output=test_filename)), + "TERM": TerminateOnTotal( + TerminateOnTotalSettings(total=int(duration * fs / block_size)) + ), + } + + conns = ( + (comps["SRC"].OUTPUT_SIGNAL, comps["MODEL"].INPUT_SIGNAL), + (comps["MODEL"].OUTPUT_SIGNAL, comps["LOG"].INPUT_MESSAGE), + (comps["LOG"].OUTPUT_MESSAGE, comps["TERM"].INPUT_MESSAGE), + ) + + ez.run(components=comps, connections=conns) + + # Read from message log + messages: list[AxisArray] = [_ for _ in message_log(test_filename)] + os.remove(test_filename) + + # Check basic structure + assert all(msg.data.shape[-1] == output_size for msg in messages) + assert all("time" in msg.dims and "ch" in msg.dims for msg in messages) + assert all("ch" in msg.axes for msg in messages) + assert messages[0].axes["ch"].data.shape[0] == output_size diff --git a/tests/unit/test_mlp.py b/tests/unit/test_mlp.py new file mode 100644 index 0000000..1479068 --- /dev/null +++ b/tests/unit/test_mlp.py @@ -0,0 +1,199 @@ +import numpy as np +import pytest +import torch +from ezmsg.util.messages.axisarray import AxisArray + +from ezmsg.learn.process.torch import TorchModelProcessor + + +@pytest.fixture +def mlp_settings(): + return { + "input_size": 8, + "hidden_size": [16, 32], + "output_heads": 5, + "activation_layer": "ReLU", + "norm_layer": None, + "dropout": 0.1, + } + + +@pytest.fixture +def sample_input(): + data = np.random.randn(64, 8) + return AxisArray( + data=data, + dims=["time", "ch"], + axes={ + "time": AxisArray.TimeAxis(fs=50.0), + "ch": AxisArray.CoordinateAxis(data=np.arange(data.shape[1]), dims=["ch"]), + }, + key="test_input", + ) + + +@pytest.fixture +def mlp_processor(mlp_settings): + return TorchModelProcessor( + model_class="ezmsg.learn.model.mlp.MLP", + model_kwargs=mlp_settings, + device="cpu", + ) + + +def test_mlp_forward_output_shape(sample_input, mlp_processor): + result = mlp_processor(sample_input)[0] + assert isinstance(result, AxisArray) + assert result.data.shape[0] == sample_input.data.shape[0] + assert result.data.shape[1] == 5 + assert "ch" in result.axes + assert result.get_axis("ch").data.shape[0] == 5 + + +def test_mlp_checkpoint_io(tmp_path, sample_input, mlp_settings): + ckpt_file = tmp_path / "mlp_checkpoint.pth" + + proc1 = TorchModelProcessor( + model_class="ezmsg.learn.model.mlp.MLP", + model_kwargs=mlp_settings, + device="cpu", + ) + proc1(sample_input) + proc1.save_checkpoint(str(ckpt_file)) + + proc2 = TorchModelProcessor( + model_class="ezmsg.learn.model.mlp.MLP", + checkpoint_path=str(ckpt_file), + model_kwargs=mlp_settings, + device="cpu", + ) + proc2(sample_input) + + state1 = proc1._state.model.state_dict() + state2 = proc2._state.model.state_dict() + + for key in state1: + assert torch.allclose(state1[key], state2[key], atol=1e-6) + + +def test_mlp_partial_fit_learns(sample_input, mlp_settings): + from ezmsg.sigproc.sampler import SampleMessage, SampleTriggerMessage + + proc = TorchModelProcessor( + model_class="ezmsg.learn.model.mlp.MLP", + model_kwargs=mlp_settings, + loss_fn=torch.nn.MSELoss(), + learning_rate=0.01, + device="cpu", + ) + proc(sample_input) + + sample = AxisArray( + data=sample_input.data[:1], dims=["time", "ch"], axes=sample_input.axes + ) + target = np.random.randn(1, 5) + + msg = SampleMessage( + sample=sample, trigger=SampleTriggerMessage(timestamp=0.0, value=target) + ) + + before = [p.detach().clone() for p in proc.state.model.parameters()] + proc.partial_fit(msg) + after = [p.detach().clone() for p in proc.state.model.parameters()] + + assert not all(torch.allclose(b, a) for b, a in zip(before, after)) + + +@pytest.mark.parametrize("device", ["cpu", "cuda", "mps"]) +def test_mlp_runs_on_available_devices(device, sample_input, mlp_settings): + if device == "cuda" and not torch.cuda.is_available(): + pytest.skip("CUDA not available") + if device == "mps" and not torch.backends.mps.is_available(): + pytest.skip("MPS not available") + + proc = TorchModelProcessor( + model_class="ezmsg.learn.model.mlp.MLP", + model_kwargs=mlp_settings, + device=device, + ) + proc(sample_input) + for p in proc._state.model.parameters(): + assert p.device.type == device + + +def test_mlp_hidden_size_integer(sample_input): + proc = TorchModelProcessor( + model_class="ezmsg.learn.model.mlp.MLP", + model_kwargs={ + "input_size": 8, + "hidden_size": 32, + "num_layers": 3, + "output_heads": 5, + "activation_layer": "ReLU", + "dropout": 0.1, + }, + device="cpu", + ) + proc(sample_input) + hidden_layers = [ + m for m in proc._state.model.modules() if isinstance(m, torch.nn.Linear) + ][:-1] # Exclude the output head + assert len(hidden_layers) == 3 # num_layers = 3 + assert hidden_layers[0].in_features == 8 + assert all(layer.out_features == 32 for layer in hidden_layers[:-1]) + assert hidden_layers[-1].out_features == 32 + + +def test_mlp_rnn_style_missing_num_layers_raises(sample_input): + with pytest.raises(ValueError, match="num_layers must be specified"): + proc = TorchModelProcessor( + model_class="ezmsg.learn.model.mlp.MLP", + model_kwargs={ + "input_size": 8, + "hidden_size": 64, + "activation_layer": "torch.nn.ReLU", + }, + ) + proc(sample_input) + + +def test_mlp_list_hidden_size_with_num_layers_mismatch(sample_input): + with pytest.raises(ValueError, match="Length of hidden_size must match num_layers"): + proc = TorchModelProcessor( + model_class="ezmsg.learn.model.mlp.MLP", + model_kwargs={ + "input_size": 8, + "hidden_size": [32, 64, 10], + "num_layers": 2, # Mismatch: len(hidden_size) = 3 + }, + ) + proc(sample_input) + + +def test_mlp_empty_hidden_size_list(sample_input): + with pytest.raises(ValueError, match="hidden_size must have at least one element"): + proc = TorchModelProcessor( + model_class="ezmsg.learn.model.mlp.MLP", + model_kwargs={ + "input_size": 8, + "hidden_size": [], + }, + ) + proc(sample_input) + + +def test_mlp_multihead_output_keys(sample_input): + proc = TorchModelProcessor( + model_class="ezmsg.learn.model.mlp.MLP", + model_kwargs={ + "input_size": 8, + "hidden_size": [32], + "output_heads": {"a": 3, "b": 2}, + }, + device="cpu", + ) + outputs = proc(sample_input) + keys = {o.key for o in outputs} + assert keys == {"a", "b"} + for o in outputs: + assert o.data.shape[0] == sample_input.data.shape[0] From 30dedb82e2565f189c6e52aff15f36870de4e6c8 Mon Sep 17 00:00:00 2001 From: kylmcgr Date: Fri, 22 Aug 2025 11:59:47 -0400 Subject: [PATCH 05/16] First commit of RNN model and custom RNN processor --- src/ezmsg/learn/model/rnn.py | 160 ++++++++++++ src/ezmsg/learn/process/rnn.py | 266 ++++++++++++++++++++ tests/integration/test_rnn_system.py | 75 ++++++ tests/unit/test_rnn.py | 361 +++++++++++++++++++++++++++ 4 files changed, 862 insertions(+) create mode 100644 src/ezmsg/learn/model/rnn.py create mode 100644 src/ezmsg/learn/process/rnn.py create mode 100644 tests/integration/test_rnn_system.py create mode 100644 tests/unit/test_rnn.py diff --git a/src/ezmsg/learn/model/rnn.py b/src/ezmsg/learn/model/rnn.py new file mode 100644 index 0000000..2184987 --- /dev/null +++ b/src/ezmsg/learn/model/rnn.py @@ -0,0 +1,160 @@ +from typing import Optional + +import torch + + +class RNNModel(torch.nn.Module): + """ + Recurrent neural network supporting GRU, LSTM, and vanilla RNN (tanh/relu). + + Attributes: + input_size (int): Number of input features per time step. + hidden_size (int): Number of hidden units in the RNN cell. + num_layers (int, optional): Number of RNN layers. Default is 1. + output_size (int | dict[str, int], optional): Number of output features or classes if single head output or a + dictionary mapping head names to output sizes if multi-head output. Default is 2 (single head). + dropout (float, optional): Dropout rate applied after input and RNN output. Default is 0.3. + rnn_type (str, optional): Type of RNN cell to use: 'GRU', 'LSTM', 'RNN-Tanh', 'RNN-ReLU'. Default is 'GRU'. + + Returns: + dict[str, torch.Tensor]: Dictionary of decoded predictions mapping head names to tensors of shape + (batch, seq_len, output_size). If single head output, the key is "output". + """ + + def __init__( + self, + input_size: int, + hidden_size: int, + num_layers: int = 1, + output_size: int | dict[str, int] = 2, + dropout: float = 0.3, + rnn_type: str = "GRU", + ): + super().__init__() + self.linear_embeddings = torch.nn.Linear(input_size, input_size) + self.dropout_input = torch.nn.Dropout(dropout) + + rnn_klass_str = rnn_type.upper().split("-")[0] + if rnn_klass_str not in ["GRU", "LSTM", "RNN"]: + raise ValueError(f"Unrecognized rnn_type: {rnn_type}") + rnn_klass = {"GRU": torch.nn.GRU, "LSTM": torch.nn.LSTM, "RNN": torch.nn.RNN}[ + rnn_klass_str + ] + rnn_kwargs = {} + if rnn_klass_str == "RNN": + rnn_kwargs["nonlinearity"] = rnn_type.lower().split("-")[-1] + self.rnn = rnn_klass( + input_size, + hidden_size, + num_layers, + batch_first=True, + dropout=dropout if num_layers > 1 else 0.0, + **rnn_kwargs, + ) + self.rnn_type = rnn_klass_str + + self.output_dropout = torch.nn.Dropout(dropout) + if isinstance(output_size, int): + output_size = {"output": output_size} + self.heads = torch.nn.ModuleDict( + { + name: torch.nn.Linear(hidden_size, size) + for name, size in output_size.items() + } + ) + + @classmethod + def infer_config_from_state_dict( + cls, state_dict: dict, rnn_type: str = "GRU" + ) -> dict[str, int | float]: + """ + This method is specific to each processor. + + Args: + state_dict: The state dict of the model. + rnn_type: The type of RNN used in the model (e.g., 'GRU', 'LSTM', 'RNN-Tanh', 'RNN-ReLU'). + + Returns: + A dictionary of model parameters obtained from the state dict. + + """ + output_size = { + key.split(".")[1]: param.shape[0] + for key, param in state_dict.items() + if key.startswith("heads.") and key.endswith(".bias") + } + + return { + # Infer input_size from linear_embeddings.weight (shape: [input_size, input_size]) + "input_size": state_dict["linear_embeddings.weight"].shape[1], + # Infer hidden_size from rnn.weight_ih_l0 (shape: [hidden_size * 3, input_size]) + "hidden_size": state_dict["rnn.weight_ih_l0"].shape[0] + // cls._get_gate_count(rnn_type), + # Infer num_layers by counting rnn layers in state_dict (e.g., weight_ih_l) + "num_layers": sum(1 for key in state_dict if "rnn.weight_ih_l" in key), + "output_size": output_size, + } + + @staticmethod + def _get_gate_count(rnn_type: str) -> int: + if rnn_type.upper() == "GRU": + return 3 + elif rnn_type.upper() == "LSTM": + return 4 + elif rnn_type.upper().startswith("RNN"): + return 1 + else: + raise ValueError(f"Unsupported rnn_type: {rnn_type}") + + def init_hidden(self, batch_size: int, device: torch.device) -> torch.Tensor: + """ + Initialize the hidden state for the RNN. + Args: + batch_size (int): Size of the batch. + device (torch.device): Device to place the hidden state on (e.g., 'cpu' or 'cuda'). + Returns: + torch.Tensor | tuple[torch.Tensor, torch.Tensor]: Initial hidden state for the RNN. + For LSTM, returns a tuple of (h_n, c_n) where h_n is the hidden state and c_n is the cell state. + For GRU or vanilla RNN, returns just h_n. + """ + shape = (self.rnn.num_layers, batch_size, self.rnn.hidden_size) + if self.rnn_type == "LSTM": + return ( + torch.zeros(shape, device=device), + torch.zeros(shape, device=device), + ) + else: + return torch.zeros(shape, device=device) + + def forward( + self, + x: torch.Tensor, + input_lens: Optional[torch.Tensor] = None, + hx: Optional[torch.Tensor | tuple[torch.Tensor, torch.Tensor]] = None, + ) -> tuple[dict[str, torch.Tensor], torch.Tensor | tuple]: + """ + Forward pass through the RNN model. + Args: + x (torch.Tensor): Input tensor of shape (batch, seq_len, input_size). + input_lens (Optional[torch.Tensor]): Optional tensor of lengths for each sequence in the batch. + If provided, sequences will be packed before passing through the RNN. + hx (Optional[torch.Tensor | tuple[torch.Tensor, torch.Tensor]]): Optional initial hidden state for the RNN. + Returns: + tuple[dict[str, torch.Tensor], torch.Tensor | tuple]: + A dictionary mapping head names to output tensors of shape (batch, seq_len, output_size). + If the RNN is LSTM, the second element is the hidden state (h_n, c_n) or just h_n if GRU. + """ + x = self.linear_embeddings(x) + x = self.dropout_input(x) + total_length = x.shape[1] + if input_lens is not None: + x = torch.nn.utils.rnn.pack_padded_sequence( + x, input_lens, batch_first=True, enforce_sorted=False + ) + x_out, hx_out = self.rnn(x, hx) + if input_lens is not None: + x_out, _ = torch.nn.utils.rnn.pad_packed_sequence( + x_out, batch_first=True, total_length=total_length + ) + x_out = self.output_dropout(x_out) + return {name: head(x_out) for name, head in self.heads.items()}, hx_out diff --git a/src/ezmsg/learn/process/rnn.py b/src/ezmsg/learn/process/rnn.py new file mode 100644 index 0000000..35b0d7c --- /dev/null +++ b/src/ezmsg/learn/process/rnn.py @@ -0,0 +1,266 @@ +import typing + +import ezmsg.core as ez +import numpy as np +import torch +from ezmsg.sigproc.base import BaseAdaptiveTransformer, BaseAdaptiveTransformerUnit +from ezmsg.sigproc.sampler import SampleMessage +from ezmsg.sigproc.util.profile import profile_subpub +from ezmsg.util.messages.axisarray import AxisArray +from ezmsg.util.messages.util import replace + +from .base import ModelInitMixin +from .torch import ( + TorchModelSettings, + TorchModelState, + TorchProcessorMixin, +) + + +class RNNSettings(TorchModelSettings): + model_class: str = "ezmsg.learn.model.rnn.RNNModel" + """ + Fully qualified class path of the model to be used. + This should be "ezmsg.learn.model.rnn.RNNModel" for this. + """ + reset_hidden_on_fit: bool = True + """ + Whether to reset the hidden state on each fit call. + If True, the hidden state will be reset to zero after each fit. + If False, the hidden state will be maintained across fit calls. + """ + preserve_state_across_windows: bool | typing.Literal["auto"] = "auto" + """ + Whether to preserve the hidden state across windows. + If True, the hidden state will be preserved across windows. + If False, the hidden state will be reset at the start of each window. + If "auto", preserve if there is no overlap in time windows, otherwise reset. + """ + + +class RNNState(TorchModelState): + hx: typing.Optional[torch.Tensor] = None + + +class RNNProcessor( + BaseAdaptiveTransformer[RNNSettings, AxisArray, AxisArray, RNNState], + TorchProcessorMixin, + ModelInitMixin, +): + def _infer_output_sizes( + self, model: torch.nn.Module, n_input: int + ) -> dict[str, int]: + """Simple inference to get output channel size.""" + dummy_input = torch.zeros(1, 50, n_input, device=self._state.device) + with torch.no_grad(): + output, _ = model(dummy_input) + + if isinstance(output, dict): + return {k: v.shape[-1] for k, v in output.items()} + else: + return {"output": output.shape[-1]} + + def _reset_state(self, message: AxisArray) -> None: + model_kwargs = dict(self.settings.model_kwargs or {}) + self._common_reset_state(message, model_kwargs) + self._init_optimizer() + self._validate_loss_keys(list(self._state.chan_ax.keys())) + + batch_size = 1 if message.data.ndim == 2 else message.data.shape[0] + self.reset_hidden(batch_size) + + def _maybe_reset_state(self, message: AxisArray, batch_size: int) -> bool: + preserve_state = self.settings.preserve_state_across_windows + if preserve_state == "auto": + axes = message.axes + if batch_size < 2: + # Single window, so preserve + preserve_state = True + elif "time" not in axes or "win" not in axes: + # Default fallback + ez.logger.warning( + "Missing 'time' or 'win' axis for auto preserve-state logic. Defaulting to reset." + ) + preserve_state = False + else: + # Calculate stride between windows (assuming uniform spacing) + win_stride = axes["win"].gain + # Calculate window length from time axis length and gain + time_len = message.data.shape[message.get_axis_idx("time")] + gain = getattr(axes["time"], "gain", None) + if gain is None: + ez.logger.warning( + "Time axis gain not found, using default gain of 1.0." + ) + gain = 1.0 # fallback default + win_len = time_len * gain + # Determine if we should preserve state + preserve_state = win_stride >= win_len + + # Preserve if windows do NOT overlap: stride >= window length + if not preserve_state: + self.reset_hidden(batch_size) + else: + # If preserving state, only reset if batch size isn't 1 + hx_batch_size = ( + self._state.hx[0].shape[1] + if isinstance(self._state.hx, tuple) + else self._state.hx.shape[1] + ) + if hx_batch_size != 1: + ez.logger.debug( + f"Resetting hidden state due to batch size mismatch (hx: {hx_batch_size}, new: 1)" + ) + self.reset_hidden(1) + return preserve_state + + def _process(self, message: AxisArray) -> list[AxisArray]: + x = message.data + if not isinstance(x, torch.Tensor): + x = torch.tensor( + x, + dtype=torch.float32 + if self.settings.single_precision + else torch.float64, + device=self._state.device, + ) + + # Add batch dimension if missing + x, added_batch_dim = self._ensure_batched(x) + + batch_size = x.shape[0] + preserve_state = self._maybe_reset_state(message, batch_size) + + with torch.no_grad(): + # If we are preserving state and have multiple batches, process sequentially + if preserve_state and batch_size > 1: + y_data = {} + for x_batch in x: + x_batch = x_batch.unsqueeze(0) + y, self._state.hx = self._state.model(x_batch, hx=self._state.hx) + for key, out in y.items(): + if key not in y_data: + y_data[key] = [] + y_data[key].append(out.cpu().numpy()) + # Concatenate outputs for each key + y_data = { + key: np.concatenate(outputs, axis=0) + for key, outputs in y_data.items() + } + else: + y, self._state.hx = self._state.model(x, hx=self._state.hx) + y_data = { + key: ( + out.cpu().numpy().squeeze(0) + if added_batch_dim + else out.cpu().numpy() + ) + for key, out in y.items() + } + + return [ + replace( + message, + data=out, + axes={**message.axes, "ch": self._state.chan_ax[key]}, + key=key, + ) + for key, out in y_data.items() + ] + + def reset_hidden(self, batch_size: int) -> None: + self._state.hx = self._state.model.init_hidden(batch_size, self._state.device) + + def _train_step( + self, + X: torch.Tensor, + y_targ: dict[str, torch.Tensor], + loss_fns: dict[str, torch.nn.Module], + ) -> None: + y_pred, self._state.hx = self._state.model(X, hx=self._state.hx) + if not isinstance(y_pred, dict): + y_pred = {"output": y_pred} + + loss_weights = self.settings.loss_weights or {} + losses = [] + for key in y_targ.keys(): + loss_fn = loss_fns.get(key) + if loss_fn is None: + raise ValueError(f"Loss function for key '{key}' is not defined.") + if isinstance(loss_fn, torch.nn.CrossEntropyLoss): + loss = loss_fn(y_pred[key].permute(0, 2, 1), y_targ[key].long()) + else: + loss = loss_fn(y_pred[key], y_targ[key]) + weight = loss_weights.get(key, 1.0) + losses.append(loss * weight) + + total_loss = sum(losses) + ez.logger.debug( + f"Training step loss: {total_loss.item()} (individual losses: {[loss.item() for loss in losses]})" + ) + + self._state.optimizer.zero_grad() + total_loss.backward() + self._state.optimizer.step() + if self._state.scheduler is not None: + self._state.scheduler.step() + + def partial_fit(self, message: SampleMessage) -> None: + self._state.model.train() + + X = self._to_tensor(message.sample.data) + + # Add batch dimension if missing + X, batched = self._ensure_batched(X) + + batch_size = X.shape[0] + preserve_state = self._maybe_reset_state(message.sample, batch_size) + + y_targ = message.trigger.value + if not isinstance(y_targ, dict): + y_targ = {"output": y_targ} + y_targ = {key: self._to_tensor(value) for key, value in y_targ.items()} + # Add batch dimension to y_targ values if missing + if batched: + for key in y_targ: + y_targ[key] = y_targ[key].unsqueeze(0) + + loss_fns = self.settings.loss_fn + if loss_fns is None: + raise ValueError("loss_fn must be provided in settings to use partial_fit") + if not isinstance(loss_fns, dict): + loss_fns = {k: loss_fns for k in y_targ.keys()} + + with torch.set_grad_enabled(True): + if preserve_state: + self._train_step(X, y_targ, loss_fns) + else: + for i in range(batch_size): + self._train_step( + X[i].unsqueeze(0), + {key: value[i].unsqueeze(0) for key, value in y_targ.items()}, + loss_fns, + ) + + self._state.model.eval() + if self.settings.reset_hidden_on_fit: + self.reset_hidden(X.shape[0]) + + +class RNNUnit( + BaseAdaptiveTransformerUnit[ + RNNSettings, + AxisArray, + AxisArray, + RNNProcessor, + ] +): + SETTINGS = RNNSettings + + @ez.subscriber(BaseAdaptiveTransformerUnit.INPUT_SIGNAL, zero_copy=True) + @ez.publisher(BaseAdaptiveTransformerUnit.OUTPUT_SIGNAL) + @profile_subpub(trace_oldest=False) + async def on_signal(self, message: AxisArray) -> typing.AsyncGenerator: + results = await self.processor.__acall__(message) + for result in results: + yield self.OUTPUT_SIGNAL, result diff --git a/tests/integration/test_rnn_system.py b/tests/integration/test_rnn_system.py new file mode 100644 index 0000000..d700d9e --- /dev/null +++ b/tests/integration/test_rnn_system.py @@ -0,0 +1,75 @@ +import os +import tempfile +from pathlib import Path + +import ezmsg.core as ez +from ezmsg.sigproc.synth import Counter, CounterSettings +from ezmsg.util.messagecodec import message_log +from ezmsg.util.messagelogger import MessageLogger, MessageLoggerSettings +from ezmsg.util.messages.axisarray import AxisArray +from ezmsg.util.terminate import TerminateOnTotal, TerminateOnTotalSettings + +from ezmsg.learn.process.rnn import RNNUnit + + +def test_torch_model_unit_system(): + fs = 10.0 + block_size = 4 + duration = 2.0 # seconds + input_size = 3 + output_size = 2 + hidden_size = 30 + num_layers = 2 + single_precision = True + + test_filename = Path(tempfile.gettempdir()) + test_filename = test_filename / Path("test_torch_system.txt") + with open(test_filename, "w"): + pass + ez.logger.info(f"Logging to {test_filename}") + + comps = { + "SRC": Counter( + CounterSettings( + fs=fs, + n_ch=input_size, + n_time=block_size, + dispatch_rate=duration, + mod=None, + ) + ), + "MODEL": RNNUnit( + single_precision=single_precision, + learning_rate=1e-2, + scheduler_gamma=0.0, + weight_decay=0.0, + device="cpu", + model_kwargs={ + "hidden_size": hidden_size, + "num_layers": num_layers, + "output_size": output_size, + }, + ), + "LOG": MessageLogger(MessageLoggerSettings(output=test_filename)), + "TERM": TerminateOnTotal( + TerminateOnTotalSettings(total=int(duration * fs / block_size)) + ), + } + + conns = ( + (comps["SRC"].OUTPUT_SIGNAL, comps["MODEL"].INPUT_SIGNAL), + (comps["MODEL"].OUTPUT_SIGNAL, comps["LOG"].INPUT_MESSAGE), + (comps["LOG"].OUTPUT_MESSAGE, comps["TERM"].INPUT_MESSAGE), + ) + + ez.run(components=comps, connections=conns) + + # Read from message log + messages: list[AxisArray] = [_ for _ in message_log(test_filename)] + os.remove(test_filename) + + # Check basic structure + assert all(msg.data.shape[-1] == output_size for msg in messages) + assert all("time" in msg.dims and "ch" in msg.dims for msg in messages) + assert all("ch" in msg.axes for msg in messages) + assert messages[0].axes["ch"].data.shape[0] == output_size diff --git a/tests/unit/test_rnn.py b/tests/unit/test_rnn.py new file mode 100644 index 0000000..1a308d1 --- /dev/null +++ b/tests/unit/test_rnn.py @@ -0,0 +1,361 @@ +import tempfile + +import numpy as np +import pytest +import torch +import torch.nn +from ezmsg.sigproc.sampler import SampleMessage, SampleTriggerMessage +from ezmsg.util.messages.axisarray import AxisArray + +from ezmsg.learn.process.rnn import RNNProcessor + + +@pytest.fixture +def simple_message() -> AxisArray: + n_ch = 192 + data = np.arange(100 * n_ch).reshape(100, n_ch) + ch_labels = np.array([f"ch{i}" for i in range(n_ch)]) + message = AxisArray( + data=data, + dims=["time", "ch"], + axes={ + "time": AxisArray.TimeAxis(fs=512.0), + "ch": AxisArray.CoordinateAxis(data=ch_labels, dims=["ch"]), + }, + key="test_rnn", + ) + return message + + +@pytest.mark.parametrize("rnn_type", ["GRU", "LSTM", "RNN-Tanh", "RNN-ReLU"]) +@pytest.mark.parametrize("hidden_size", [20, 30, 40]) +@pytest.mark.parametrize("num_layers", [1, 2, 3]) +@pytest.mark.parametrize("output_size", [2, 3]) +def test_rnn_init(rnn_type, hidden_size, num_layers, output_size, simple_message): + single_precision = True + + proc = RNNProcessor( + single_precision=single_precision, + device="cpu", + model_kwargs={ + "hidden_size": hidden_size, + "num_layers": num_layers, + "output_size": output_size, + "rnn_type": rnn_type, + }, + ) + + # Verify the settings were registered properly + assert proc.settings.single_precision == single_precision + + # The processor creates its model upon receipt of the first message. + proc(simple_message) + + # Verify the settings were incorporated into the model + mdl: torch.nn.Module = proc.state.model + + in_dim = simple_message.data.shape[simple_message.get_axis_idx("ch")] + + assert mdl.linear_embeddings.in_features == in_dim + assert mdl.rnn.input_size == in_dim + assert mdl.rnn.hidden_size == hidden_size + assert mdl.rnn.num_layers == num_layers + assert mdl.rnn.dropout == 0.3 if num_layers > 1 else mdl.rnn.dropout == 0.0 + assert mdl.heads["output"].out_features == output_size + + rnn_mdl = list(mdl.children())[2] + expected_module = { + "GRU": torch.nn.GRU, + "LSTM": torch.nn.LSTM, + "RNN-Tanh": torch.nn.RNN, + "RNN-ReLU": torch.nn.RNN, + }[rnn_type] + assert isinstance(rnn_mdl, expected_module) + + +@pytest.mark.parametrize("rnn_type", ["GRU", "LSTM", "RNN-Tanh", "RNN-ReLU"]) +def test_rnn_process(rnn_type, simple_message): + hidden_size = 16 + num_layers = 1 + output_size = 2 + single_precision = True + + proc = RNNProcessor( + single_precision=single_precision, + learning_rate=1e-2, + scheduler_gamma=0.0, + weight_decay=0.0, + device="cpu", + model_kwargs={ + "hidden_size": hidden_size, + "num_layers": num_layers, + "output_size": output_size, + "rnn_type": rnn_type, + }, + ) + + output = proc(simple_message)[0] + assert output.data.shape == simple_message.data.shape[:-1] + (output_size,) + if rnn_type == "LSTM": + for rnn_state in proc.state.hx: + assert torch.any(rnn_state != 0) + else: + assert torch.any(proc.state.hx != 0) + + # Try calling the model directly and compare the result. + # We don't pass in the hx state so it should be initialized to zeros, same as in the first call to proc. + in_tensor = torch.tensor(simple_message.data[None, ...], dtype=torch.float32) + with torch.no_grad(): + expected_result = ( + proc.state.model(in_tensor)[0]["output"].cpu().numpy().squeeze(0) + ) + assert np.allclose(output.data, expected_result) + + +def test_rnn_partial_fit(simple_message): + hidden_size = 16 + num_layers = 1 + output_size = 2 + single_precision = True + + proc = RNNProcessor( + single_precision=single_precision, + learning_rate=1e-2, + scheduler_gamma=0.0, + weight_decay=0.0, + loss_fn=torch.nn.MSELoss(), + device="cpu", + model_kwargs={ + "hidden_size": hidden_size, + "num_layers": num_layers, + "output_size": output_size, + }, + ) + + proc(simple_message) # Initialize the model + + initial_weights = [p.detach().clone() for p in proc.state.model.parameters()] + + target_shape = (simple_message.data.shape[0], output_size) + target_value = np.ones(target_shape, dtype=np.float32) + sample_message = SampleMessage( + trigger=SampleTriggerMessage(timestamp=0.0, value=target_value), + sample=simple_message, + ) + + proc(sample_message) + + assert not proc.state.model.training + updated_weights = [p.detach() for p in proc.state.model.parameters()] + + assert any( + not torch.equal(w0, w1) for w0, w1 in zip(initial_weights, updated_weights) + ) + + +def test_rnn_checkpoint_save_load(simple_message): + hidden_size = 16 + num_layers = 1 + output_size = 2 + single_precision = True + + proc = RNNProcessor( + single_precision=single_precision, + device="cpu", + model_kwargs={ + "hidden_size": hidden_size, + "num_layers": num_layers, + "output_size": output_size, + }, + ) + + # First pass to initialize model + proc(simple_message) + + # Save full checkpoint (state_dict + config) + with tempfile.NamedTemporaryFile(suffix=".pt") as tmp: + proc.save_checkpoint(tmp.name) + + # Load from checkpoint + proc2 = RNNProcessor( + checkpoint_path=tmp.name, + single_precision=single_precision, + device="cpu", + model_kwargs={ + "hidden_size": hidden_size, + "num_layers": num_layers, + "output_size": output_size, + }, + ) + + proc2(simple_message) + + # Compare state_dicts directly + state_dict1 = proc.state.model.state_dict() + state_dict2 = proc2.state.model.state_dict() + + for key in state_dict1: + assert key in state_dict2, f"Missing key {key} in loaded state_dict" + assert torch.equal(state_dict1[key], state_dict2[key]), ( + f"Mismatch in parameter {key}" + ) + + +def test_rnn_partial_fit_multiloss(simple_message): + hidden_size = 16 + num_layers = 1 + output_heads = {"traj": 3, "state": 3} + single_precision = True + + proc = RNNProcessor( + single_precision=single_precision, + learning_rate=1e-2, + scheduler_gamma=0.0, + weight_decay=0.0, + loss_fn={"traj": torch.nn.MSELoss(), "state": torch.nn.CrossEntropyLoss()}, + loss_weights={"traj": 1.0, "state": 1.0}, + device="cpu", + model_kwargs={ + "hidden_size": hidden_size, + "num_layers": num_layers, + "output_size": output_heads, + }, + ) + + output = proc(simple_message) + initial_weights = [p.detach().clone() for p in proc.state.model.parameters()] + + # Build targets + traj_target = torch.tensor( + np.random.randn(*output[0].data.shape), + dtype=torch.float32, + ) + state_target = torch.tensor( + np.random.randint(0, output_heads["state"], size=output[1].data.shape[:-1]), + dtype=torch.long, + ) + + sample_message = SampleMessage( + trigger=SampleTriggerMessage( + timestamp=0.0, + value={"traj": traj_target, "state": state_target}, + ), + sample=simple_message, + ) + + proc.partial_fit(sample_message) + + updated_weights = [p.detach() for p in proc.state.model.parameters()] + assert any( + not torch.equal(w0, w1) for w0, w1 in zip(initial_weights, updated_weights) + ) + + +@pytest.mark.parametrize( + "preserve_state_across_windows, win_stride, win_len, should_preserve", + [ + (True, 0.1, 0.1, True), + (False, 0.1, 0.1, False), + ("auto", 0.1, 0.1, True), # non-overlapping → preserve + ("auto", 0.05, 0.1, False), # overlapping → reset + ], +) +def test_rnn_preserve_state( + preserve_state_across_windows, win_stride, win_len, should_preserve +): + hidden_size = 16 + num_layers = 1 + output_size = 2 + single_precision = True + + proc = RNNProcessor( + single_precision=single_precision, + learning_rate=1e-2, + scheduler_gamma=0.0, + weight_decay=0.0, + device="cpu", + preserve_state_across_windows=preserve_state_across_windows, + model_kwargs={ + "hidden_size": hidden_size, + "num_layers": num_layers, + "output_size": output_size, + }, + ) + + fs = 512.0 + n_time = int(fs * win_len) + n_win = 3 + n_ch = 192 + ch_labels = np.array([f"ch{i}" for i in range(n_ch)]) + + data = np.random.randn(n_win, n_time, n_ch).astype(np.float32) + + msg = AxisArray( + data=data, + dims=["win", "time", "ch"], + axes={ + "time": AxisArray.TimeAxis(fs=fs), + "win": AxisArray.LinearAxis(unit="s", gain=win_stride), + "ch": AxisArray.CoordinateAxis(data=ch_labels, dims=["ch"]), + }, + key="test_auto_param", + ) + + val0 = proc(msg)[0].data + + val1 = proc(msg)[0].data + + # Values should be the same when state is reset but different if state is preserved + if should_preserve: + assert not np.allclose(val0, val1) + else: + assert np.allclose(val0, val1) + + +def test_rnn_preserve_state_batch_size_change(): + hidden_size = 8 + output_size = 2 + n_ch = 192 + ch_labels = np.array([f"ch{i}" for i in range(n_ch)]) + + proc = RNNProcessor( + single_precision=True, + device="cpu", + preserve_state_across_windows=True, + model_kwargs={"hidden_size": hidden_size, "output_heads": output_size}, + ) + + # First message: 1 window + data1 = np.random.randn(1, 50, n_ch).astype(np.float32) + msg1 = AxisArray( + data=data1, + dims=["win", "time", "ch"], + axes={ + "time": AxisArray.TimeAxis(fs=512.0), + "win": AxisArray.LinearAxis(unit="s", gain=0.1), + "ch": AxisArray.CoordinateAxis(data=ch_labels, dims=["ch"]), + }, + key="batch1", + ) + + # Second message: 3 windows + data2 = np.random.randn(2, 50, n_ch).astype(np.float32) + data2 = np.append(data2, data1, axis=0) + msg2 = AxisArray( + data=data2, + dims=["win", "time", "ch"], + axes={ + "time": AxisArray.TimeAxis(fs=512.0), + "win": AxisArray.LinearAxis(unit="s", gain=0.1), + "ch": AxisArray.CoordinateAxis(data=ch_labels, dims=["ch"]), + }, + key="batch2", + ) + + val0 = proc(msg1)[0].data + + proc(msg2) + val1 = proc(msg1)[0].data + + # Values should be different because state should not be reset + assert not np.allclose(val0, val1) From 3abc4379999033c6d00f2097aa9fd8269150bf76 Mon Sep 17 00:00:00 2001 From: kylmcgr Date: Fri, 22 Aug 2025 12:00:25 -0400 Subject: [PATCH 06/16] First commit of Transformer model and custom transformer processor --- src/ezmsg/learn/model/transformer.py | 175 ++++++++++ src/ezmsg/learn/process/transformer.py | 222 +++++++++++++ tests/integration/test_transformer_system.py | 77 +++++ tests/unit/test_transformer.py | 317 +++++++++++++++++++ 4 files changed, 791 insertions(+) create mode 100644 src/ezmsg/learn/model/transformer.py create mode 100644 src/ezmsg/learn/process/transformer.py create mode 100644 tests/integration/test_transformer_system.py create mode 100644 tests/unit/test_transformer.py diff --git a/src/ezmsg/learn/model/transformer.py b/src/ezmsg/learn/model/transformer.py new file mode 100644 index 0000000..776b665 --- /dev/null +++ b/src/ezmsg/learn/model/transformer.py @@ -0,0 +1,175 @@ +from typing import Optional + +import torch + + +class TransformerModel(torch.nn.Module): + """ + Transformer-based encoder (optional decoder) neural network. + + If `decoder_layers > 0`, the model includes a Transformer decoder. In this case, the `tgt` argument must be + provided: during training, it is typically the ground-truth target sequence (i.e. teacher forcing); during + inference, it can be constructed autoregressively from previous predictions. + + Attributes: + input_size (int): Number of input features per time step. + hidden_size (int): Dimensionality of the transformer model. + encoder_layers (int, optional): Number of transformer encoder layers. Default is 1. + decoder_layers (int, optional): Number of transformer decoder layers. Default is 0. + output_size (int | dict[str, int], optional): Number of output features or classes if single head output, or a + dictionary mapping head names to output sizes if multi-head output. Default is 2 (single head). + dropout (float, optional): Dropout rate applied after input and transformer output. Default is 0.3. + attention_heads (int, optional): Number of attention heads in the transformer. Default is 4. + max_seq_len (int, optional): Maximum sequence length for positional embeddings. Default is 512. + + Returns: + dict[str, torch.Tensor]: Dictionary of decoded predictions mapping head names to tensors of shape + (batch, seq_len, output_size). If single head output, the key is "output". + """ + + def __init__( + self, + input_size: int, + hidden_size: int, + encoder_layers: int = 1, + decoder_layers: int = 0, + output_size: int | dict[str, int] = 2, + dropout: float = 0.3, + attention_heads: int = 4, + max_seq_len: int = 512, + autoregressive_head: str | None = None, + ): + super().__init__() + + self.decoder_layers = decoder_layers + self.hidden_size = hidden_size + + if isinstance(output_size, int): + autoregressive_size = output_size + else: + autoregressive_size = list(output_size.values())[0] + if isinstance(output_size, dict): + autoregressive_size = output_size.get( + autoregressive_head, autoregressive_size + ) + self.start_token = torch.nn.Parameter(torch.zeros(1, 1, autoregressive_size)) + self.output_to_hidden = torch.nn.Linear(autoregressive_size, hidden_size) + + self.input_proj = torch.nn.Linear(input_size, hidden_size) + self.pos_embedding = torch.nn.Embedding(max_seq_len, hidden_size) + self.dropout = torch.nn.Dropout(dropout) + + self.encoder = torch.nn.TransformerEncoder( + torch.nn.TransformerEncoderLayer( + d_model=hidden_size, + nhead=attention_heads, + dim_feedforward=hidden_size * 4, + dropout=dropout, + batch_first=True, + ), + num_layers=encoder_layers, + ) + + self.decoder = None + if decoder_layers > 0: + self.decoder = torch.nn.TransformerDecoder( + torch.nn.TransformerDecoderLayer( + d_model=hidden_size, + nhead=attention_heads, + dim_feedforward=hidden_size * 4, + dropout=dropout, + batch_first=True, + ), + num_layers=decoder_layers, + ) + + if isinstance(output_size, int): + output_size = {"output": output_size} + self.heads = torch.nn.ModuleDict( + { + name: torch.nn.Linear(hidden_size, out_dim) + for name, out_dim in output_size.items() + } + ) + + @classmethod + def infer_config_from_state_dict(cls, state_dict: dict) -> dict[str, int | float]: + # Infer output size from heads..bias (shape: [output_size]) + output_size = { + key.split(".")[1]: param.shape[0] + for key, param in state_dict.items() + if key.startswith("heads.") and key.endswith(".bias") + } + + return { + # Infer input_size from input_proj.weight (shape: [hidden_size, input_size]) + "input_size": state_dict["input_proj.weight"].shape[1], + # Infer hidden_size from input_proj.weight (shape: [hidden_size, input_size]) + "hidden_size": state_dict["input_proj.weight"].shape[0], + "output_size": output_size, + # Infer encoder_layers from transformer layers in state_dict + "encoder_layers": len( + [k for k in state_dict if k.startswith("encoder.layers")] + ), + # Infer decoder_layers from transformer decoder layers in state_dict + "decoder_layers": len( + {k.split(".")[2] for k in state_dict if k.startswith("decoder.layers")} + ) + if any(k.startswith("decoder.layers") for k in state_dict) + else 0, + } + + def forward( + self, + src: torch.Tensor, + tgt: Optional[torch.Tensor] = None, + src_mask: Optional[torch.Tensor] = None, + tgt_mask: Optional[torch.Tensor] = None, + start_pos: int = 0, + ) -> dict[str, torch.Tensor]: + """ + Forward pass through the transformer model. + Args: + src (torch.Tensor): Input tensor of shape (batch, seq_len, input_size). + tgt (Optional[torch.Tensor]): Target tensor for decoder, shape (batch, seq_len, input_size). + Required if `decoder_layers > 0`. In training, this can be the ground-truth target sequence + (i.e. teacher forcing). During inference, this is constructed autoregressively. + src_mask (Optional[torch.Tensor]): Optional attention mask for the encoder input. Should be broadcastable + to shape (batch, seq_len, seq_len) or (seq_len, seq_len). + tgt_mask (Optional[torch.Tensor]): Optional attention mask for the decoder input. Used to enforce causal + decoding (i.e. autoregressive generation) during training or inference. + start_pos (int): Starting offset for positional embeddings. Used for streaming inference to maintain + correct positional indices. Default is 0. + Returns: + dict[str, torch.Tensor]: Dictionary of output tensors each output head, each with shape (batch, seq_len, + output_size). + """ + B, T, _ = src.shape + device = src.device + + x = self.input_proj(src) + pos_ids = torch.arange(start_pos, start_pos + T, device=device).expand(B, T) + x = x + self.pos_embedding(pos_ids) + x = self.dropout(x) + + memory = self.encoder(x, mask=src_mask) + + if self.decoder is not None: + if tgt is None: + tgt = self.start_token.expand(B, -1, -1).to(device) + tgt_proj = self.output_to_hidden(tgt) + tgt_pos_ids = torch.arange(tgt.shape[1], device=device).expand( + B, tgt.shape[1] + ) + tgt_proj = tgt_proj + self.pos_embedding(tgt_pos_ids) + tgt_proj = self.dropout(tgt_proj) + out = self.decoder( + tgt_proj, + memory, + tgt_mask=tgt_mask, + memory_mask=src_mask, + ) + else: + out = memory + + return {name: head(out) for name, head in self.heads.items()} diff --git a/src/ezmsg/learn/process/transformer.py b/src/ezmsg/learn/process/transformer.py new file mode 100644 index 0000000..4dbe2a5 --- /dev/null +++ b/src/ezmsg/learn/process/transformer.py @@ -0,0 +1,222 @@ +import typing + +import ezmsg.core as ez +import torch +from ezmsg.sigproc.base import BaseAdaptiveTransformer, BaseAdaptiveTransformerUnit +from ezmsg.sigproc.sampler import SampleMessage +from ezmsg.sigproc.util.profile import profile_subpub +from ezmsg.util.messages.axisarray import AxisArray +from ezmsg.util.messages.util import replace + +from .base import ModelInitMixin +from .torch import ( + TorchModelSettings, + TorchModelState, + TorchProcessorMixin, +) + + +class TransformerSettings(TorchModelSettings): + model_class: str = "ezmsg.learn.model.transformer.TransformerModel" + """ + Fully qualified class path of the model to be used. + This should be "ezmsg.learn.model.transformer.TransformerModel" for this. + """ + autoregressive_head: str | None = None + """ + The name of the output head used for autoregressive decoding. + This should match one of the keys in the model's output dictionary. + If None, the first output head will be used. + """ + max_cache_len: int | None = 128 + """ + Maximum length of the target sequence cache for autoregressive decoding. + This limits the context length during decoding to prevent excessive memory usage. + If set to None, the cache will grow indefinitely. + """ + + +class TransformerState(TorchModelState): + ar_head: str | None = None + """ + The name of the autoregressive head used for decoding. + This is set based on the `autoregressive_head` setting. + If None, the first output head will be used. + """ + tgt_cache: typing.Optional[torch.Tensor] = None + """ + Cache for the target sequence used in autoregressive decoding. + This is updated with each processed message to maintain context. + """ + + +class TransformerProcessor( + BaseAdaptiveTransformer[ + TransformerSettings, AxisArray, AxisArray, TransformerState + ], + TorchProcessorMixin, + ModelInitMixin, +): + @property + def has_decoder(self) -> bool: + return self.settings.model_kwargs.get("decoder_layers", 0) > 0 + + def reset_cache(self) -> None: + self._state.tgt_cache = None + + def _reset_state(self, message: AxisArray) -> None: + model_kwargs = dict(self.settings.model_kwargs or {}) + self._common_reset_state(message, model_kwargs) + self._init_optimizer() + self._validate_loss_keys(list(self._state.chan_ax.keys())) + + self._state.tgt_cache = None + if ( + self.settings.autoregressive_head is not None + and self.settings.autoregressive_head not in self._state.chan_ax + ): + raise ValueError( + f"Autoregressive head '{self.settings.autoregressive_head}' not found in target dictionary keys: {list(self._state.chan_ax.keys())}" + ) + self._state.ar_head = ( + self.settings.autoregressive_head + if self.settings.autoregressive_head is not None + else list(self._state.chan_ax.keys())[0] + ) + + def _process(self, message: AxisArray) -> list[AxisArray]: + # If has_decoder is False, fallback to regular processing + if not self.has_decoder: + return self._common_process(message) + + x = self._to_tensor(message.data) + x, _ = self._ensure_batched(x) + if x.shape[0] > 1: + raise ValueError("Autoregressive decoding only supports batch size 1.") + + with torch.no_grad(): + y_pred = self._state.model(x, tgt=self._state.tgt_cache) + + pred = y_pred[self._state.ar_head] + if self._state.tgt_cache is None: + self._state.tgt_cache = pred[:, -1:, :] + else: + self._state.tgt_cache = torch.cat( + [self._state.tgt_cache, pred[:, -1:, :]], dim=1 + ) + if self.settings.max_cache_len is not None: + if self._state.tgt_cache.shape[1] > self.settings.max_cache_len: + # Trim the cache to the maximum length + self._state.tgt_cache = self._state.tgt_cache[ + :, -self.settings.max_cache_len :, : + ] + + if isinstance(y_pred, dict): + return [ + replace( + message, + data=out.squeeze(0).cpu().numpy(), + axes={**message.axes, "ch": self._state.chan_ax[key]}, + key=key, + ) + for key, out in y_pred.items() + ] + else: + return [ + replace( + message, + data=y_pred.squeeze(0).cpu().numpy(), + axes={**message.axes, "ch": self._state.chan_ax["output"]}, + ) + ] + + def partial_fit(self, message: SampleMessage) -> None: + self._state.model.train() + + X = self._to_tensor(message.sample.data) + X, batched = self._ensure_batched(X) + + y_targ = message.trigger.value + if not isinstance(y_targ, dict): + y_targ = {"output": y_targ} + y_targ = {key: self._to_tensor(value) for key, value in y_targ.items()} + # Add batch dimension to y_targ values if missing + if batched: + for key in y_targ: + y_targ[key] = y_targ[key].unsqueeze(0) + + loss_fns = self.settings.loss_fn + if loss_fns is None: + raise ValueError("loss_fn must be provided in settings to use partial_fit") + if not isinstance(loss_fns, dict): + loss_fns = {k: loss_fns for k in y_targ.keys()} + + weights = self.settings.loss_weights or {} + + if self.has_decoder: + if X.shape[0] != 1: + raise ValueError("Autoregressive decoding only supports batch size 1.") + + # Create shifted target for autoregressive head + tgt_tensor = y_targ[self._state.ar_head] + tgt = torch.cat( + [ + torch.zeros( + (1, 1, tgt_tensor.shape[-1]), + dtype=tgt_tensor.dtype, + device=tgt_tensor.device, + ), + tgt_tensor[:, :-1, :], + ], + dim=1, + ) + + # Reset tgt_cache at start of partial_fit to avoid stale context + self.reset_cache() + y_pred = self._state.model(X, tgt=tgt) + else: + # For non-autoregressive models, use the model directly + y_pred = self._state.model(X) + + if not isinstance(y_pred, dict): + y_pred = {"output": y_pred} + + with torch.set_grad_enabled(True): + losses = [] + for key in y_targ.keys(): + loss_fn = loss_fns.get(key) + if loss_fn is None: + raise ValueError( + f"Loss function for key '{key}' is not defined in settings." + ) + loss = loss_fn(y_pred[key], y_targ[key]) + weight = weights.get(key, 1.0) + losses.append(loss * weight) + total_loss = sum(losses) + + self._state.optimizer.zero_grad() + total_loss.backward() + self._state.optimizer.step() + if self._state.scheduler is not None: + self._state.scheduler.step() + + self._state.model.eval() + + +class TransformerUnit( + BaseAdaptiveTransformerUnit[ + TransformerSettings, + AxisArray, + AxisArray, + TransformerProcessor, + ] +): + SETTINGS = TransformerSettings + + @ez.subscriber(BaseAdaptiveTransformerUnit.INPUT_SIGNAL, zero_copy=True) + @ez.publisher(BaseAdaptiveTransformerUnit.OUTPUT_SIGNAL) + @profile_subpub(trace_oldest=False) + async def on_signal(self, message: AxisArray) -> typing.AsyncGenerator: + results = await self.processor.__acall__(message) + for result in results: + yield self.OUTPUT_SIGNAL, result diff --git a/tests/integration/test_transformer_system.py b/tests/integration/test_transformer_system.py new file mode 100644 index 0000000..5407df3 --- /dev/null +++ b/tests/integration/test_transformer_system.py @@ -0,0 +1,77 @@ +import os +import tempfile +from pathlib import Path + +import ezmsg.core as ez +from ezmsg.sigproc.synth import Counter, CounterSettings +from ezmsg.util.messagecodec import message_log +from ezmsg.util.messagelogger import MessageLogger, MessageLoggerSettings +from ezmsg.util.messages.axisarray import AxisArray +from ezmsg.util.terminate import TerminateOnTotal, TerminateOnTotalSettings + +from ezmsg.learn.process.transformer import TransformerUnit + + +def test_torch_model_unit_system(): + fs = 10.0 + block_size = 4 + duration = 2.0 # seconds + input_size = 3 + output_size = 2 + hidden_size = 32 + attention_heads = 4 + num_layers = 2 + single_precision = True + + test_filename = Path(tempfile.gettempdir()) + test_filename = test_filename / Path("test_torch_system.txt") + with open(test_filename, "w"): + pass + ez.logger.info(f"Logging to {test_filename}") + + comps = { + "SRC": Counter( + CounterSettings( + fs=fs, + n_ch=input_size, + n_time=block_size, + dispatch_rate=duration, + mod=None, + ) + ), + "MODEL": TransformerUnit( + single_precision=single_precision, + learning_rate=1e-2, + scheduler_gamma=0.0, + weight_decay=0.0, + device="cpu", + model_kwargs={ + "hidden_size": hidden_size, + "encoder_layers": num_layers, + "output_size": output_size, + "attention_heads": attention_heads, + }, + ), + "LOG": MessageLogger(MessageLoggerSettings(output=test_filename)), + "TERM": TerminateOnTotal( + TerminateOnTotalSettings(total=int(duration * fs / block_size)) + ), + } + + conns = ( + (comps["SRC"].OUTPUT_SIGNAL, comps["MODEL"].INPUT_SIGNAL), + (comps["MODEL"].OUTPUT_SIGNAL, comps["LOG"].INPUT_MESSAGE), + (comps["LOG"].OUTPUT_MESSAGE, comps["TERM"].INPUT_MESSAGE), + ) + + ez.run(components=comps, connections=conns) + + # Read from message log + messages: list[AxisArray] = [_ for _ in message_log(test_filename)] + os.remove(test_filename) + + # Check basic structure + assert all(msg.data.shape[-1] == output_size for msg in messages) + assert all("time" in msg.dims and "ch" in msg.dims for msg in messages) + assert all("ch" in msg.axes for msg in messages) + assert messages[0].axes["ch"].data.shape[0] == output_size diff --git a/tests/unit/test_transformer.py b/tests/unit/test_transformer.py new file mode 100644 index 0000000..17048e3 --- /dev/null +++ b/tests/unit/test_transformer.py @@ -0,0 +1,317 @@ +import tempfile + +import numpy as np +import pytest +import torch +import torch.nn +from ezmsg.sigproc.sampler import SampleMessage, SampleTriggerMessage +from ezmsg.util.messages.axisarray import AxisArray + +from ezmsg.learn.process.transformer import TransformerProcessor + + +@pytest.fixture +def simple_message() -> AxisArray: + n_ch = 192 + data = np.arange(100 * n_ch).reshape(100, n_ch) + ch_labels = np.array([f"ch{i}" for i in range(n_ch)]) + message = AxisArray( + data=data, + dims=["time", "ch"], + axes={ + "time": AxisArray.TimeAxis(fs=512.0), + "ch": AxisArray.CoordinateAxis(data=ch_labels, dims=["ch"]), + }, + key="test_transformer", + ) + return message + + +@pytest.mark.parametrize("hidden_size,attention_heads", [(20, 4), (30, 5), (40, 4)]) +@pytest.mark.parametrize("encoder_layers", [1, 2, 3]) +@pytest.mark.parametrize("decoder_layers", [0, 1]) +@pytest.mark.parametrize("output_size", [2, 3]) +def test_transformer_init( + hidden_size, + attention_heads, + encoder_layers, + decoder_layers, + output_size, + simple_message, +): + single_precision = True + + proc = TransformerProcessor( + single_precision=single_precision, + device="cpu", + model_kwargs={ + "hidden_size": hidden_size, + "encoder_layers": encoder_layers, + "decoder_layers": decoder_layers, + "output_size": output_size, + "attention_heads": attention_heads, + }, + ) + + # Verify the settings were registered properly + assert proc.settings.single_precision == single_precision + + # The processor creates its model upon receipt of the first message. + proc(simple_message) + + # Verify the settings were incorporated into the model + mdl: torch.nn.Module = proc.state.model + + in_dim = simple_message.data.shape[simple_message.get_axis_idx("ch")] + + assert mdl.input_proj.in_features == in_dim + assert mdl.hidden_size == hidden_size + assert len(mdl.encoder.layers) == encoder_layers + for layer in mdl.encoder.layers: + assert layer.dropout.p == 0.3 + assert mdl.heads["output"].out_features == output_size + + +@pytest.mark.parametrize("decoder_layers", [0, 1]) +def test_transformer_process(simple_message, decoder_layers): + hidden_size = 16 + encoder_layers = 1 + output_size = 2 + single_precision = True + + proc = TransformerProcessor( + single_precision=single_precision, + learning_rate=1e-2, + scheduler_gamma=0.0, + weight_decay=0.0, + device="cpu", + model_kwargs={ + "hidden_size": hidden_size, + "encoder_layers": encoder_layers, + "decoder_layers": decoder_layers, + "output_size": output_size, + }, + ) + + output = proc(simple_message)[0] + + time_dim = (1,) if decoder_layers > 0 else simple_message.data.shape[:-1] + assert output.data.shape == time_dim + (output_size,) + + # Bypass processor and call the model directly to verify output + in_tensor = torch.tensor(simple_message.data[None, ...], dtype=torch.float32) + with torch.no_grad(): + expected_result = proc.state.model(in_tensor)["output"].cpu().numpy().squeeze(0) + assert np.allclose(output.data, expected_result) + if decoder_layers > 0: + assert proc.has_decoder + # If decoder_layers > 0, tgt_cache should be initialized + assert proc.state.tgt_cache is not None + + +@pytest.mark.parametrize("decoder_layers", [0, 1]) +def test_transformer_partial_fit(simple_message, decoder_layers): + hidden_size = 16 + encoder_layers = 1 + output_size = 2 + single_precision = True + + proc = TransformerProcessor( + single_precision=single_precision, + learning_rate=1e-2, + scheduler_gamma=0.0, + weight_decay=0.0, + loss_fn=torch.nn.MSELoss(), + device="cpu", + model_kwargs={ + "hidden_size": hidden_size, + "encoder_layers": encoder_layers, + "decoder_layers": decoder_layers, + "output_size": output_size, + }, + ) + + proc(simple_message) # Initialize the model + + initial_weights = [p.detach().clone() for p in proc.state.model.parameters()] + + target_shape = (simple_message.data.shape[0], output_size) + target_value = np.ones(target_shape, dtype=np.float32) + sample_message = SampleMessage( + trigger=SampleTriggerMessage(timestamp=0.0, value=target_value), + sample=simple_message, + ) + + proc.partial_fit(sample_message) + + assert not proc.state.model.training + assert proc.state.tgt_cache is None + updated_weights = [p.detach() for p in proc.state.model.parameters()] + + assert any( + not torch.equal(w0, w1) for w0, w1 in zip(initial_weights, updated_weights) + ) + + +def test_transformer_checkpoint_save_load(simple_message): + hidden_size = 16 + encoder_layers = 1 + output_size = 2 + single_precision = True + + proc = TransformerProcessor( + single_precision=single_precision, + device="cpu", + model_kwargs={ + "hidden_size": hidden_size, + "encoder_layers": encoder_layers, + "output_size": output_size, + }, + ) + + # First pass to initialize model + proc(simple_message) + + # Save full checkpoint (state_dict + config) + with tempfile.NamedTemporaryFile(suffix=".pt") as tmp: + proc.save_checkpoint(tmp.name) + + # Load from checkpoint + proc2 = TransformerProcessor( + checkpoint_path=tmp.name, + single_precision=single_precision, + device="cpu", + model_kwargs={ + "hidden_size": hidden_size, + "encoder_layers": encoder_layers, + "output_size": output_size, + }, + ) + + proc2(simple_message) + + # Compare state_dicts directly + state_dict1 = proc.state.model.state_dict() + state_dict2 = proc2.state.model.state_dict() + + for key in state_dict1: + assert key in state_dict2, f"Missing key {key} in loaded state_dict" + assert torch.equal(state_dict1[key], state_dict2[key]), ( + f"Mismatch in parameter {key}" + ) + + +def test_transformer_partial_fit_multiloss(simple_message): + hidden_size = 16 + encoder_layers = 1 + output_heads = {"traj": 3, "state": 3} + single_precision = True + + proc = TransformerProcessor( + single_precision=single_precision, + learning_rate=1e-2, + scheduler_gamma=0.0, + weight_decay=0.0, + loss_fn={"traj": torch.nn.MSELoss(), "state": torch.nn.CrossEntropyLoss()}, + loss_weights={"traj": 1.0, "state": 1.0}, + device="cpu", + model_kwargs={ + "hidden_size": hidden_size, + "encoder_layers": encoder_layers, + "output_size": output_heads, + }, + ) + + output = proc(simple_message) + initial_weights = [p.detach().clone() for p in proc.state.model.parameters()] + + # Build targets + traj_target = torch.tensor( + np.random.randn(*output[0].data.shape), + dtype=torch.float32, + ) + state_target = torch.tensor( + np.random.randint(0, output_heads["state"], size=output[1].data.shape), + dtype=torch.long, + ) + + sample_message = SampleMessage( + trigger=SampleTriggerMessage( + timestamp=0.0, + value={"traj": traj_target, "state": state_target}, + ), + sample=simple_message, + ) + + proc.partial_fit(sample_message) + + updated_weights = [p.detach() for p in proc.state.model.parameters()] + assert any( + not torch.equal(w0, w1) for w0, w1 in zip(initial_weights, updated_weights) + ) + + +def test_autoregressive_cache_behavior(simple_message): + proc = TransformerProcessor( + single_precision=True, + device="cpu", + model_kwargs={ + "hidden_size": 8, + "encoder_layers": 1, + "decoder_layers": 1, # Enable autoregressive mode + "output_size": 2, + }, + ) + + # First call initializes model and cache + proc(simple_message) + cache1 = proc.state.tgt_cache.clone() + + # Second call should extend the cache + proc(simple_message) + cache2 = proc.state.tgt_cache + + assert cache2.shape[1] > cache1.shape[1] + assert not torch.equal(cache1, cache2), "Cache should be updated with new data" + + +def test_cache_truncation(simple_message): + max_len = 10 + + proc = TransformerProcessor( + single_precision=True, + device="cpu", + model_kwargs={ + "hidden_size": 8, + "encoder_layers": 1, + "decoder_layers": 1, + "output_size": 2, + }, + max_cache_len=max_len, + ) + + for _ in range(20): + proc(simple_message) + + assert proc.state.tgt_cache.shape[1] <= max_len + + +def test_invalid_autoregressive_head_raises(simple_message): + proc = TransformerProcessor( + single_precision=True, + device="cpu", + loss_fn=torch.nn.MSELoss(), + model_kwargs={ + "hidden_size": 8, + "encoder_layers": 1, + "decoder_layers": 1, + "output_size": 2, + }, + autoregressive_head="my_output_head", # Invalid key + ) + + with pytest.raises( + ValueError, + match="Autoregressive head 'my_output_head' not found", + ): + proc(simple_message) From 2eeb5461a65677b6e900b1110fa610d3bf44398d Mon Sep 17 00:00:00 2001 From: hernst1 <120542324+hernst1@users.noreply.github.com> Date: Fri, 22 Aug 2025 12:01:38 -0400 Subject: [PATCH 07/16] First commit of RefitKalmanFilter --- src/ezmsg/learn/model/refit_kalman.py | 401 +++++++++++++ src/ezmsg/learn/process/refit_kalman.py | 407 +++++++++++++ tests/integration/test_refit_kalman_system.py | 126 +++++ tests/unit/test_refit_kalman.py | 535 ++++++++++++++++++ 4 files changed, 1469 insertions(+) create mode 100644 src/ezmsg/learn/model/refit_kalman.py create mode 100644 src/ezmsg/learn/process/refit_kalman.py create mode 100644 tests/integration/test_refit_kalman_system.py create mode 100644 tests/unit/test_refit_kalman.py diff --git a/src/ezmsg/learn/model/refit_kalman.py b/src/ezmsg/learn/model/refit_kalman.py new file mode 100644 index 0000000..ef2c683 --- /dev/null +++ b/src/ezmsg/learn/model/refit_kalman.py @@ -0,0 +1,401 @@ +# refit_kalman.py + +import numpy as np +from numpy.linalg import LinAlgError +from scipy.linalg import solve_discrete_are + + +class RefitKalmanFilter: + """ + Refit Kalman filter for adaptive neural decoding. + + This class implements a Kalman filter that can be refitted online during operation. + Unlike the standard Kalman filter, this version can adapt its observation model + (H and Q matrices) based on new data while maintaining the state transition model + (A and W matrices). This is particularly useful for brain-computer interfaces + where the relationship between neural activity and intended movements may change + over time. + + The filter operates in two phases: + 1. Initial fitting: Learns all system matrices (A, W, H, Q) from training data + 2. Refitting: Updates only the observation model (H, Q) based on new data + + Attributes: + A_state_transition_matrix: The state transition matrix A (n_states x n_states). + W_process_noise_covariance: The process noise covariance matrix W (n_states x n_states). + H_observation_matrix: The observation matrix H (n_observations x n_states). + Q_measurement_noise_covariance: The measurement noise covariance matrix Q (n_observations x n_observations). + K_kalman_gain: The Kalman gain matrix (n_states x n_observations). + P_state_covariance: The state error covariance matrix (n_states x n_states). + steady_state: Whether to use steady-state Kalman gain computation. + is_fitted: Whether the model has been fitted with data. + + Example: + >>> # Create and fit the filter + >>> rkf = RefitKalmanFilter(steady_state=True) + >>> rkf.fit(X_train, y_train) + >>> + >>> # Refit with new data + >>> rkf.refit(X_new, Y_state, velocity_indices, targets, cursors, holds) + >>> + >>> # Predict with updated model + >>> x_updated = rkf.predict_and_update(measurement, current_state) + """ + + def __init__( + self, + A_state_transition_matrix=None, + W_process_noise_covariance=None, + H_observation_matrix=None, + Q_measurement_noise_covariance=None, + steady_state=False, + enforce_state_structure=False, + alpha_fading_memory=1.000, + process_noise_scale=1, + measurement_noise_scale=1.2, + ): + self.A_state_transition_matrix = A_state_transition_matrix + self.W_process_noise_covariance = W_process_noise_covariance + self.H_observation_matrix = H_observation_matrix + self.Q_measurement_noise_covariance = Q_measurement_noise_covariance + self.K_kalman_gain = None + self.P_state_covariance = None + self.alpha_fading_memory = alpha_fading_memory + + # Noise scaling factors for smoothing control + self.process_noise_scale = process_noise_scale + self.measurement_noise_scale = measurement_noise_scale + + self.steady_state = steady_state + self.enforce_state_structure = enforce_state_structure + self.is_fitted = False + + def _validate_state_vector(self, Y_state): + """ + Validate that the state vector has proper dimensions. + + Args: + Y_state: State vector to validate + + Raises: + ValueError: If state vector has invalid dimensions + """ + if Y_state.ndim != 2: + raise ValueError(f"State vector must be 2D, got {Y_state.ndim}D") + + if ( + not hasattr(self, "H_observation_matrix") + or self.H_observation_matrix is None + ): + raise ValueError("Model must be fitted before refitting") + + expected_states = self.H_observation_matrix.shape[1] + if Y_state.shape[1] != expected_states: + raise ValueError( + f"State vector has {Y_state.shape[1]} dimensions, expected {expected_states}" + ) + + def fit(self, X_train, y_train): + """ + Fit the Refit Kalman filter to the training data. + + This method learns all system matrices (A, W, H, Q) from training data + using least-squares estimation, then computes the steady-state solution. + This is the initial fitting phase that establishes the baseline model. + + Args: + X_train: Neural activity (n_samples, n_neurons). + y_train: Outputs being predicted (n_samples, n_states). + + Raises: + ValueError: If training data has invalid dimensions. + LinAlgError: If matrix operations fail during fitting. + """ + # self._validate_state_vector(y_train) + + X = np.array(y_train) + Z = np.array(X_train) + n_samples = X.shape[0] + + # Calculate the transition matrix (from x_t to x_t+1) using least-squares + X2 = X[1:, :] # x_{t+1} + X1 = X[:-1, :] # x_t + A = X2.T @ X1 @ np.linalg.inv(X1.T @ X1) # Transition matrix + W = ( + (X2 - X1 @ A.T).T @ (X2 - X1 @ A.T) / (n_samples - 1) + ) # Covariance of transition matrix + + # Calculate the measurement matrix (from x_t to z_t) using least-squares + H = Z.T @ X @ np.linalg.inv(X.T @ X) # Measurement matrix + Q = ( + (Z - X @ H.T).T @ (Z - X @ H.T) / Z.shape[0] + ) # Covariance of measurement matrix + + self.A_state_transition_matrix = A + self.W_process_noise_covariance = W * self.process_noise_scale + self.H_observation_matrix = H + self.Q_measurement_noise_covariance = Q * self.measurement_noise_scale + + self._compute_gain() + self.is_fitted = True + + def refit( + self, + X_neural: np.ndarray, + Y_state: np.ndarray, + intention_velocity_indices: int | None = None, + target_positions: np.ndarray | None = None, + cursor_positions: np.ndarray | None = None, + hold_indices: np.ndarray | None = None, + ): + """ + Refit the observation model based on new data. + + This method updates only the observation model (H and Q matrices) while + keeping the state transition model (A and W matrices) unchanged. The refitting + process modifies the intended states based on target positions and hold flags + to better align with user intentions. + + The refitting process: + 1. Modifies intended states based on target positions and hold flags + 2. Recalculates the observation matrix H using least-squares + 3. Recalculates the measurement noise covariance Q + 4. Updates the Kalman gain accordingly + + Args: + X_neural: Neural activity data (n_samples, n_neurons). + Y_state: State estimates (n_samples, n_states). + intention_velocity_indices: Index of velocity components in state vector. + target_positions: Target positions for each sample (n_samples, 2). + cursor_positions: Current cursor positions (n_samples, 2). + hold_indices: Boolean flags indicating hold periods (n_samples,). + + Raises: + ValueError: If input data has invalid dimensions or the model is not fitted. + """ + self._validate_state_vector(Y_state) + + # Check if velocity indices are provided + if intention_velocity_indices is None: + # Assume (x, y, vx, vy) + vel_idx = 2 if Y_state.shape[1] >= 4 else 0 + print( + f"[RefitKalmanFilter] No velocity index provided — defaulting to {vel_idx}" + ) + else: + if isinstance(intention_velocity_indices, (list, tuple)): + if len(intention_velocity_indices) != 1: + raise ValueError( + "Only one velocity start index should be provided." + ) + vel_idx = intention_velocity_indices[0] + else: + vel_idx = intention_velocity_indices + + # Only remap velocity if target and cursor positions are provided + if target_positions is None or cursor_positions is None: + intended_states = Y_state.copy() + else: + intended_states = Y_state.copy() + # Calculate intended velocities for each sample + for i, (state, pos, target) in enumerate( + zip(Y_state, cursor_positions, target_positions) + ): + is_hold = hold_indices[i] if hold_indices is not None else False + + if is_hold: + # During hold periods, intended velocity is zero + intended_states[i, vel_idx : vel_idx + 2] = 0.0 + if i > 0: + intended_states[i, :2] = intended_states[ + i - 1, :2 + ] # Same position as previous + else: + # Calculate direction to target + to_target = target - pos + target_distance = np.linalg.norm(to_target) + + if target_distance > 1e-5: # Avoid division by zero + # Get current decoded velocity magnitude + current_velocity = state[vel_idx : vel_idx + 2] + current_speed = np.linalg.norm(current_velocity) + + # Calculate intended velocity: same speed, but toward target + target_direction = to_target / target_distance + intended_velocity = target_direction * current_speed + + # Update intended state with new velocity + intended_states[i, vel_idx : vel_idx + 2] = intended_velocity + # If target is very close, keep original velocity + else: + intended_states[i, vel_idx : vel_idx + 2] = state[ + vel_idx : vel_idx + 2 + ] + + intended_states = np.array(intended_states) + Z = np.array(X_neural) + + # Recalculate observation matrix and noise covariance + H = ( + Z.T @ intended_states @ np.linalg.pinv(intended_states.T @ intended_states) + ) # Using pinv() instead of inv() to avoid singular matrix errors + Q = (Z - intended_states @ H.T).T @ (Z - intended_states @ H.T) / Z.shape[0] + + self.H_observation_matrix = H + self.Q_measurement_noise_covariance = Q + + self._compute_gain() + + def _compute_gain(self): + """ + Compute the Kalman gain matrix. + + This method computes the Kalman gain matrix based on the current system + parameters. In steady-state mode, it solves the discrete-time algebraic + Riccati equation to find the optimal steady-state gain. In non-steady-state + mode, it computes the gain using the current covariance matrix. + + Raises: + LinAlgError: If the Riccati equation cannot be solved or matrix operations fail. + """ + ## TODO: consider removing non-steady-state for compute_gain() - non_steady_state updates will occur during predict() and update() + # if self.steady_state: + try: + # Try with original matrices + self.P_state_covariance = solve_discrete_are( + self.A_state_transition_matrix.T, + self.H_observation_matrix.T, + self.W_process_noise_covariance, + self.Q_measurement_noise_covariance, + ) + self.K_kalman_gain = ( + self.P_state_covariance + @ self.H_observation_matrix.T + @ np.linalg.inv( + self.H_observation_matrix + @ self.P_state_covariance + @ self.H_observation_matrix.T + + self.Q_measurement_noise_covariance + ) + ) + except LinAlgError: + # Apply regularization and retry + # A_reg = self.A_state_transition_matrix * 0.999 # Slight damping + # W_reg = self.W_process_noise_covariance + 1e-7 * np.eye( + # self.W_process_noise_covariance.shape[0] + # ) + Q_reg = self.Q_measurement_noise_covariance + 1e-7 * np.eye( + self.Q_measurement_noise_covariance.shape[0] + ) + + try: + self.P_state_covariance = solve_discrete_are( + self.A_state_transition_matrix.T, + self.H_observation_matrix.T, + self.W_process_noise_covariance, + Q_reg, + ) + self.K_kalman_gain = ( + self.P_state_covariance + @ self.H_observation_matrix.T + @ np.linalg.inv( + self.H_observation_matrix + @ self.P_state_covariance + @ self.H_observation_matrix.T + + Q_reg + ) + ) + print("Warning: Used regularized matrices for DARE solution") + except LinAlgError: + # Fallback to identity or manual initialization + print("Warning: DARE failed, using identity covariance") + self.P_state_covariance = np.eye( + self.A_state_transition_matrix.shape[0] + ) + + # else: + # n_states = self.A_state_transition_matrix.shape[0] + # self.P_state_covariance = ( + # np.eye(n_states) * 1000 + # ) # Large initial uncertainty + + # P_m = ( + # self.A_state_transition_matrix + # @ self.P_state_covariance + # @ self.A_state_transition_matrix.T + # + self.W_process_noise_covariance + # ) + + # S = ( + # self.H_observation_matrix @ P_m @ self.H_observation_matrix.T + # + self.Q_measurement_noise_covariance + # ) + + # self.K_kalman_gain = P_m @ self.H_observation_matrix.T @ np.linalg.pinv(S) + + # I_mat = np.eye(self.A_state_transition_matrix.shape[0]) + # self.P_state_covariance = ( + # I_mat - self.K_kalman_gain @ self.H_observation_matrix + # ) @ P_m + + def predict(self, x_current: np.ndarray) -> tuple[np.ndarray, np.ndarray]: + """ + Predict the next state and covariance. + + This method predicts the next state and covariance using the current state. + """ + x_predicted = self.A_state_transition_matrix @ x_current + if self.steady_state is True: + return x_predicted, None + else: + P_predicted = self.alpha_fading_memory**2 * ( + self.A_state_transition_matrix + @ self.P_state_covariance + @ self.A_state_transition_matrix.T + + self.W_process_noise_covariance + ) + return x_predicted, P_predicted + + def update( + self, + z_measurement: np.ndarray, + x_predicted: np.ndarray, + P_predicted: np.ndarray | None = None, + ) -> np.ndarray: + """Update state estimate and covariance based on measurement z.""" + + # Compute residual + innovation = z_measurement - self.H_observation_matrix @ x_predicted + + if self.steady_state: + x_updated = x_predicted + self.K_kalman_gain @ innovation + return x_updated + + if P_predicted is None: + raise ValueError("P_predicted must be provided for non-steady-state mode") + + # Non-steady-state mode + # System uncertainty + S = ( + self.H_observation_matrix @ P_predicted @ self.H_observation_matrix.T + + self.Q_measurement_noise_covariance + ) + + # Kalman gain + K = P_predicted @ self.H_observation_matrix.T @ np.linalg.pinv(S) + + # Updated state + x_updated = x_predicted + K @ innovation + + # Covariance update + I_mat = np.eye(self.A_state_transition_matrix.shape[0]) + P_updated = (I_mat - K @ self.H_observation_matrix) @ P_predicted @ ( + I_mat - K @ self.H_observation_matrix + ).T + K @ self.Q_measurement_noise_covariance @ K.T + + # Save updated values + self.P_state_covariance = P_updated + self.K_kalman_gain = K + # self.S = S # Optional: for diagnostics + + return x_updated diff --git a/src/ezmsg/learn/process/refit_kalman.py b/src/ezmsg/learn/process/refit_kalman.py new file mode 100644 index 0000000..4ebc8cb --- /dev/null +++ b/src/ezmsg/learn/process/refit_kalman.py @@ -0,0 +1,407 @@ +import pickle +from pathlib import Path + +import ezmsg.core as ez +import numpy as np +from ezmsg.sigproc.base import ( + BaseAdaptiveTransformer, + BaseAdaptiveTransformerUnit, + processor_state, +) +from ezmsg.sigproc.sampler import SampleMessage +from ezmsg.util.messages.axisarray import AxisArray +from ezmsg.util.messages.util import replace + +from ..model.refit_kalman import RefitKalmanFilter + + +class RefitKalmanFilterSettings(ez.Settings): + """ + Settings for the Refit Kalman filter processor. + + This class defines the configuration parameters for the Refit Kalman filter processor. + The RefitKalmanFilter is designed for online processing and playback. + + Attributes: + checkpoint_path: Path to saved model parameters (optional). + If provided, loads pre-trained parameters instead of learning from data. + steady_state: Whether to use steady-state Kalman filter. + If True, uses pre-computed Kalman gain; if False, updates dynamically. + """ + + checkpoint_path: str | None = None + steady_state: bool = False + velocity_indices: tuple[int, int] = (2, 3) + + +@processor_state +class RefitKalmanFilterState: + """ + State management for the Refit Kalman filter processor. + + This class manages the persistent state of the Refit Kalman filter processor, + including the model instance, current state estimates, and data buffers for refitting. + + Attributes: + model: The RefitKalmanFilter model instance. + x: Current state estimate (n_states,). + P: Current state covariance matrix (n_states x n_states). + buffer_neural: Buffer for storing neural activity data for refitting. + buffer_state: Buffer for storing state estimates for refitting. + buffer_cursor_positions: Buffer for storing cursor positions for refitting. + buffer_target_positions: Buffer for storing target positions for refitting. + buffer_hold_flags: Buffer for storing hold flags for refitting. + current_position: Current cursor position estimate (2,). + """ + + model: RefitKalmanFilter | None = None + x: np.ndarray | None = None + P: np.ndarray | None = None + + buffer_neural: list | None = None + buffer_state: list | None = None + buffer_cursor_positions: list | None = None + buffer_target_positions: list | None = None + buffer_hold_flags: list | None = None + + +class RefitKalmanFilterProcessor( + BaseAdaptiveTransformer[ + RefitKalmanFilterSettings, + AxisArray, + AxisArray, + RefitKalmanFilterState, + ] +): + """ + Processor for implementing a Refit Kalman filter in the ezmsg framework. + + This processor integrates the RefitKalmanFilter model into the ezmsg + message passing system. It handles the conversion between AxisArray messages + and the internal Refit Kalman filter operations. + + The processor performs the following operations: + 1. Configures the Refit Kalman filter model with provided settings + 2. Processes incoming measurement messages + 3. Performs prediction and update steps + 4. Logs data for potential refitting + 5. Supports online refitting of the observation model + 6. Returns filtered state estimates as AxisArray messages + 7. Maintains state between message processing calls + + The processor can operate in two modes: + 1. Pre-trained mode: Loads parameters from checkpoint_path + 2. Learning mode: Collects data and fits the model when buffer is full + + Key features: + - Online refitting capability for adaptive neural decoding + - Data logging for retrospective analysis + - Position tracking for cursor control applications + - Hold period detection and handling + + Attributes: + settings: Configuration settings for the Refit Kalman filter. + _state: Internal state management object. + + Example: + >>> # Create settings with checkpoint path + >>> settings = RefitKalmanFilterSettings( + ... checkpoint_path="path/to/checkpoint.pkl", + ... steady_state=True + ... ) + >>> + >>> # Create processor + >>> processor = RefitKalmanFilterProcessor(settings) + >>> + >>> # Process measurement message + >>> result = processor(measurement_message) + >>> + >>> # Log data for refitting + >>> processor.log_for_refit(message, target_pos, hold_flag) + >>> + >>> # Refit the model + >>> processor.refit_model() + """ + + def _config_from_settings(self) -> dict: + """ + Returns: + dict: Dictionary containing configuration parameters for model initialization. + """ + return { + "steady_state": self.settings.steady_state, + } + + def _init_model(self, **kwargs): + """ + Initialize a new RefitKalmanFilter model with current settings. + + Args: + **kwargs: Keyword arguments for model initialization. + """ + config = self._config_from_settings() + config.update(kwargs) + self._state.model = RefitKalmanFilter(**config) + + def fit(self, X: np.ndarray, y: np.ndarray) -> None: + if self._state.model is None: + self._init_model() + if hasattr(self._state.model, "fit"): + self._state.model.fit(X, y) + + def load_from_checkpoint(self, checkpoint_path: str) -> None: + """ + Load model parameters from a serialized checkpoint file. + + Args: + checkpoint_path (str): Path to the saved checkpoint file. + + Side Effects: + - Initializes a new model if not already set. + - Sets model matrices A, W, H, Q from the checkpoint. + - Computes Kalman gain based on restored parameters. + """ + checkpoint_file = Path(checkpoint_path) + with open(checkpoint_file, "rb") as f: + checkpoint_data = pickle.load(f) + self._init_model(**checkpoint_data) + self._state.model._compute_gain() + self._state.model.is_fitted = True + + def save_checkpoint(self, checkpoint_path: str) -> None: + """ + Save current model parameters to a checkpoint file. + + Args: + checkpoint_path (str): Destination file path for saving model parameters. + + Raises: + ValueError: If the model is not initialized or has not been fitted. + """ + if not self._state.model or not self._state.model.is_fitted: + raise ValueError("Cannot save checkpoint: model not fitted") + checkpoint_data = { + "A_state_transition_matrix": self._state.model.A_state_transition_matrix, + "W_process_noise_covariance": self._state.model.W_process_noise_covariance, + "H_observation_matrix": self._state.model.H_observation_matrix, + "Q_measurement_noise_covariance": self._state.model.Q_measurement_noise_covariance, + } + checkpoint_file = Path(checkpoint_path) + checkpoint_file.parent.mkdir(parents=True, exist_ok=True) + with open(checkpoint_file, "wb") as f: + pickle.dump(checkpoint_data, f) + + def _reset_state( + self, + message: AxisArray = None, + ): + """ + This method initializes or reinitializes the state vector (x), state covariance (P), + and cursor position. If a checkpoint path is specified in the settings, the model + is loaded from the checkpoint. + + Args: + message (AxisArray): Time-series message containing neural measurements. + x_init (np.ndarray): Initial state vector. + P_init (np.ndarray): Initial state covariance matrix. + """ + if not self._state.model: + if self.settings.checkpoint_path: + self.load_from_checkpoint(self.settings.checkpoint_path) + else: + self._init_model() + ## TODO: fit the model - how to do this given expected inputs X and y? + # for unit test purposes only, given a known kinematic state size + state_dim = 2 + + # # If A is None, the model has not been fitted or loaded from checkpoint + # if self._state.model.A_state_transition_matrix is None: + # raise RuntimeError( + # "Cannot reset state — model has not been fitted or loaded from checkpoint." + # ) + + if self._state.model.A_state_transition_matrix is not None: + state_dim = self._state.model.A_state_transition_matrix.shape[0] + + self._state.x = np.zeros(state_dim) + self._state.P = np.eye(state_dim) + + self._state.buffer_neural = [] + self._state.buffer_state = [] + self._state.buffer_cursor_positions = [] + self._state.buffer_target_positions = [] + self._state.buffer_hold_flags = [] + + def _process(self, message: AxisArray) -> AxisArray: + """ + Process an incoming message using the Kalman filter. + + For each time point in the message: + - Predict the next state + - Update the estimate using the current measurement + - Track the velocity and estimate position + + Args: + message (AxisArray): Time-series message containing neural measurements. + + Returns: + AxisArray: Filtered message containing updated state estimates. + """ + # If checkpoint, load the model from the checkpoint + if not self._state.model and self.settings.checkpoint_path: + self.load_from_checkpoint(self.settings.checkpoint_path) + # No checkpoint means you need to initialize and fit the model + elif not self._state.model: + self._init_model() + state_dim = self._state.model.A_state_transition_matrix.shape[0] + if self._state.x is None: + self._state.x = np.zeros(state_dim) + + filtered_data = np.zeros( + ( + message.data.shape[0], + self._state.model.A_state_transition_matrix.shape[0], + ) + ) + + for i in range(message.data.shape[0]): + measurement = message.data[i] + # Predict + x_pred, P_pred = self._state.model.predict(self._state.x) + + # Update + x_updated = self._state.model.update(measurement, x_pred, P_pred) + + # Store + self._state.x = x_updated.copy() + self._state.P = self._state.model.P_state_covariance.copy() + filtered_data[i] = self._state.x + + return replace( + message, + data=filtered_data, + dims=["time", "state"], + key=f"{message.key}_filtered" if hasattr(message, "key") else "filtered", + ) + + def partial_fit(self, message: SampleMessage) -> None: + """ + Perform refitting using externally provided data. + + Expects message.sample.data (neural input) and message.trigger.value as a dict with: + - Y_state: (n_samples, n_states) array + - intention_velocity_indices: Optional[int] + - target_positions: Optional[np.ndarray] + - cursor_positions: Optional[np.ndarray] + - hold_flags: Optional[list[bool]] + """ + if not hasattr(message, "sample") or not hasattr(message, "trigger"): + raise ValueError("Invalid message format for partial_fit.") + + X = np.array(message.sample.data) + values = message.trigger.value + + if not isinstance(values, dict) or "Y_state" not in values: + raise ValueError( + "partial_fit expects trigger.value to include at least 'Y_state'." + ) + + kwargs = { + "X_neural": X, + "Y_state": np.array(values["Y_state"]), + } + + # Optional fields + for key in [ + "intention_velocity_indices", + "target_positions", + "cursor_positions", + "hold_flags", + ]: + if key in values and values[key] is not None: + kwargs[key if key != "hold_flags" else "hold_indices"] = np.array( + values[key] + ) + + # Call model refit + self._state.model.refit(**kwargs) + + def log_for_refit( + self, + message: AxisArray, + target_position: np.ndarray | None = None, + hold_flag: bool | None = None, + ): + """ + Log data for potential refitting of the model. + + This method stores measurement data, state estimates, and contextual + information (target positions, cursor positions, hold flags) in buffers + for later use in refitting the observation model. This data is used + to adapt the model to changing neural-to-behavioral relationships. + + Args: + message: AxisArray message containing measurement data. + target_position: Target position for the current time point (2,). + hold_flag: Boolean flag indicating if this is a hold period. + """ + if target_position is not None: + self._state.buffer_target_positions.append(target_position.copy()) + if hold_flag is not None: + self._state.buffer_hold_flags.append(hold_flag) + + measurement = message.data[-1] + self._state.buffer_neural.append(measurement.copy()) + self._state.buffer_state.append(self._state.x.copy()) + + def refit_model(self): + """ + Refit the observation model (H, Q) using buffered measurements and contextual data. + + This method updates the model's understanding of the neural-to-state mapping + by calculating a new observation matrix and noise covariance, based on: + - Logged neural data + - Cursor state estimates + - Hold flags and target positions + + Args: + velocity_indices (tuple): Indices in the state vector corresponding to velocity components. + Default assumes 2D velocity at indices (0, 1). + + Raises: + ValueError: If no buffered data exists. + """ + if not self._state.buffer_neural: + print("No buffered data to refit") + return + + kwargs = { + "X_neural": np.array(self._state.buffer_neural), + "Y_state": np.array(self._state.buffer_state), + "intention_velocity_indices": self.settings.velocity_indices[0], + } + + if self._state.buffer_target_positions and self._state.buffer_cursor_positions: + kwargs["target_positions"] = np.array(self._state.buffer_target_positions) + kwargs["cursor_positions"] = np.array(self._state.buffer_cursor_positions) + if self._state.buffer_hold_flags: + kwargs["hold_indices"] = np.array(self._state.buffer_hold_flags) + + self._state.model.refit(**kwargs) + + self._state.buffer_neural.clear() + self._state.buffer_state.clear() + self._state.buffer_cursor_positions.clear() + self._state.buffer_target_positions.clear() + self._state.buffer_hold_flags.clear() + + +class RefitKalmanFilterUnit( + BaseAdaptiveTransformerUnit[ + RefitKalmanFilterSettings, + AxisArray, + AxisArray, + RefitKalmanFilterProcessor, + ] +): + SETTINGS = RefitKalmanFilterSettings diff --git a/tests/integration/test_refit_kalman_system.py b/tests/integration/test_refit_kalman_system.py new file mode 100644 index 0000000..dedff5d --- /dev/null +++ b/tests/integration/test_refit_kalman_system.py @@ -0,0 +1,126 @@ +import pickle +import tempfile +import numpy as np +from pathlib import Path +import os + +import ezmsg.core as ez +from ezmsg.util.messages.axisarray import AxisArray +from ezmsg.util.messagecodec import message_log +from ezmsg.util.messagelogger import MessageLogger, MessageLoggerSettings +from ezmsg.sigproc.synth import Counter, CounterSettings +from ezmsg.util.terminate import TerminateOnTotal, TerminateOnTimeout +from ezmsg.util.terminate import TerminateOnTotalSettings, TerminateOnTimeoutSettings + +from ezmsg.learn.process.refit_kalman import ( + RefitKalmanFilterUnit, + RefitKalmanFilterSettings, +) + + +class RefitKalmanSystemSettings(ez.Settings): + counter_settings: CounterSettings + unit_settings: RefitKalmanFilterSettings + log_settings: MessageLoggerSettings + terminate_total: TerminateOnTotalSettings + terminate_timeout: TerminateOnTimeoutSettings + + +class RefitKalmanSystem(ez.Collection): + SETTINGS = RefitKalmanSystemSettings + + SOURCE = Counter() + UNIT = RefitKalmanFilterUnit() + LOG = MessageLogger() + TERM_TOTAL = TerminateOnTotal() + TERM_TIMEOUT = TerminateOnTimeout() + + def configure(self) -> None: + self.SOURCE.apply_settings(self.SETTINGS.counter_settings) + self.UNIT.apply_settings(self.SETTINGS.unit_settings) + self.LOG.apply_settings(self.SETTINGS.log_settings) + self.TERM_TOTAL.apply_settings(self.SETTINGS.terminate_total) + self.TERM_TIMEOUT.apply_settings(self.SETTINGS.terminate_timeout) + + def network(self) -> ez.NetworkDefinition: + return ( + (self.SOURCE.OUTPUT_SIGNAL, self.UNIT.INPUT_SIGNAL), + (self.UNIT.OUTPUT_SIGNAL, self.LOG.INPUT_MESSAGE), + (self.LOG.OUTPUT_MESSAGE, self.TERM_TOTAL.INPUT_MESSAGE), + (self.LOG.OUTPUT_MESSAGE, self.TERM_TIMEOUT.INPUT), + ) + + +def test_refit_kalman_system(): + """Test complete RefitKalmanFilter system integration. + + This test verifies that the RefitKalmanFilter can be successfully + integrated into a complete ezmsg processing system. The test creates + a realistic processing pipeline that generates synthetic neural signals, + processes them through the Kalman filter, and logs the results. + + The integration test workflow: + 1. Create realistic Kalman filter checkpoint with model parameters + 2. Configure complete system with signal generation and processing + 3. Execute the system with specified duration and parameters + 4. Verify that messages are processed and logged correctly + 5. Validate output message format and data integrity + + Note: + This test creates temporary files that are automatically cleaned up + after execution. The test uses synthetic data and pre-trained model + parameters to ensure consistent and reproducible results. + """ + state_dim = 2 + duration = 2.0 + fs = 10.0 + block_size = 4 + test_path = Path(tempfile.gettempdir()) / "refit_kalman_log.txt" + + A = np.array([[1, 0.1], [0, 1]]) + H = np.array([[1, 0], [0, 1]]) + W = np.eye(2) * 0.05 + Q = np.eye(2) * 0.1 + + # Write checkpoint + checkpoint = { + "A_state_transition_matrix": A, + "H_observation_matrix": H, + "W_process_noise_covariance": W, + "Q_measurement_noise_covariance": Q, + } + with tempfile.NamedTemporaryFile(delete=False, suffix=".pkl", mode="wb") as f: + pickle.dump(checkpoint, f) + checkpoint_file = f.name + + settings = RefitKalmanSystemSettings( + counter_settings=CounterSettings( + fs=fs, + n_ch=1, + n_time=block_size, + dispatch_rate=duration, + ), + unit_settings=RefitKalmanFilterSettings( + checkpoint_path=checkpoint_file, + steady_state=True, + ), + log_settings=MessageLoggerSettings(output=test_path), + terminate_total=TerminateOnTotalSettings(total=int(duration * fs / block_size)), + terminate_timeout=TerminateOnTimeoutSettings(time=5.0), + ) + + system = RefitKalmanSystem(settings=settings) + ez.run(SYSTEM=system) + + messages = [_ for _ in message_log(test_path)] + os.remove(test_path) + os.remove(checkpoint_file) + + assert len(messages) > 0, "No messages logged" + + for msg in messages: + assert isinstance(msg, AxisArray) + assert msg.dims == ["time", "state"] + assert msg.data.ndim == 2 + assert msg.data.shape[1] == state_dim + assert np.all(np.isfinite(msg.data)) diff --git a/tests/unit/test_refit_kalman.py b/tests/unit/test_refit_kalman.py new file mode 100644 index 0000000..9f6cfea --- /dev/null +++ b/tests/unit/test_refit_kalman.py @@ -0,0 +1,535 @@ +import pickle +import tempfile +from pathlib import Path +import numpy as np +import pytest +from ezmsg.util.messages.axisarray import AxisArray + +from ezmsg.learn.process.refit_kalman import ( + RefitKalmanFilterSettings, + RefitKalmanFilterProcessor, +) + + +@pytest.fixture +def create_test_message(): + """Create a standard test message with synthetic neural data. + + Returns: + AxisArray: Test message with 10 time steps, 2 channels, 100 Hz sampling rate + """ + n_time, n_ch = 10, 2 + data = np.linspace(0, 1, n_time * n_ch).reshape(n_time, n_ch) + fs = 100.0 + + return AxisArray( + data=data, + dims=["time", "channels"], + axes={ + "time": AxisArray.TimeAxis(fs=fs), + "channels": AxisArray.CoordinateAxis( + data=np.array([f"ch{i}" for i in range(n_ch)]), + dims=["channels"], + ), + }, + key="test_message", + ) + + +@pytest.fixture +def checkpoint_file(): + """Create a realistic refit-ready Kalman checkpoint file. + + Creates temporary pickle file with Kalman filter parameters for testing. + File should be cleaned up after each test to prevent accumulation. + + Returns: + str: Path to temporary checkpoint file + """ + dt = 0.1 + A = np.array([[1, dt], [0, 1]]) + H = np.array([[1, 0], [0, 1]]) + W = np.eye(2) * 0.05 + Q = np.eye(2) * 0.1 + checkpoint_data = { + "A_state_transition_matrix": A, + "W_process_noise_covariance": W, + "H_observation_matrix": H, + "Q_measurement_noise_covariance": Q, + } + with tempfile.NamedTemporaryFile(delete=False, suffix=".pkl", mode="wb") as f: + pickle.dump(checkpoint_data, f) + return f.name + + +@pytest.mark.parametrize("steady_state", [True, False]) +def test_processor_initialization_with_checkpoint(checkpoint_file, steady_state): + """Test that processor initializes correctly with checkpoint file. + + Verifies processor can load pre-trained model from checkpoint and mark it as fitted. + Tests both steady-state and non-steady-state modes. + + Args: + checkpoint_file: Path to temporary checkpoint file + steady_state: Whether to use steady-state Kalman filter mode + """ + settings = RefitKalmanFilterSettings( + checkpoint_path=checkpoint_file, steady_state=steady_state + ) + processor = RefitKalmanFilterProcessor(settings) + + assert processor._state.model is None + + processor._reset_state() + assert processor._state.model is not None + assert processor._state.model.is_fitted is True + + Path(checkpoint_file).unlink() + + +@pytest.mark.parametrize("steady_state", [True, False]) +def test_processor_initialization_without_checkpoint(steady_state): + """Test that processor initializes correctly without checkpoint file. + + Verifies processor creates unfitted model that requires training before processing. + Tests both steady-state and non-steady-state modes. + + Args: + steady_state: Whether to use steady-state Kalman filter mode + """ + settings = RefitKalmanFilterSettings(steady_state=steady_state) + processor = RefitKalmanFilterProcessor(settings) + + assert processor._state.model is None + + processor._reset_state() + assert processor._state.model is not None + assert processor._state.model.is_fitted is False + + +@pytest.mark.parametrize("steady_state", [True, False]) +def test_message_processing_with_checkpoint( + create_test_message, checkpoint_file, steady_state +): + """Test that messages can be processed with pre-loaded checkpoint. + + Verifies processor can make predictions using pre-trained model parameters. + Tests both steady-state and non-steady-state processing modes. + + Args: + create_test_message: Test message fixture with synthetic neural data + checkpoint_file: Path to temporary checkpoint file + steady_state: Whether to use steady-state Kalman filter mode + """ + settings = RefitKalmanFilterSettings( + checkpoint_path=checkpoint_file, steady_state=steady_state + ) + processor = RefitKalmanFilterProcessor(settings) + msg = create_test_message + + processor._reset_state(msg) + result = processor._process(msg) + + assert processor._state.model is not None + assert isinstance(result, AxisArray) + assert result.data.shape == ( + 10, + processor._state.model.A_state_transition_matrix.shape[0], + ) + assert np.all(np.isfinite(result.data)) + Path(checkpoint_file).unlink() + + +@pytest.mark.parametrize("steady_state", [True, False]) +def test_message_processing_without_checkpoint_requires_fit( + create_test_message, steady_state +): + """Test that processing without checkpoint requires model fitting first. + + Verifies processor prevents processing when no fitted model is available. + Should raise appropriate error to prevent invalid operations. + + Args: + create_test_message: Test message fixture with synthetic neural data + steady_state: Whether to use steady-state Kalman filter mode + + Raises: + AttributeError or ValueError: Expected when processing without fitted model + """ + settings = RefitKalmanFilterSettings(steady_state=steady_state) + processor = RefitKalmanFilterProcessor(settings) + msg = create_test_message + + # Should raise error when processing without fitted model + with pytest.raises((AttributeError, ValueError)): + processor._process(msg) + + +def test_state_update_during_processing(create_test_message, checkpoint_file): + """Test that processor state is updated during message processing. + + Verifies Kalman filter state evolves properly as new data is processed. + State should change from initial values after processing. + + Args: + create_test_message: Test message fixture with synthetic neural data + checkpoint_file: Path to temporary checkpoint file + """ + settings = RefitKalmanFilterSettings( + checkpoint_path=checkpoint_file, steady_state=False + ) + processor = RefitKalmanFilterProcessor(settings) + msg = create_test_message + + processor._reset_state(msg) + x_prior = np.array([0, 0]) + + processor._process(msg) + + x_post = processor._state.x + assert not np.allclose(x_prior, x_post) + # assert np.all(np.isfinite(x_post)) + Path(checkpoint_file).unlink() + + +def test_fit_method_functionality(create_test_message): + """Test fit method functionality for training the model. + + Verifies processor can train model using provided neural and state data. + After fitting, model should be marked as fitted and able to process new data. + + Args: + create_test_message: Test message fixture with synthetic neural data + """ + settings = RefitKalmanFilterSettings(steady_state=True) + processor = RefitKalmanFilterProcessor(settings) + + # Test refit with minimal parameters + X_neural = np.array([[0.1, 0.2], [0.3, 0.4], [0.5, 0.6]]) + Y_state = np.array([[1.0, 2.0], [1.1, 2.1], [1.2, 2.2]]) + + # Test that fit method works with the processor + processor.fit(X_neural, Y_state) + + # Check that model parameters were updated + assert processor._state.model.is_fitted + + # Test that processing works after fit + msg = create_test_message + result = processor._process(msg) + assert result.data.shape == (10, 2) + assert np.all(np.isfinite(result.data)) + + +def test_refit_functionality_with_buffered_data(create_test_message, checkpoint_file): + """Test refit functionality using buffered data from online processing. + + Verifies processor can perform online refitting using collected buffer data. + Model parameters should change after refit, and buffers should be cleared. + + Args: + create_test_message: Test message fixture with synthetic neural data + checkpoint_file: Path to temporary checkpoint file + """ + settings = RefitKalmanFilterSettings( + checkpoint_path=checkpoint_file, steady_state=True + ) + processor = RefitKalmanFilterProcessor(settings) + msg = create_test_message + processor._reset_state(msg) + + # Get initial model parameters + with open(checkpoint_file, "rb") as f: + checkpoint_data = pickle.load(f) + H_initial = checkpoint_data["H_observation_matrix"] + Q_initial = checkpoint_data["Q_measurement_noise_covariance"] + + # Log data for refit + target_pos = np.array([1.0, 1.0]) + for i in range(10): + processor.log_for_refit(msg, target_pos, hold_flag=False) + + # Perform refit + processor.refit_model() + + # Check that model parameters changed + H_after_refit = processor._state.model.H_observation_matrix + Q_after_refit = processor._state.model.Q_measurement_noise_covariance + + assert not np.allclose(H_initial, H_after_refit) + assert not np.allclose(Q_initial, Q_after_refit) + + # Check that buffers are cleared after refit + assert len(processor._state.buffer_neural) == 0 + assert len(processor._state.buffer_state) == 0 + + Path(checkpoint_file).unlink() + + +def test_refit_functionality_without_buffered_data( + create_test_message, checkpoint_file +): + """Test refit functionality when no buffered data exists. + + Verifies processor handles refit requests gracefully with empty buffers. + Should not raise errors and buffers should remain empty. + + Args: + create_test_message: Test message fixture with synthetic neural data + checkpoint_file: Path to temporary checkpoint file + """ + settings = RefitKalmanFilterSettings( + checkpoint_path=checkpoint_file, steady_state=True + ) + processor = RefitKalmanFilterProcessor(settings) + processor._reset_state() + + # Try to refit with empty buffer (should not raise error) + processor.refit_model() + + # Check that buffers remain empty + assert len(processor._state.buffer_neural) == 0 + assert len(processor._state.buffer_state) == 0 + + Path(checkpoint_file).unlink() + + +def test_partial_fit_functionality(create_test_message, checkpoint_file): + """Test partial fit functionality for incremental model updates. + + Verifies processor can perform partial fitting using individual sample messages. + Model parameters should be updated incrementally after partial fit. + + Args: + create_test_message: Test message fixture with synthetic neural data + checkpoint_file: Path to temporary checkpoint file + """ + settings = RefitKalmanFilterSettings( + checkpoint_path=checkpoint_file, steady_state=True + ) + processor = RefitKalmanFilterProcessor(settings) + msg = create_test_message + processor._reset_state(msg) + + with open(checkpoint_file, "rb") as f: + checkpoint_data = pickle.load(f) + H_initial = checkpoint_data["H_observation_matrix"] + Q_initial = checkpoint_data["Q_measurement_noise_covariance"] + + # Create a mock SampleMessage with the expected structure + class MockSampleMessage: + def __init__(self, neural_data, trigger_value): + self.sample = type("obj", (object,), {"data": neural_data})() + self.trigger = type("obj", (object,), {"value": trigger_value})() + + # Create test data + neural_data = np.array( + [[0.1, 0.2], [0.3, 0.4], [0.5, 0.6]] + ) # 3 samples, 2 channels + trigger_value = { + "Y_state": np.array( + [[1.0, 2.0], [1.1, 2.1], [1.2, 2.2]] + ), # 3 samples, 2 states + "intention_velocity_indices": 0, + "target_positions": np.array([[1.0, 1.0], [1.1, 1.1], [1.2, 1.2]]), + "cursor_positions": np.array([[0.0, 0.0], [0.1, 0.1], [0.2, 0.2]]), + "hold_flags": [False, False, False], + } + + mock_message = MockSampleMessage(neural_data, trigger_value) + processor.partial_fit(mock_message) + + assert not np.allclose(H_initial, processor._state.model.H_observation_matrix) + assert not np.allclose( + Q_initial, processor._state.model.Q_measurement_noise_covariance + ) + + Path(checkpoint_file).unlink() + + +def test_hold_periods_functionality(create_test_message, checkpoint_file): + """Test hold periods functionality in refit data collection. + + Verifies processor correctly handles hold and non-hold periods during data collection. + Buffers should be properly managed and cleared after refit. + + Args: + create_test_message: Test message fixture with synthetic neural data + checkpoint_file: Path to temporary checkpoint file + """ + settings = RefitKalmanFilterSettings( + checkpoint_path=checkpoint_file, steady_state=True + ) + processor = RefitKalmanFilterProcessor(settings) + msg = create_test_message + processor._reset_state(msg) + target_pos = np.array([1.0, 1.0]) + + # Log data with mixed hold/non-hold periods + for i in range(10): + hold_flag = i % 2 == 0 # Alternate hold/non-hold + processor.log_for_refit(msg, target_pos, hold_flag=hold_flag) + + # Perform refit + processor.refit_model() + + # Check that buffers are cleared after refit + assert len(processor._state.buffer_neural) == 0 + assert len(processor._state.buffer_state) == 0 + assert len(processor._state.buffer_cursor_positions) == 0 + assert len(processor._state.buffer_target_positions) == 0 + assert len(processor._state.buffer_hold_flags) == 0 + + Path(checkpoint_file).unlink() + + +def test_message_processing_integration(create_test_message, checkpoint_file): + """Test complete message processing integration workflow. + + Verifies complete integration from initialization through message processing. + All components should work together to process neural data and produce valid output. + + Args: + create_test_message: Test message fixture with synthetic neural data + checkpoint_file: Path to temporary checkpoint file + """ + settings = RefitKalmanFilterSettings( + checkpoint_path=checkpoint_file, steady_state=True + ) + processor = RefitKalmanFilterProcessor(settings) + + msg = create_test_message + processor._reset_state(msg) + result = processor._process(msg) + + # Check that processing worked + assert isinstance(result, AxisArray) + assert result.data.shape == (10, 2) + assert np.all(np.isfinite(result.data)) + + # Test that state was updated + assert processor._state.x is not None + assert np.all(np.isfinite(processor._state.x)) + + Path(checkpoint_file).unlink() + + +def test_checkpoint_save_and_load_functionality(create_test_message, checkpoint_file): + """Test checkpoint saving and loading functionality. + + Verifies processor can save model state and new processor can load it successfully. + All model parameters should be preserved across save/load cycles. + + Args: + create_test_message: Test message fixture with synthetic neural data + checkpoint_file: Path to temporary checkpoint file + """ + settings = RefitKalmanFilterSettings(checkpoint_path=checkpoint_file) + processor = RefitKalmanFilterProcessor(settings) + msg = create_test_message + processor._reset_state(msg) + # Save checkpoint + new_checkpoint_path = checkpoint_file.replace(".pkl", "_new.pkl") + processor.save_checkpoint(new_checkpoint_path) + + # Create new processor and load checkpoint + settings_new = RefitKalmanFilterSettings(checkpoint_path=new_checkpoint_path) + processor_new = RefitKalmanFilterProcessor(settings_new) + processor_new._reset_state() + # Check that models are equivalent + assert np.allclose( + processor._state.model.A_state_transition_matrix, + processor_new._state.model.A_state_transition_matrix, + ) + assert np.allclose( + processor._state.model.H_observation_matrix, + processor_new._state.model.H_observation_matrix, + ) + assert np.allclose( + processor._state.model.W_process_noise_covariance, + processor_new._state.model.W_process_noise_covariance, + ) + assert np.allclose( + processor._state.model.Q_measurement_noise_covariance, + processor_new._state.model.Q_measurement_noise_covariance, + ) + + Path(checkpoint_file).unlink() + Path(new_checkpoint_path).unlink() + + +def test_error_handling_for_unfitted_model(): + """Test error handling for unfitted model during checkpoint saving. + + Verifies processor raises appropriate error when attempting to save checkpoint + without fitted model. Should prevent creation of invalid checkpoint files. + + Raises: + ValueError: Expected when attempting to save checkpoint without fitted model + """ + # Test saving checkpoint without fitted model + settings = RefitKalmanFilterSettings(checkpoint_path=None, steady_state=True) + processor = RefitKalmanFilterProcessor(settings) + + with pytest.raises(ValueError, match="Cannot save checkpoint: model not fitted"): + processor.save_checkpoint("test.pkl") + + +def test_error_handling_for_processing_without_checkpoint(create_test_message): + """Test error handling for processing without checkpoint or fitted model. + + Verifies processor prevents processing when no fitted model is available. + Should raise appropriate error to prevent invalid operations. + + Args: + create_test_message: Test message fixture with synthetic neural data + + Raises: + AttributeError or ValueError: Expected when processing without fitted model + """ + settings = RefitKalmanFilterSettings(checkpoint_path=None, steady_state=False) + processor = RefitKalmanFilterProcessor(settings) + msg = create_test_message + + # Should raise error when processing without fitted model + with pytest.raises((AttributeError, ValueError)): + processor._process(msg) + + +def test_steady_state_vs_non_steady_state_processing( + create_test_message, checkpoint_file +): + """Test differences between steady-state and non-steady-state processing. + + Verifies processor produces different results in different modes due to + different Kalman gain computation. Both modes should produce valid results. + + Args: + create_test_message: Test message fixture with synthetic neural data + checkpoint_file: Path to temporary checkpoint file + """ + # Test steady-state mode + settings_steady = RefitKalmanFilterSettings( + checkpoint_path=checkpoint_file, steady_state=True + ) + processor_steady = RefitKalmanFilterProcessor(settings_steady) + msg = create_test_message + processor_steady._reset_state(msg) + + result_steady = processor_steady._process(msg) + + # Test non-steady-state mode + settings_nonsteady = RefitKalmanFilterSettings( + checkpoint_path=checkpoint_file, steady_state=False + ) + processor_nonsteady = RefitKalmanFilterProcessor(settings_nonsteady) + processor_nonsteady._reset_state(msg) + + result_nonsteady = processor_nonsteady._process(msg) + + # Results should be different due to different Kalman gain computation + assert not np.allclose(result_steady.data, result_nonsteady.data) + assert np.all(np.isfinite(result_steady.data)) + assert np.all(np.isfinite(result_nonsteady.data)) + + Path(checkpoint_file).unlink() From a286fba80d2383e055bb57c01707a44dddebcc88 Mon Sep 17 00:00:00 2001 From: Chadwick Boulay Date: Fri, 22 Aug 2025 12:05:28 -0400 Subject: [PATCH 08/16] Update GHA tests script --- .github/workflows/python-tests.yml | 15 +++++++-------- 1 file changed, 7 insertions(+), 8 deletions(-) diff --git a/.github/workflows/python-tests.yml b/.github/workflows/python-tests.yml index 08dfcd4..1237b14 100644 --- a/.github/workflows/python-tests.yml +++ b/.github/workflows/python-tests.yml @@ -2,10 +2,13 @@ name: Test package on: push: - branches: [main] + branches: + - main + - dev pull_request: branches: - main + - dev workflow_dispatch: jobs: @@ -23,16 +26,12 @@ jobs: - uses: actions/checkout@v4 - name: Install uv - uses: astral-sh/setup-uv@v2 + uses: astral-sh/setup-uv@v6 with: - enable-cache: true - cache-dependency-glob: "uv.lock" - - - name: Set up Python ${{ matrix.python-version }} - run: uv python install ${{ matrix.python-version }} + python-version: ${{ matrix.python-version }} - name: Install the project - run: uv sync --all-extras + run: uv sync - name: Lint run: From e3955832ee5315129e05e41614940aec6a4c0056 Mon Sep 17 00:00:00 2001 From: Chadwick Boulay Date: Fri, 22 Aug 2025 12:06:18 -0400 Subject: [PATCH 09/16] GHA publish - bump uv version --- .github/workflows/python-publish-ezmsg-learn.yml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.github/workflows/python-publish-ezmsg-learn.yml b/.github/workflows/python-publish-ezmsg-learn.yml index 3cad99c..c0c6454 100644 --- a/.github/workflows/python-publish-ezmsg-learn.yml +++ b/.github/workflows/python-publish-ezmsg-learn.yml @@ -17,7 +17,7 @@ jobs: - uses: actions/checkout@v4 - name: Install uv - uses: astral-sh/setup-uv@v2 + uses: astral-sh/setup-uv@v6 - name: Build Package run: uv build From e422692b0502d6d596bc2c1e457e228bfa8f6131 Mon Sep 17 00:00:00 2001 From: Chadwick Boulay Date: Fri, 22 Aug 2025 12:10:04 -0400 Subject: [PATCH 10/16] Add pre-commit config --- .pre-commit-config.yaml | 7 +++++++ 1 file changed, 7 insertions(+) create mode 100644 .pre-commit-config.yaml diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml new file mode 100644 index 0000000..a525e91 --- /dev/null +++ b/.pre-commit-config.yaml @@ -0,0 +1,7 @@ +repos: +- repo: https://github.com/astral-sh/ruff-pre-commit + rev: v0.11.12 + hooks: + - id: ruff + args: [ --fix ] + - id: ruff-format \ No newline at end of file From e84f6b81a68472dc089b024895bb2ba7b1567e3a Mon Sep 17 00:00:00 2001 From: Chadwick Boulay Date: Fri, 22 Aug 2025 12:39:50 -0400 Subject: [PATCH 11/16] Add '.' to pythonpath when running pytest so the tests can find local test models. --- pyproject.toml | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/pyproject.toml b/pyproject.toml index 633a9b1..4d0ed33 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -41,3 +41,9 @@ version-file = "src/ezmsg/learn/__version__.py" [tool.hatch.build.targets.wheel] packages = ["src/ezmsg"] + +[tool.pytest.ini_options] +pythonpath = [ + "src", + ".", +] From 34650a07f946148a8c392c358a7193b6d432a264 Mon Sep 17 00:00:00 2001 From: Chadwick Boulay Date: Fri, 22 Aug 2025 12:47:50 -0400 Subject: [PATCH 12/16] Clear mps memory during pytest --- tests/unit/test_torch.py | 8 ++++++++ 1 file changed, 8 insertions(+) diff --git a/tests/unit/test_torch.py b/tests/unit/test_torch.py index e1ace17..d835d9b 100644 --- a/tests/unit/test_torch.py +++ b/tests/unit/test_torch.py @@ -54,6 +54,14 @@ def infer_config_from_state_dict(cls, state_dict): return {"input_size": state_dict["head_a.weight"].shape[1]} +@pytest.fixture(autouse=True) +def mps_memory_cleanup(): + """Fixture to clean up MPS memory after each test.""" + yield + if torch.backends.mps.is_available(): + torch.mps.empty_cache() + + @pytest.fixture def batch_message(): input_dim = 6 From 2c092c60136d2302482b338f49c4508c93f59631 Mon Sep 17 00:00:00 2001 From: Chadwick Boulay Date: Fri, 22 Aug 2025 12:57:38 -0400 Subject: [PATCH 13/16] Attempt to fix MPS out of memory error. --- .github/workflows/python-tests.yml | 2 ++ 1 file changed, 2 insertions(+) diff --git a/.github/workflows/python-tests.yml b/.github/workflows/python-tests.yml index 1237b14..b7ebdb4 100644 --- a/.github/workflows/python-tests.yml +++ b/.github/workflows/python-tests.yml @@ -39,3 +39,5 @@ jobs: - name: Run tests run: uv run pytest tests + env: + PYTORCH_MPS_HIGH_WATERMARK_RATIO: 0.0 From d09fbedce183f6274c8ea22ce478e95df7c8815d Mon Sep 17 00:00:00 2001 From: Chadwick Boulay Date: Fri, 22 Aug 2025 13:13:17 -0400 Subject: [PATCH 14/16] Revert previous commit --- .github/workflows/python-tests.yml | 2 -- 1 file changed, 2 deletions(-) diff --git a/.github/workflows/python-tests.yml b/.github/workflows/python-tests.yml index b7ebdb4..1237b14 100644 --- a/.github/workflows/python-tests.yml +++ b/.github/workflows/python-tests.yml @@ -39,5 +39,3 @@ jobs: - name: Run tests run: uv run pytest tests - env: - PYTORCH_MPS_HIGH_WATERMARK_RATIO: 0.0 From ce8e69f9a108af855c4efb0df5800f745a0a2b23 Mon Sep 17 00:00:00 2001 From: Chadwick Boulay Date: Fri, 22 Aug 2025 13:13:34 -0400 Subject: [PATCH 15/16] Use "cpu" instead of "mps" when running on GitHub Actions. --- tests/unit/test_torch.py | 39 ++++++++++++++++++++++++++++++--------- 1 file changed, 30 insertions(+), 9 deletions(-) diff --git a/tests/unit/test_torch.py b/tests/unit/test_torch.py index d835d9b..c537133 100644 --- a/tests/unit/test_torch.py +++ b/tests/unit/test_torch.py @@ -1,4 +1,6 @@ from pathlib import Path +import os +import sys import numpy as np import pytest @@ -54,6 +56,14 @@ def infer_config_from_state_dict(cls, state_dict): return {"input_size": state_dict["head_a.weight"].shape[1]} +@pytest.fixture +def device(): + """Returns 'cpu' if on macOS in GitHub Actions, otherwise None.""" + if os.getenv("GITHUB_ACTIONS") == "true" and sys.platform == "darwin": + return "cpu" + return None + + @pytest.fixture(autouse=True) def mps_memory_cleanup(): """Fixture to clean up MPS memory after each test.""" @@ -78,7 +88,7 @@ def batch_message(): @pytest.mark.parametrize("input_size,output_size", [(4, 2), (6, 3), (8, 1)]) -def test_inference_shapes(input_size, output_size): +def test_inference_shapes(input_size, output_size, device): data = np.random.randn(12, input_size) msg = AxisArray( data=data, @@ -94,6 +104,7 @@ def test_inference_shapes(input_size, output_size): "input_size": input_size, "output_size": output_size, }, + device=device, ) out = proc(msg)[0] # Check output last dim matches output_size @@ -138,13 +149,14 @@ def test_checkpoint_loading_and_weights(batch_message): @pytest.mark.parametrize("dropout", [0.0, 0.1, 0.5]) -def test_model_kwargs_propagation(dropout, batch_message): +def test_model_kwargs_propagation(dropout, batch_message, device): proc = TorchModelProcessor( model_class=DUMMY_MODEL_CLASS, model_kwargs={ "output_size": 2, "dropout": dropout, }, + device=device, ) proc(batch_message) model = proc._state.model @@ -155,7 +167,7 @@ def test_model_kwargs_propagation(dropout, batch_message): assert model.dropout is None -def test_partial_fit_changes_weights(batch_message): +def test_partial_fit_changes_weights(batch_message, device): proc = TorchModelProcessor( model_class=DUMMY_MODEL_CLASS, loss_fn=torch.nn.MSELoss(), @@ -163,6 +175,7 @@ def test_partial_fit_changes_weights(batch_message): model_kwargs={ "output_size": 2, }, + device=device, ) x = batch_message.data[:1] y = np.random.randn(1, 2) @@ -198,6 +211,7 @@ def test_partial_fit_changes_weights(batch_message): "input_size": x.shape[-1], "output_size": 2, }, + device=device, ) bad_proc(sample) with pytest.raises(ValueError): @@ -209,8 +223,11 @@ def test_model_runs_on_devices(device, batch_message): # Skip unavailable devices if device == "cuda" and not torch.cuda.is_available(): pytest.skip("CUDA not available") - if device == "mps" and not torch.backends.mps.is_available(): - pytest.skip("MPS not available") + if device == "mps": + if not torch.backends.mps.is_available(): + pytest.skip("MPS not available") + if os.getenv("GITHUB_ACTIONS") == "true": + pytest.skip("MPS memory limit too low on free GitHub Actions runner") proc = TorchModelProcessor( model_class=DUMMY_MODEL_CLASS, @@ -226,7 +243,7 @@ def test_model_runs_on_devices(device, batch_message): @pytest.mark.parametrize("batch_size", [1, 5, 10]) -def test_batch_processing(batch_size): +def test_batch_processing(batch_size, device): input_dim = 4 output_dim = 2 data = np.random.randn(batch_size, input_dim) @@ -245,6 +262,7 @@ def test_batch_processing(batch_size): "input_size": input_dim, "output_size": output_dim, }, + device=device, ) out = proc(msg)[0] assert out.data.shape[0] == batch_size @@ -273,10 +291,11 @@ def test_input_size_mismatch_raises_error(): )(msg) -def test_multihead_output(batch_message): +def test_multihead_output(batch_message, device): proc = TorchModelProcessor( model_class=MULTIHEAD_MODEL_CLASS, model_kwargs={"input_size": batch_message.data.shape[1]}, + device=device, ) results = proc(batch_message) @@ -286,7 +305,7 @@ def test_multihead_output(batch_message): assert r.data.ndim == 2 -def test_multihead_partial_fit_with_loss_dict(batch_message): +def test_multihead_partial_fit_with_loss_dict(batch_message, device): proc = TorchModelProcessor( model_class=MULTIHEAD_MODEL_CLASS, loss_fn={ @@ -294,6 +313,7 @@ def test_multihead_partial_fit_with_loss_dict(batch_message): "head_b": torch.nn.L1Loss(), }, model_kwargs={"input_size": batch_message.data.shape[1]}, + device=device, ) proc(batch_message) # initialize model @@ -324,7 +344,7 @@ def test_multihead_partial_fit_with_loss_dict(batch_message): assert not torch.allclose(before_b, after_b) -def test_partial_fit_with_loss_weights(batch_message): +def test_partial_fit_with_loss_weights(batch_message, device): proc = TorchModelProcessor( model_class=MULTIHEAD_MODEL_CLASS, loss_fn={ @@ -336,6 +356,7 @@ def test_partial_fit_with_loss_weights(batch_message): "head_b": 0.5, }, model_kwargs={"input_size": batch_message.data.shape[1]}, + device=device, ) proc(batch_message) From cc2a023b073572953f25787bcfeb6ef90fdf0239 Mon Sep 17 00:00:00 2001 From: Chadwick Boulay Date: Fri, 22 Aug 2025 13:53:18 -0400 Subject: [PATCH 16/16] Fix GHA Windows runner error about tempfile lock. --- tests/unit/test_rnn.py | 19 ++++++++++++++----- tests/unit/test_transformer.py | 17 +++++++++++++---- 2 files changed, 27 insertions(+), 9 deletions(-) diff --git a/tests/unit/test_rnn.py b/tests/unit/test_rnn.py index 1a308d1..f490d20 100644 --- a/tests/unit/test_rnn.py +++ b/tests/unit/test_rnn.py @@ -1,4 +1,5 @@ import tempfile +from pathlib import Path import numpy as np import pytest @@ -172,13 +173,17 @@ def test_rnn_checkpoint_save_load(simple_message): # First pass to initialize model proc(simple_message) - # Save full checkpoint (state_dict + config) - with tempfile.NamedTemporaryFile(suffix=".pt") as tmp: - proc.save_checkpoint(tmp.name) + # Create a temporary file that is closed immediately + with tempfile.NamedTemporaryFile(suffix=".pt", delete=False) as tmp: + checkpoint_path = Path(tmp.name) + + try: + # Save full checkpoint (state_dict + config) + proc.save_checkpoint(str(checkpoint_path)) # Load from checkpoint proc2 = RNNProcessor( - checkpoint_path=tmp.name, + checkpoint_path=str(checkpoint_path), single_precision=single_precision, device="cpu", model_kwargs={ @@ -200,6 +205,10 @@ def test_rnn_checkpoint_save_load(simple_message): f"Mismatch in parameter {key}" ) + finally: + # Ensure the temporary file is deleted + checkpoint_path.unlink(missing_ok=True) + def test_rnn_partial_fit_multiloss(simple_message): hidden_size = 16 @@ -322,7 +331,7 @@ def test_rnn_preserve_state_batch_size_change(): single_precision=True, device="cpu", preserve_state_across_windows=True, - model_kwargs={"hidden_size": hidden_size, "output_heads": output_size}, + model_kwargs={"hidden_size": hidden_size, "output_size": output_size}, ) # First message: 1 window diff --git a/tests/unit/test_transformer.py b/tests/unit/test_transformer.py index 17048e3..71952b0 100644 --- a/tests/unit/test_transformer.py +++ b/tests/unit/test_transformer.py @@ -1,4 +1,5 @@ import tempfile +from pathlib import Path import numpy as np import pytest @@ -172,13 +173,17 @@ def test_transformer_checkpoint_save_load(simple_message): # First pass to initialize model proc(simple_message) - # Save full checkpoint (state_dict + config) - with tempfile.NamedTemporaryFile(suffix=".pt") as tmp: - proc.save_checkpoint(tmp.name) + # Create a temporary file that is closed immediately + with tempfile.NamedTemporaryFile(suffix=".pt", delete=False) as tmp: + checkpoint_path = Path(tmp.name) + + try: + # Save full checkpoint (state_dict + config) + proc.save_checkpoint(str(checkpoint_path)) # Load from checkpoint proc2 = TransformerProcessor( - checkpoint_path=tmp.name, + checkpoint_path=str(checkpoint_path), single_precision=single_precision, device="cpu", model_kwargs={ @@ -200,6 +205,10 @@ def test_transformer_checkpoint_save_load(simple_message): f"Mismatch in parameter {key}" ) + finally: + # Ensure the temporary file is deleted + checkpoint_path.unlink(missing_ok=True) + def test_transformer_partial_fit_multiloss(simple_message): hidden_size = 16