Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions docs/source/_rst/_code.rst
Original file line number Diff line number Diff line change
Expand Up @@ -106,6 +106,7 @@ Models
GraphNeuralKernel <model/graph_neural_operator_integral_kernel.rst>
PirateNet <model/pirate_network.rst>
EquivariantGraphNeuralOperator <model/equivariant_graph_neural_operator.rst>
SINDy <model/sindy.rst>

Blocks
-------------
Expand Down
7 changes: 7 additions & 0 deletions docs/source/_rst/model/sindy.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
SINDy
=======================
.. currentmodule:: pina.model.sindy

.. autoclass:: SINDy
:members:
:show-inheritance:
2 changes: 2 additions & 0 deletions pina/model/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
"GraphNeuralOperator",
"PirateNet",
"EquivariantGraphNeuralOperator",
"SINDy",
]

from .feed_forward import FeedForward, ResidualFeedForward
Expand All @@ -28,3 +29,4 @@
from .graph_neural_operator import GraphNeuralOperator
from .pirate_network import PirateNet
from .equivariant_graph_neural_operator import EquivariantGraphNeuralOperator
from .sindy import SINDy
102 changes: 102 additions & 0 deletions pina/model/sindy.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,102 @@
"""Module for the SINDy model class."""

from typing import Callable
import torch
from ..utils import check_consistency, check_positive_integer


class SINDy(torch.nn.Module):
r"""
SINDy model class.

The Sparse Identification of Nonlinear Dynamics (SINDy) model identifies the
governing equations of a dynamical system from data by learning a sparse
linear combination of non-linear candidate functions.

The output of the model is expressed as product of a library matrix and a
coefficient matrix:

.. math::

\dot{X} = \Theta(X) \Xi

where:
- :math:`X \in \mathbb{R}^{B \times D}` is the input snapshots of the
system state. Here, :math:`B` is the batch size and :math:`D` is the
number of state variables.
- :math:`\Theta(X) \in \mathbb{R}^{B \times L}` is the library matrix
obtained by evaluating a set of candidate functions on the input data.
Here, :math:`L` is the number of candidate functions in the library.
- :math:`\Xi \in \mathbb{R}^{L \times D}` is the learned coefficient
matrix that defines the sparse model.

.. seealso::

**Original reference**:
Brunton, S.L., Proctor, J.L., and Kutz, J.N. (2016).
*Discovering governing equations from data: Sparse identification of
non-linear dynamical systems.*
Proceedings of the National Academy of Sciences, 113(15), 3932-3937.
DOI: `10.1073/pnas.1517384113
<https://doi.org/10.1073/pnas.1517384113>`_
"""

def __init__(self, library, output_dimension):
"""
Initialization of the :class:`SINDy` class.

:param list[Callable] library: The collection of candidate functions
used to construct the library matrix. Each function must accept an
input tensor of shape ``[..., D]`` and return a tensor of shape
``[..., 1]``.
:param int output_dimension: The number of output variables, typically
the number of state derivatives. It determines the number of columns
in the coefficient matrix.
:raises ValueError: If ``library`` is not a list of callables.
:raises AssertionError: If ``output_dimension`` is not a positive
integer.
"""
super().__init__()

# Check consistency
check_positive_integer(output_dimension, strict=True)
check_consistency(library, Callable)
if not isinstance(library, list):
raise ValueError("`library` must be a list of callables.")

# Initialization
self._library = library
self._coefficients = torch.nn.Parameter(
torch.zeros(len(library), output_dimension)
)

def forward(self, x):
"""
Forward pass of the :class:`SINDy` model.

:param torch.Tensor x: The input batch of state variables.
:return: The predicted time derivatives of the state variables.
:rtype: torch.Tensor
"""
theta = torch.stack([f(x) for f in self.library], dim=-2)
return torch.einsum("...li , lo -> ...o", theta, self.coefficients)

@property
def library(self):
"""
The library of candidate functions.

:return: The library.
:rtype: list[Callable]
"""
return self._library

@property
def coefficients(self):
"""
The coefficients of the model.

:return: The coefficients.
:rtype: torch.Tensor
"""
return self._coefficients
55 changes: 55 additions & 0 deletions tests/test_model/test_sindy.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,55 @@
import torch
import pytest
from pina.model import SINDy

# Define a simple library of candidate functions and some test data
library = [lambda x: torch.pow(x, 2), lambda x: torch.sin(x)]


@pytest.mark.parametrize("data", [torch.rand((20, 1)), torch.rand((5, 20, 1))])
def test_constructor(data):
SINDy(library, data.shape[-1])

# Should fail if output_dimension is not a positive integer
with pytest.raises(AssertionError):
SINDy(library, "not_int")
with pytest.raises(AssertionError):
SINDy(library, -1)

# Should fail if library is not a list
with pytest.raises(ValueError):
SINDy(lambda x: torch.pow(x, 2), 3)

# Should fail if library is not a list of callables
with pytest.raises(ValueError):
SINDy([1, 2, 3], 3)


@pytest.mark.parametrize("data", [torch.rand((20, 1)), torch.rand((5, 20, 1))])
def test_forward(data):

# Define model
model = SINDy(library, data.shape[-1])
with torch.no_grad():
model.coefficients.data.fill_(1.0)

# Evaluate model
output_ = model(data)
vals = data.pow(2) + torch.sin(data)

print(data.shape, output_.shape, vals.shape)

assert output_.shape == data.shape
assert torch.allclose(output_, vals, atol=1e-6, rtol=1e-6)


@pytest.mark.parametrize("data", [torch.rand((20, 1)), torch.rand((5, 20, 1))])
def test_backward(data):

# Define and evaluate model
model = SINDy(library, data.shape[-1])
output_ = model(data.requires_grad_())

loss = output_.mean()
loss.backward()
assert data.grad.shape == data.shape