Skip to content

Commit

Permalink
Merge pull request #151 from decargroup/bugfix/150-fix-performance-is…
Browse files Browse the repository at this point in the history
…sues-in-predict_trajectory

Fix performance issues in `pykoop.predict_trajectory()`
  • Loading branch information
sdahdah committed Aug 31, 2023
2 parents 18ec6e8 + 7b6bb78 commit 77a7a0d
Show file tree
Hide file tree
Showing 13 changed files with 437 additions and 122 deletions.
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -288,6 +288,7 @@ dmypy.json

# profiling data
.prof
*.prof

### Vim ###
# Swap
Expand Down
38 changes: 38 additions & 0 deletions benchmarks/benchmark_predict_trajectory.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,38 @@
"""Benchmark :func:`pykoop.predict_trajectory()`.
Outputs a ``.prof`` file that can be visualized using ``snakeviz``.
"""

import cProfile

import pykoop


def main():
"""Benchmark :func:`pykoop.predict_trajectory()`."""
pykoop.set_config(skip_validation=True)

# Get example mass-spring-damper data
eg = pykoop.example_data_pendulum()
# Create pipeline
kp = pykoop.KoopmanPipeline(
lifting_functions=[
('pl', pykoop.PolynomialLiftingFn(order=2)),
('dl', pykoop.DelayLiftingFn(n_delays_state=2, n_delays_input=2)),
],
regressor=pykoop.Edmd(alpha=1),
)
# Fit the pipeline
kp.fit(
eg['X_train'],
n_inputs=eg['n_inputs'],
episode_feature=eg['episode_feature'],
)
# Predict using the pipeline
with cProfile.Profile() as pr:
X_pred = kp.predict_trajectory(eg['X_train'])
pr.dump_stats('benchmark_predict_trajectory.prof')


if __name__ == '__main__':
main()
27 changes: 27 additions & 0 deletions benchmarks/benchmark_unique_episodes.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,27 @@
"""Benchmark :func:`pykoop.koopman_pipeline._unique_episodes()`.
It's very hard to do better than :func:`pandas.unique`, so I will stop messing
with it. Another approach could be to store the unique episodes somewhere for
reuse, but that could be convoluted.
"""

import timeit

import numpy as np

import pykoop


def main():
"""Benchmark :func:`pykoop.koopman_pipeline._unique_episodes()`."""
pykoop.set_config(skip_validation=True)
"""Benchmark :func:`pykoop.unique_episodes()`."""
X_ep = np.array([0] * 100 + [1] * 1000 + [2] * 500 + [10] * 1000)
n_loop = 100_000
time = timeit.timeit(lambda: pykoop.unique_episodes(X_ep), number=n_loop)
print(f' Total time: {time} s')
print(f'Time per loop: {time / n_loop} s')


if __name__ == '__main__':
main()
20 changes: 19 additions & 1 deletion doc/pykoop.rst
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,8 @@ The data matrices provided to :func:`fit` (as well as :func:`transform`
and :func:`inverse_transform`) must obey the following format:

1. If ``episode_feature`` is true, the first feature must indicate
which episode each timestep belongs to.
which episode each timestep belongs to. The episode feature must contain
positive integers only.
2. The last ``n_inputs`` features must be exogenous inputs.
3. The remaining features are considered to be states (input-independent).

Expand Down Expand Up @@ -84,6 +85,9 @@ State 0 State 1 State 2
In the above case, each timestep is assumed to belong to the same
episode.

.. important::
The episode feature must contain positive integers only!

Koopman regressors, which implement the interface defined in
:class:`pykoop.KoopmanRegressor` are distinct from ``scikit-learn`` regressors
in that they support the episode feature and state tracking attributes used by
Expand Down Expand Up @@ -272,6 +276,20 @@ The following class and function implementations are located in
pykoop.dynamic_models.Pendulum


Configuration
=============

The following functions allow the user to interact with ``pykoop``'s global
configuration.

.. autosummary::
:toctree: _autosummary/

pykoop.get_config
pykoop.set_config
pykoop.config_context


Extending ``pykoop``
====================

