-
Notifications
You must be signed in to change notification settings - Fork 0
/
state.py
127 lines (95 loc) · 2.98 KB
/
state.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
"""Common types used across the package.
:copyright: Copyright 2023-2024 by Matt Laporte.
:license: Apache 2.0. See LICENSE for details.
"""
from collections.abc import Callable
from copy import deepcopy
from functools import cached_property
import logging
from typing import (
Optional,
Generic,
Protocol,
TypeVar,
runtime_checkable,
)
import equinox as eqx
from equinox import Module, field
import jax
import jax.numpy as jnp
from jaxtyping import Array, Float, PyTree
logger = logging.getLogger(__name__)
class AbstractState(Module):
"""Base class for model states.
!!! NOTE ""
Currently this is empty, and only used for collectively typing its subclasses.
"""
...
StateT = TypeVar("StateT", bound=PyTree[Array])
class StateBounds(Module, Generic[StateT]):
"""Specifies bounds on a state.
Attributes:
low: A state PyTree giving lower bounds.
high: A state PyTree giving upper bounds.
"""
low: Optional[StateT]
high: Optional[StateT]
@cached_property
def filter_spec(self) -> PyTree[bool]:
"""A matching PyTree, indicated which parts of the state are bounded."""
return jax.tree_map(
lambda x: x is not None,
self,
is_leaf=lambda x: isinstance(x, Array) or x is None,
)
class CartesianState(AbstractState):
"""Cartesian state of a mechanical system in two spatial dimensions.
Attributes:
pos: The position coordinates of the point(s) in the system.
vel: The respective velocities.
force: The respective forces.
"""
pos: Float[Array, "... 2"] = field(default_factory=lambda: jnp.zeros(2))
vel: Float[Array, "... 2"] = field(default_factory=lambda: jnp.zeros(2))
force: Float[Array, "... 2"] = field(default_factory=lambda: jnp.zeros(2))
@runtime_checkable
class HasEffectorState(Protocol):
effector: CartesianState
@runtime_checkable
class HasMechanicsEffectorState(Protocol):
mechanics: HasEffectorState
def clip_state(
bounds: StateBounds[StateT],
state: StateT,
) -> StateT:
"""Returns a state clipped to the given bounds.
Arguments:
bounds: The lower and upper bounds to clip the state to.
state: The state to clip.
"""
if bounds.low is not None:
state = _clip_state_to_bound(
state, bounds.low, bounds.filter_spec.low, jnp.greater
)
if bounds.high is not None:
state = _clip_state_to_bound(
state, bounds.high, bounds.filter_spec.high, jnp.less
)
return state
def _clip_state_to_bound(
state: StateT,
bound: StateT,
filter_spec: PyTree[bool],
op: Callable,
) -> StateT:
"""A single (one-sided) clipping operation."""
states_to_clip, states_other = eqx.partition(
state,
filter_spec,
)
states_clipped = jax.tree_map(
lambda x, y: jnp.where(op(x, y), x, y),
states_to_clip,
bound,
)
return eqx.combine(states_other, states_clipped)