-
Notifications
You must be signed in to change notification settings - Fork 0
/
dynamics.py
122 lines (96 loc) · 3.43 KB
/
dynamics.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
"""Base classes for continuous dynamical systems.
:copyright: Copyright 2023-2024 by Matt L Laporte.
:license: Apache 2.0, see LICENSE for details.
"""
from abc import abstractmethod, abstractproperty
import logging
from typing import Optional
import equinox as eqx
from equinox import AbstractVar
import jax
import jax.numpy as jnp
from jaxtyping import Array, Float, PRNGKeyArray, PyTree
from feedbax._model import AbstractModel
from feedbax.state import CartesianState, StateBounds, StateT
logger = logging.getLogger(__name__)
class AbstractDynamicalSystem(AbstractModel[StateT]):
"""Base class for continuous dynamical systems.
??? dev-note "Development note"
The signature of `vector_field` matches that expected by
`diffrax.ODETerm`. However, the argument that is called
`args` in the Diffrax documentation, we call `input`, since
we use it for input signals (e.g. forces from the controller)
Vector fields for biomechanical models are generally not
time-dependent. That is, the argument `t` to `vector_field` is
typically unused. This is apparent in the way we alias `vector_field`
to `__call__`, which is a method that `AbstractModel` requires.
Perhaps it is unnecessary to inherit from `AbstractModel`, though.
"""
def __call__(
self,
input: PyTree[Array],
state: StateT,
key: PRNGKeyArray,
) -> StateT:
"""Alias for `vector_field`, with a modified signature."""
return self.vector_field(None, state, input)
@abstractmethod
def vector_field(
self,
t: float | None,
state: StateT,
input: PyTree[Array], # controls
) -> StateT:
"""Returns scalar (e.g. time) derivatives of the system's states."""
...
@abstractproperty
def input_size(self) -> int:
"""Number of input variables."""
...
@abstractmethod
def init(self, *, key: PRNGKeyArray) -> StateT:
"""Returns the initial state of the system."""
...
@property
def step(self) -> "AbstractDynamicalSystem[StateT]":
return self
class AbstractLTISystem(AbstractDynamicalSystem[Array]):
"""
!!! ref inline end ""
Inspired by [this Diffrax example](https://docs.kidger.site/diffrax/examples/kalman_filter/).
A linear, continuous, time-invariant system.
Attributes:
A: The state evolution matrix.
B: The control matrix.
C: The observation matrix.
"""
A: AbstractVar[Float[Array, "state state"]]
B: AbstractVar[Float[Array, "state input"]]
C: AbstractVar[Float[Array, "obs state"]]
@jax.named_scope("fbx.AbstractLTISystem")
def vector_field(
self,
t: float | None,
state: Array,
input: Array,
) -> Array:
"""Returns time derivatives of the system's states."""
return self.A @ state + self.B @ input
@property
def input_size(self) -> int:
"""Number of control variables."""
return self.B.shape[1]
@property
def state_size(self) -> int:
"""Number of state variables."""
return self.A.shape[1]
@property
def bounds(self) -> StateBounds[Array]:
return StateBounds(low=None, high=None)
def init(
self,
*,
key: Optional[PRNGKeyArray] = None,
) -> Array:
"""Return a default state for the linear system."""
return jnp.zeros(self.state_size)