diff --git a/docs/source/_rst/_code.rst b/docs/source/_rst/_code.rst index 25f0e3062..965a286b5 100644 --- a/docs/source/_rst/_code.rst +++ b/docs/source/_rst/_code.rst @@ -106,6 +106,7 @@ Models GraphNeuralKernel PirateNet EquivariantGraphNeuralOperator + SINDy Blocks ------------- diff --git a/docs/source/_rst/model/sindy.rst b/docs/source/_rst/model/sindy.rst new file mode 100644 index 000000000..bd507603b --- /dev/null +++ b/docs/source/_rst/model/sindy.rst @@ -0,0 +1,7 @@ +SINDy +======================= +.. currentmodule:: pina.model.sindy + +.. autoclass:: SINDy + :members: + :show-inheritance: \ No newline at end of file diff --git a/pina/model/__init__.py b/pina/model/__init__.py index ee343e53d..1edeacd1a 100644 --- a/pina/model/__init__.py +++ b/pina/model/__init__.py @@ -15,6 +15,7 @@ "GraphNeuralOperator", "PirateNet", "EquivariantGraphNeuralOperator", + "SINDy", ] from .feed_forward import FeedForward, ResidualFeedForward @@ -28,3 +29,4 @@ from .graph_neural_operator import GraphNeuralOperator from .pirate_network import PirateNet from .equivariant_graph_neural_operator import EquivariantGraphNeuralOperator +from .sindy import SINDy diff --git a/pina/model/sindy.py b/pina/model/sindy.py new file mode 100644 index 000000000..a40fa37b4 --- /dev/null +++ b/pina/model/sindy.py @@ -0,0 +1,102 @@ +"""Module for the SINDy model class.""" + +from typing import Callable +import torch +from ..utils import check_consistency, check_positive_integer + + +class SINDy(torch.nn.Module): + r""" + SINDy model class. + + The Sparse Identification of Nonlinear Dynamics (SINDy) model identifies the + governing equations of a dynamical system from data by learning a sparse + linear combination of non-linear candidate functions. + + The output of the model is expressed as product of a library matrix and a + coefficient matrix: + + .. math:: + + \dot{X} = \Theta(X) \Xi + + where: + - :math:`X \in \mathbb{R}^{B \times D}` is the input snapshots of the + system state. Here, :math:`B` is the batch size and :math:`D` is the + number of state variables. + - :math:`\Theta(X) \in \mathbb{R}^{B \times L}` is the library matrix + obtained by evaluating a set of candidate functions on the input data. + Here, :math:`L` is the number of candidate functions in the library. + - :math:`\Xi \in \mathbb{R}^{L \times D}` is the learned coefficient + matrix that defines the sparse model. + + .. seealso:: + + **Original reference**: + Brunton, S.L., Proctor, J.L., and Kutz, J.N. (2016). + *Discovering governing equations from data: Sparse identification of + non-linear dynamical systems.* + Proceedings of the National Academy of Sciences, 113(15), 3932-3937. + DOI: `10.1073/pnas.1517384113 + `_ + """ + + def __init__(self, library, output_dimension): + """ + Initialization of the :class:`SINDy` class. + + :param list[Callable] library: The collection of candidate functions + used to construct the library matrix. Each function must accept an + input tensor of shape ``[..., D]`` and return a tensor of shape + ``[..., 1]``. + :param int output_dimension: The number of output variables, typically + the number of state derivatives. It determines the number of columns + in the coefficient matrix. + :raises ValueError: If ``library`` is not a list of callables. + :raises AssertionError: If ``output_dimension`` is not a positive + integer. + """ + super().__init__() + + # Check consistency + check_positive_integer(output_dimension, strict=True) + check_consistency(library, Callable) + if not isinstance(library, list): + raise ValueError("`library` must be a list of callables.") + + # Initialization + self._library = library + self._coefficients = torch.nn.Parameter( + torch.zeros(len(library), output_dimension) + ) + + def forward(self, x): + """ + Forward pass of the :class:`SINDy` model. + + :param torch.Tensor x: The input batch of state variables. + :return: The predicted time derivatives of the state variables. + :rtype: torch.Tensor + """ + theta = torch.stack([f(x) for f in self.library], dim=-2) + return torch.einsum("...li , lo -> ...o", theta, self.coefficients) + + @property + def library(self): + """ + The library of candidate functions. + + :return: The library. + :rtype: list[Callable] + """ + return self._library + + @property + def coefficients(self): + """ + The coefficients of the model. + + :return: The coefficients. + :rtype: torch.Tensor + """ + return self._coefficients diff --git a/tests/test_model/test_sindy.py b/tests/test_model/test_sindy.py new file mode 100644 index 000000000..223c4eba2 --- /dev/null +++ b/tests/test_model/test_sindy.py @@ -0,0 +1,55 @@ +import torch +import pytest +from pina.model import SINDy + +# Define a simple library of candidate functions and some test data +library = [lambda x: torch.pow(x, 2), lambda x: torch.sin(x)] + + +@pytest.mark.parametrize("data", [torch.rand((20, 1)), torch.rand((5, 20, 1))]) +def test_constructor(data): + SINDy(library, data.shape[-1]) + + # Should fail if output_dimension is not a positive integer + with pytest.raises(AssertionError): + SINDy(library, "not_int") + with pytest.raises(AssertionError): + SINDy(library, -1) + + # Should fail if library is not a list + with pytest.raises(ValueError): + SINDy(lambda x: torch.pow(x, 2), 3) + + # Should fail if library is not a list of callables + with pytest.raises(ValueError): + SINDy([1, 2, 3], 3) + + +@pytest.mark.parametrize("data", [torch.rand((20, 1)), torch.rand((5, 20, 1))]) +def test_forward(data): + + # Define model + model = SINDy(library, data.shape[-1]) + with torch.no_grad(): + model.coefficients.data.fill_(1.0) + + # Evaluate model + output_ = model(data) + vals = data.pow(2) + torch.sin(data) + + print(data.shape, output_.shape, vals.shape) + + assert output_.shape == data.shape + assert torch.allclose(output_, vals, atol=1e-6, rtol=1e-6) + + +@pytest.mark.parametrize("data", [torch.rand((20, 1)), torch.rand((5, 20, 1))]) +def test_backward(data): + + # Define and evaluate model + model = SINDy(library, data.shape[-1]) + output_ = model(data.requires_grad_()) + + loss = output_.mean() + loss.backward() + assert data.grad.shape == data.shape