Expand Down
6 changes: 4 additions & 2 deletions pykoop/__init__.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
"""Koopman operator identification library in Python."""

from ._sklearn_config.config import config_context, get_config, set_config
from .centers import (
Centers,
ClusterCenters,
Expand Down Expand Up @@ -29,17 +30,18 @@
shift_episodes,
split_episodes,
strip_initial_conditions,
unique_episodes,
)
from .lifting_functions import (
BilinearInputLiftingFn,
ConstantLiftingFn,
DelayLiftingFn,
KernelApproxLiftingFn,
PolynomialLiftingFn,
RbfLiftingFn,
KernelApproxLiftingFn,
SkLearnLiftingFn,
)
from .regressors import Dmd, Dmdc, Edmd, EdmdMeta, DataRegressor
from .regressors import DataRegressor, Dmd, Dmdc, Edmd, EdmdMeta
from .tsvd import Tsvd
from .util import (
AnglePreprocessor,
Expand Down
29 changes: 29 additions & 0 deletions pykoop/_sklearn_config/LICENSE
Original file line number Diff line number Diff line change
@@ -0,0 +1,29 @@
BSD 3-Clause License

Copyright (c) 2007-2022 The scikit-learn developers.
All rights reserved.

Redistribution and use in source and binary forms, with or without
modification, are permitted provided that the following conditions are met:

* Redistributions of source code must retain the above copyright notice, this
list of conditions and the following disclaimer.

* Redistributions in binary form must reproduce the above copyright notice,
this list of conditions and the following disclaimer in the documentation
and/or other materials provided with the distribution.

* Neither the name of the copyright holder nor the names of its
contributors may be used to endorse or promote products derived from
this software without specific prior written permission.

THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
Empty file.
99 changes: 99 additions & 0 deletions pykoop/_sklearn_config/config.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,99 @@
"""Global configuration for ``pykoop``.
Based on code from the ``scikit-learn`` project. Original author of the file
is Joel Nothman. Specifically, the original file is
``scikit-learn/sklearn/_config.py`` from commit ``894b335``.
Distributed under the BSD-3-Clause License. See ``LICENSE`` in this directory
for the full license.
"""

import contextlib
import os
import threading
from typing import Any, Dict, Optional

_global_config = {
'skip_validation': False,
}
_threadlocal = threading.local()


def _get_threadlocal_config() -> Dict[str, Any]:
"""Get a threadlocal mutable configuration.
If the configuration does not exist, copy the default global configuration.
"""
if not hasattr(_threadlocal, 'global_config'):
_threadlocal.global_config = _global_config.copy()
return _threadlocal.global_config


def get_config() -> Dict[str, Any]:
"""Retrieve current values for configuration set by :func:`set_config`.
Returns
-------
config : dict
Keys are parameter names that can be passed to :func:`set_config`.
Examples
--------
Get configuation
>>> pykoop.get_config()
{'skip_validation': False}
"""
# Return a copy of the threadlocal configuration so that users will
# not be able to modify the configuration with the returned dict.
return _get_threadlocal_config().copy()


def set_config(skip_validation: Optional[bool] = None) -> None:
"""Set global configuration.
Parameters
----------
skip_validation : Optional[bool]
Set to ``True`` to skip all parameter validation. Can save significant
time, especially in func:`pykoop.predict_trajectory()` but risks
crashes.
Examples
--------
Set configuation
>>> pykoop.set_config(skip_validation=False)
"""
local_config = _get_threadlocal_config()
# Set parameters
if skip_validation is not None:
local_config['skip_validation'] = skip_validation


@contextlib.contextmanager
def config_context(*, skip_validation=None):
"""Context manager for global configuration.
Parameters
----------
skip_validation : Optional[bool]
Set to ``True`` to skip all parameter validation. Can save significant
time, especially in func:`pykoop.predict_trajectory()` but risks
crashes.
Examples
--------
Use config context manager
>>> with pykoop.config_context(skip_validation=False):
... pykoop.KoopmanPipeline()
KoopmanPipeline()
"""
old_config = get_config()
set_config(skip_validation=skip_validation)

try:
yield
finally:
set_config(**old_config)

0 comments on commit 77a7a0d

Please sign in to comment.