-
Notifications
You must be signed in to change notification settings - Fork 0
/
mechanics.py
180 lines (152 loc) · 5.67 KB
/
mechanics.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
"""Discretize and step plant models.
:copyright: Copyright 2023-2024 by Matt Laporte.
:license: Apache 2.0. See LICENSE for details.
"""
from collections import OrderedDict
from collections.abc import Mapping, Sequence
from functools import cached_property
import logging
from typing import Optional, Type, Union
import diffrax as dfx
import equinox as eqx
import jax.numpy as jnp
from jaxtyping import Array, PRNGKeyArray, PyTree
from feedbax.intervene import AbstractIntervenor
from feedbax.mechanics.plant import AbstractPlant, PlantState
from feedbax._model import wrap_stateless_callable
from feedbax._staged import AbstractStagedModel, ModelStage
from feedbax.state import AbstractState, CartesianState
logger = logging.getLogger(__name__)
class MechanicsState(AbstractState):
"""Type of state PyTree operated on by `Mechanics` instances.
Attributes:
plant: The state of the plant.
effector: The state of the end effector.
solver: The state of the Diffrax solver.
"""
plant: PlantState
effector: CartesianState
solver: PyTree
class Mechanics(AbstractStagedModel[MechanicsState]):
"""Discretizes the dynamics of a plant, and iterates along with the plant statics.
Attributes:
plant: The plant model.
dt: The time step duration.
solver: The Diffrax solver.
intervenors: The intervenors associated with each stage of the model.
"""
plant: AbstractPlant
dt: float
solver: dfx.AbstractSolver
intervenors: Mapping[str, AbstractIntervenor]
def __init__(
self,
plant: AbstractPlant,
dt: float,
solver_type: Type[dfx.AbstractSolver] = dfx.Euler,
intervenors: Optional[
Union[
Sequence[AbstractIntervenor], Mapping[str, Sequence[AbstractIntervenor]]
]
] = None,
*,
key: Optional[PRNGKeyArray] = None,
):
"""
Arguments:
plant: The plant model.
dt: The time step duration.
solver_type: The type of Diffrax solver to use.
intervenors: The intervenors associated with each stage of the model.
"""
self.plant = plant
self.solver = solver_type()
self.dt = dt
self.intervenors = self._get_intervenors_dict(intervenors)
@property
def model_spec(self) -> OrderedDict[str, ModelStage]:
"""Specifies the stages of the model."""
return OrderedDict(
{
"convert_effector_force": ModelStage(
callable=lambda self: self.plant.skeleton.update_state_given_effector_force,
where_input=lambda input, state: state.effector.force,
where_state=lambda state: state.plant.skeleton,
),
"kinematics_update": ModelStage(
# the `plant` module directly implements non-ODE operations
callable=lambda self: self.plant,
where_input=lambda input, state: input,
where_state=lambda state: state.plant,
),
"dynamics_step": ModelStage(
callable=lambda self: self.dynamics_step,
where_input=lambda input, state: input,
where_state=lambda state: state,
),
"get_effector": ModelStage(
callable=lambda self: wrap_stateless_callable(
self.plant.skeleton.effector, pass_key=False
),
where_input=lambda input, state: state.plant.skeleton,
where_state=lambda state: state.effector,
),
}
)
@cached_property
def _term(self) -> dfx.AbstractTerm:
"""The Diffrax term for the aggregate vector field of the plant."""
return dfx.ODETerm(self.plant.vector_field)
def dynamics_step(
self,
input: PyTree[Array],
state: MechanicsState,
*,
key: Optional[PRNGKeyArray] = None,
) -> MechanicsState:
"""Return an updated state after a single step of plant dynamics."""
plant_state, _, _, solver_state, _ = self.solver.step(
self._term,
0,
self.dt,
state.plant,
input,
state.solver,
made_jump=False,
)
return eqx.tree_at(
lambda state: (state.plant, state.solver),
state,
(plant_state, solver_state),
)
@property
def memory_spec(self):
return MechanicsState(
plant=True,
effector=True,
solver=False,
)
def init(
self,
*,
key: PRNGKeyArray,
):
"""Returns an initial state for use with the `Mechanics` module."""
plant_state = self.plant.init(key=key)
init_input = jnp.zeros((self.plant.input_size,))
return MechanicsState(
plant=plant_state,
effector=self.plant.skeleton.effector(plant_state.skeleton),
solver=self.solver.init(self._term, 0, self.dt, plant_state, init_input),
)
# def n_vars(self, where):
# """
# TODO: Given a function that returns a PyTree of leaves of `mechanics_state`,
# return the sum of the sizes of the last dimensions of the leaves.
# Alternatively, just return an empty `mechanics_state`.
# This is useful to automatically determine the number of feedback inputs
# during model construction, when a `mechanics_state` instance isn't yet available.
# See `get_model` in notebook 8.
# """
# # tree.tree_sum_n_features
# ...