-
Couldn't load subscription status.
- Fork 84
Add SINDy model #660
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Merged
+167
−0
Merged
Add SINDy model #660
Changes from all commits
Commits
File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,7 @@ | ||
| SINDy | ||
| ======================= | ||
| .. currentmodule:: pina.model.sindy | ||
|
|
||
| .. autoclass:: SINDy | ||
| :members: | ||
| :show-inheritance: |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,102 @@ | ||
| """Module for the SINDy model class.""" | ||
|
|
||
| from typing import Callable | ||
| import torch | ||
| from ..utils import check_consistency, check_positive_integer | ||
|
|
||
|
|
||
| class SINDy(torch.nn.Module): | ||
| r""" | ||
| SINDy model class. | ||
|
|
||
| The Sparse Identification of Nonlinear Dynamics (SINDy) model identifies the | ||
| governing equations of a dynamical system from data by learning a sparse | ||
| linear combination of non-linear candidate functions. | ||
|
|
||
| The output of the model is expressed as product of a library matrix and a | ||
| coefficient matrix: | ||
|
|
||
| .. math:: | ||
|
|
||
| \dot{X} = \Theta(X) \Xi | ||
|
|
||
| where: | ||
| - :math:`X \in \mathbb{R}^{B \times D}` is the input snapshots of the | ||
| system state. Here, :math:`B` is the batch size and :math:`D` is the | ||
| number of state variables. | ||
| - :math:`\Theta(X) \in \mathbb{R}^{B \times L}` is the library matrix | ||
| obtained by evaluating a set of candidate functions on the input data. | ||
| Here, :math:`L` is the number of candidate functions in the library. | ||
| - :math:`\Xi \in \mathbb{R}^{L \times D}` is the learned coefficient | ||
| matrix that defines the sparse model. | ||
|
|
||
| .. seealso:: | ||
|
|
||
| **Original reference**: | ||
| Brunton, S.L., Proctor, J.L., and Kutz, J.N. (2016). | ||
| *Discovering governing equations from data: Sparse identification of | ||
| non-linear dynamical systems.* | ||
| Proceedings of the National Academy of Sciences, 113(15), 3932-3937. | ||
| DOI: `10.1073/pnas.1517384113 | ||
| <https://doi.org/10.1073/pnas.1517384113>`_ | ||
| """ | ||
|
|
||
| def __init__(self, library, output_dimension): | ||
| """ | ||
| Initialization of the :class:`SINDy` class. | ||
|
|
||
| :param list[Callable] library: The collection of candidate functions | ||
| used to construct the library matrix. Each function must accept an | ||
| input tensor of shape ``[..., D]`` and return a tensor of shape | ||
| ``[..., 1]``. | ||
| :param int output_dimension: The number of output variables, typically | ||
| the number of state derivatives. It determines the number of columns | ||
| in the coefficient matrix. | ||
| :raises ValueError: If ``library`` is not a list of callables. | ||
| :raises AssertionError: If ``output_dimension`` is not a positive | ||
| integer. | ||
| """ | ||
| super().__init__() | ||
|
|
||
| # Check consistency | ||
| check_positive_integer(output_dimension, strict=True) | ||
| check_consistency(library, Callable) | ||
| if not isinstance(library, list): | ||
| raise ValueError("`library` must be a list of callables.") | ||
|
|
||
| # Initialization | ||
| self._library = library | ||
| self._coefficients = torch.nn.Parameter( | ||
| torch.zeros(len(library), output_dimension) | ||
| ) | ||
|
|
||
| def forward(self, x): | ||
| """ | ||
| Forward pass of the :class:`SINDy` model. | ||
|
|
||
| :param torch.Tensor x: The input batch of state variables. | ||
| :return: The predicted time derivatives of the state variables. | ||
| :rtype: torch.Tensor | ||
| """ | ||
| theta = torch.stack([f(x) for f in self.library], dim=-2) | ||
| return torch.einsum("...li , lo -> ...o", theta, self.coefficients) | ||
|
|
||
| @property | ||
| def library(self): | ||
| """ | ||
| The library of candidate functions. | ||
|
|
||
| :return: The library. | ||
| :rtype: list[Callable] | ||
| """ | ||
| return self._library | ||
|
|
||
| @property | ||
| def coefficients(self): | ||
| """ | ||
| The coefficients of the model. | ||
|
|
||
| :return: The coefficients. | ||
| :rtype: torch.Tensor | ||
| """ | ||
| return self._coefficients |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,55 @@ | ||
| import torch | ||
| import pytest | ||
| from pina.model import SINDy | ||
|
|
||
| # Define a simple library of candidate functions and some test data | ||
| library = [lambda x: torch.pow(x, 2), lambda x: torch.sin(x)] | ||
|
|
||
|
|
||
| @pytest.mark.parametrize("data", [torch.rand((20, 1)), torch.rand((5, 20, 1))]) | ||
| def test_constructor(data): | ||
| SINDy(library, data.shape[-1]) | ||
|
|
||
| # Should fail if output_dimension is not a positive integer | ||
| with pytest.raises(AssertionError): | ||
| SINDy(library, "not_int") | ||
| with pytest.raises(AssertionError): | ||
| SINDy(library, -1) | ||
|
|
||
| # Should fail if library is not a list | ||
| with pytest.raises(ValueError): | ||
| SINDy(lambda x: torch.pow(x, 2), 3) | ||
|
|
||
| # Should fail if library is not a list of callables | ||
| with pytest.raises(ValueError): | ||
| SINDy([1, 2, 3], 3) | ||
|
|
||
|
|
||
| @pytest.mark.parametrize("data", [torch.rand((20, 1)), torch.rand((5, 20, 1))]) | ||
| def test_forward(data): | ||
|
|
||
| # Define model | ||
| model = SINDy(library, data.shape[-1]) | ||
| with torch.no_grad(): | ||
| model.coefficients.data.fill_(1.0) | ||
|
|
||
| # Evaluate model | ||
| output_ = model(data) | ||
| vals = data.pow(2) + torch.sin(data) | ||
|
|
||
| print(data.shape, output_.shape, vals.shape) | ||
|
|
||
| assert output_.shape == data.shape | ||
| assert torch.allclose(output_, vals, atol=1e-6, rtol=1e-6) | ||
|
|
||
|
|
||
| @pytest.mark.parametrize("data", [torch.rand((20, 1)), torch.rand((5, 20, 1))]) | ||
| def test_backward(data): | ||
|
|
||
| # Define and evaluate model | ||
| model = SINDy(library, data.shape[-1]) | ||
| output_ = model(data.requires_grad_()) | ||
|
|
||
| loss = output_.mean() | ||
| loss.backward() | ||
| assert data.grad.shape == data.shape | ||
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
Uh oh!
There was an error while loading. Please reload this page.