-
Notifications
You must be signed in to change notification settings - Fork 0
/
plant.py
459 lines (393 loc) · 16.2 KB
/
plant.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
"""
:copyright: Copyright 2023-2024 by Matt Laporte.
:license: Apache 2.0. See LICENSE for details.
"""
from abc import abstractmethod, abstractproperty
from collections import OrderedDict
from collections.abc import Callable, Mapping, Sequence
from functools import cached_property
import logging
from typing import Generic, Optional, Self, Tuple, Union
import equinox as eqx
from equinox import AbstractVar, Module, field
import jax
from jax import Array
import jax.numpy as jnp
from jaxtyping import Float, PRNGKeyArray, PyTree
from feedbax.dynamics import AbstractDynamicalSystem
from feedbax.intervene import AbstractIntervenor
from feedbax.mechanics.muscle import AbstractMuscle, AbstractMuscleState
from feedbax.mechanics.skeleton.arm import TwoLink
from feedbax.mechanics.skeleton.skeleton import AbstractSkeleton, AbstractSkeletonState
from feedbax._staged import AbstractStagedModel, ModelStage
from feedbax.state import AbstractState, StateBounds, StateT, clip_state
logger = logging.getLogger(__name__)
class PlantState(Module):
"""The state of a biomechanical model.
Some models may only possess a skeleton, with forces input directly by a
controller.
Attributes.
skeleton: The state of the skeleton.
muscles: The state of the muscles, if included in the model.
"""
skeleton: AbstractSkeletonState
muscles: Optional[AbstractMuscleState] = None
class DynamicsComponent(eqx.Module, Generic[StateT]):
dynamics: AbstractDynamicalSystem # [StateS]
where_input: Callable[[PyTree[Array], StateT], PyTree[Array]]
where_state: Callable[[StateT], PyTree[Array]] # returns StateS
class AbstractPlant(
AbstractStagedModel[PlantState],
AbstractDynamicalSystem[PlantState],
):
"""Base class for models of muscoloskeletal systems.
!!! Note ""
These models describe both 1) the continuous dynamics, and 2) instantaneous
dependencies between variables.
For example, we may model 1) the rate of change of skeletal joint angles
given torques applied to the joints, but also 2) the force generated by
a muscle, which is not described by a differential equation but is
directly dependent in each instant on the skeletal geometry.
Kinematic/geometric updates are specified in `model_spec`, and dynamic updates
in `dynamics_spec`. Calling an `AbstractPlant` instance will only perform the
kinematic updates defined in `model_spec`. Normally we wrap the instance in a
`Mechanics` instance to discretize the dynamics—then, calling the `Mechanics`
instance will perform both sets of updates.
Attributes:
skeleton: The model of skeletal dynamics.
muscle_model: The muscle model, if the model includes muscles.
clip_states: Whether to clip the state to its bounds.
"""
skeleton: AbstractVar[AbstractSkeleton]
muscle_model: AbstractVar[Optional[AbstractMuscle]]
clip_states: AbstractVar[bool]
def vector_field(
self, t: float | None, state: PlantState, input: PyTree[Array]
) -> PlantState:
"""Return the time derivatives of musculoskeletal variables,
where those derivatives are defined.
!!! Note ""
Aggregates vector fields for different substates of the plant
state, as described by `dynamics_spec`.
Arguments:
t: The simulation time. Typically this is unused by any of the
constituent fields.
state: The state of the musculoskeletal system.
input: The control inputs to the musculoskeletal system.
"""
d_state = jax.tree_map(jnp.zeros_like, state)
for component in self.dynamics_spec.values():
d_state = eqx.tree_at(
component.where_state,
d_state,
component.dynamics.vector_field(
t, component.where_state(state), component.where_input(input, state)
),
)
return d_state
@abstractproperty
def model_spec(self) -> OrderedDict[str, ModelStage[Self, PlantState]]:
"""Specifies kinematic updates to the musculoskeletal state."""
...
@abstractproperty
def dynamics_spec(self) -> dict[str, DynamicsComponent[PlantState]]:
"""Aggregates differential equations for different substates of the
musculoskeletal state."""
...
@abstractmethod
def init(self, *, key: PRNGKeyArray) -> PlantState:
"""Returns a default state for the plant."""
...
@abstractproperty
def input_size(self) -> int:
"""Number of control inputs."""
...
@property
def bounds(self) -> PyTree[StateBounds]:
"""Aggregates the bounds specified by the skeletal and muscle models."""
if self.muscle_model is not None:
muscle_bounds = self.muscle_model.bounds
else:
muscle_bounds = StateBounds(low=None, high=None)
return PlantState(
skeleton=self.skeleton.bounds,
muscles=muscle_bounds,
)
def _clip_state(self, input, state, *, key: Optional[PRNGKeyArray] = None):
if self.clip_states:
return clip_state(input, state)
else:
return state
class DirectForceInput(AbstractPlant):
"""Model of a skeleton controlled directly by forces/torques—no muscles.
!!! Note ""
This is essentially a wrapper for an `AbstractSkeleton`, it
skeleton conforms with the interface expected by `Mechanics`.
It also adds optional state clipping.
Attributes:
skeleton: The model of skeletal dynamics.
muscle_model: None.
clip_states: Whether to clip the state to its bounds.
intervenors: Intervenors associated with each model stage.
"""
skeleton: AbstractSkeleton
muscle_model: None = None
clip_states: bool
intervenors: Mapping[str, AbstractIntervenor]
def __init__(
self,
skeleton: AbstractSkeleton,
clip_states: bool = True,
intervenors: Optional[
Union[
Sequence[AbstractIntervenor], Mapping[str, Sequence[AbstractIntervenor]]
]
] = None,
*,
key: Optional[PRNGKeyArray] = None,
):
"""
Arguments:
skeleton: The model of skeletal dynamics.
clip_states: Whether to clip the state to its bounds.
intervenors: Intervenors associated with each model stage.
"""
self.skeleton = skeleton
self.clip_states = clip_states
self.intervenors = self._get_intervenors_dict(intervenors)
@property
def model_spec(self) -> OrderedDict[str, ModelStage[Self, PlantState]]:
"""Specifies at most one model stage: state clipping, if it is enabled."""
Stage = ModelStage[Self, PlantState]
spec = OrderedDict()
if self.clip_states:
spec |= OrderedDict(
{
"clip_skeleton_state": Stage(
callable=lambda self: self._clip_state,
where_input=lambda input, state: self.bounds.skeleton,
where_state=lambda state: state.skeleton,
),
}
)
return spec
@cached_property
def dynamics_spec(self) -> dict[str, DynamicsComponent[PlantState]]:
"""Specifies a single dynamical component: the skeleton."""
return dict(
{
"skeleton": DynamicsComponent[PlantState](
dynamics=self.skeleton,
where_input=lambda input, state: input,
where_state=lambda state: state.skeleton,
),
}
)
@property
def memory_spec(self) -> PyTree[bool]:
"""A simple plant has no muscles, and no muscle state to remember."""
return PlantState(
skeleton=True,
muscles=False,
)
def init(self, *, key: PRNGKeyArray) -> PlantState:
"""Return a default state for the plant."""
return PlantState(
skeleton=self.skeleton.init(key=key),
muscles=None,
)
@property
def input_size(self) -> int:
"""Equal to the skeleton's input size."""
return self.skeleton.input_size
class MuscledArm(AbstractPlant):
"""Model of a two-link arm actuated by muscles.
Attributes:
skeleton: The model of skeletal dynamics.
muscle_model: The muscle model.
activator: The muscle activator, such as
[`ActivationFilter`][feedbax.mechanics.muscle.ActivationFilter].
clip_states: Whether to clip the states to their bounds.
n_muscles: The number of muscles.
moment_arms: The moment arms of the muscles with respect to the joints.
theta0: The optimal angles of the muscles with respect to the joints.
l0: The optimal length for each muscle.
f0: The maximum isometric force for each muscle.
intervenors: Intervenors associated with each model stage.
"""
skeleton: AbstractSkeleton
muscle_model: AbstractMuscle
activator: AbstractDynamicalSystem
clip_states: bool
n_muscles: int
moment_arms: Float[Array, "links=2 muscles"] = field(converter=jnp.array)
theta0: Float[Array, "links=2 muscles"] = field(converter=jnp.array)
l0: Float[Array, "muscles"] = field(converter=jnp.array)
f0: Float[Array, "muscles"] = field(converter=jnp.array)
intervenors: Mapping[str, AbstractIntervenor]
def __init__(
self,
muscle_model: AbstractMuscle,
activator: AbstractDynamicalSystem,
skeleton: AbstractSkeleton = TwoLink(),
clip_states: bool = True,
moment_arms: Float[Array, "links=2 muscles"] | Sequence[Sequence[float]] = (
jnp.array(
(
(2.0, -2.0, 0.0, 0.0, 1.50, -2.0), # [cm]
(0.0, 0.0, 2.0, -2.0, 2.0, -1.50),
)
),
),
theta0: Float[Array, "links=2 muscles"] | Sequence[Sequence[float]] = (
2
* jnp.pi
* jnp.array(
(
(15.0, 4.88, 0.0, 0.0, 4.5, 2.12), # [rad]
(0.0, 0.0, 80.86, 109.32, 92.96, 91.52),
)
)
/ 360.0
),
l0: Float[Array, "muscles"] | Sequence[float] = jnp.array(
(7.32, 3.26, 6.4, 4.26, 5.95, 4.04) # [cm]
),
f0: Float[Array, "muscles"] | Sequence[float] = jnp.array(
(1.0, 1.0, 1.0, 1.0, 1.0, 1.0) # [N] = [N/cm^2] * [cm^2]
# 31.8 * jnp.array((22., 12., 18., 14., 5., 10.)),
),
intervenors: Optional[
Union[
Sequence[AbstractIntervenor], Mapping[str, Sequence[AbstractIntervenor]]
]
] = None,
*,
key: Optional[PRNGKeyArray] = None,
):
"""
Arguments:
muscle_model: The muscle model.
activator: The muscle activator.
skeleton: The model of skeletal dynamics.
clip_states: Whether to clip the state to its bounds.
moment_arms: The moment arms of the muscles with respect to the joints.
theta0: The optimal angles of the muscles with respect to the joints (radians).
l0: The optimal length for each muscle.
f0: The maximum isometric force for each muscle.
intervenors: Intervenors associated with each model stage.
"""
self.skeleton = skeleton
self.activator = activator
self.clip_states = clip_states
if not theta0.shape[1] == l0.shape[0] == moment_arms.shape[1]:
raise ValueError(
"moment_arms, theta0, and l0 must have the same number of "
"columns (i.e. number of muscles)"
)
self.moment_arms = moment_arms
self.theta0 = theta0
self.l0 = l0
self.f0 = f0
self.n_muscles = moment_arms.shape[1]
# Make sure the muscle model has the right number of muscles.
self.muscle_model = muscle_model.change_n_muscles(self.n_muscles)
self.intervenors = self._get_intervenors_dict(intervenors)
@property
def model_spec(self) -> OrderedDict[str, ModelStage[Self, PlantState]]:
"""Specifies kinematic updates to the musculoskeletal state."""
Stage = ModelStage[Self, PlantState]
return OrderedDict(
{
"clip_skeleton_state": Stage(
callable=lambda self: self._clip_state,
where_input=lambda input, state: self.bounds.skeleton,
where_state=lambda state: state.skeleton,
),
"muscle_geometry": Stage(
callable=lambda self: self._muscle_geometry,
where_input=lambda input, state: state.skeleton,
where_state=lambda state: (
state.muscles.length,
state.muscles.velocity,
),
),
"clip_muscle_state": Stage(
# Activation shouldn't be below 0, and length has an UB.
callable=lambda self: self._clip_state,
where_input=lambda input, state: self.bounds.muscles,
where_state=lambda state: state.muscles,
),
"muscle_tension": Stage(
callable=lambda self: self.muscle_model,
where_input=lambda input, state: state.muscles.activation,
where_state=lambda state: state.muscles,
),
"muscle_torques": Stage(
callable=lambda self: self._muscle_torques,
where_input=lambda input, state: state.muscles,
where_state=lambda state: state.skeleton.torque,
),
}
)
@cached_property
def dynamics_spec(self) -> Mapping[str, DynamicsComponent[PlantState]]:
"""Specifies the components of the muscled arm dynamics."""
return dict(
{
"muscle_activation": DynamicsComponent(
dynamics=self.activator,
where_input=lambda input, state: input,
where_state=lambda state: state.muscles.activation,
),
#! is this applying the torques twice? since arm will do `input_torque + state.torque`
"skeleton": DynamicsComponent(
dynamics=self.skeleton,
where_input=lambda input, state: state.skeleton.torque,
where_state=lambda state: state.skeleton,
),
}
)
def _muscle_geometry(
self, input: AbstractSkeletonState, state: Tuple[Array, Array], *, key=None
):
skeleton_state = input
length = self._muscle_length(skeleton_state.angle)
velocity = self._muscle_velocity(skeleton_state.d_angle)
return (length, velocity)
def _muscle_length(self, angle: Array) -> Array:
# TODO: should this be a function? how general is it?
moment_arms, l0, theta0 = self.moment_arms, self.l0, self.theta0
l = (
1
+ (
moment_arms[0] * (theta0[0] - angle[0])
+ moment_arms[1] * (theta0[1] - angle[1])
)
/ l0
)
return l
def _muscle_velocity(self, d_angle: Array) -> Array:
moment_arms, l0 = self.moment_arms, self.l0
v = (moment_arms[0] * d_angle[0] + moment_arms[1] * d_angle[1]) / l0
return v
def _muscle_torques(self, input, state, *, key=None) -> Array:
torque = self.moment_arms @ (self.f0 * input.tension)
return torque
@property
def memory_spec(self) -> PyTree[bool]:
return PlantState(
skeleton=True,
muscles=True,
)
def init(self, *, key: PRNGKeyArray) -> PlantState:
"""Return a default state for the muscled arm."""
key1, key2 = jax.random.split(key)
return PlantState(
skeleton=self.skeleton.init(key=key1),
muscles=self.muscle_model.init(key=key2),
)
@property
def input_size(self) -> int:
"""Equal to the number of muscles."""
return self.n_muscles