Skip to content

Commit

Permalink
Bind state typevars to Module instead of AbstractState
Browse files Browse the repository at this point in the history
- Make `feedbax.state.StateT` bound to `eqx.Module` instead of
  `AbstractState`.
- Split `wrap_stateless_callable` into two, producing
  `wrap_stateless_keyless_callable` as well. Add to docs.
- Fix control flow and typing in the body of `add_intervenors`.
- Fix typing in the body of `schedule_intervenor`.
- Switch a few more annotations to `Self`, where appropriate.
  • Loading branch information
mlprt committed Mar 1, 2024
1 parent 0c6e824 commit a6b68bd
Show file tree
Hide file tree
Showing 13 changed files with 113 additions and 104 deletions.
4 changes: 3 additions & 1 deletion docs/api/model.md
Original file line number Diff line number Diff line change
Expand Up @@ -4,4 +4,6 @@

::: feedbax.ModelInput

::: feedbax.wrap_stateless_callable
::: feedbax.wrap_stateless_callable

::: feedbax.wrap_stateless_keyless_callable
10 changes: 7 additions & 3 deletions feedbax/__init__.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
"""TODO
"""
:copyright: Copyright 2023-2024 by Matt L Laporte.
:license: Apache 2.0, see LICENSE for details.
"""
Expand All @@ -10,7 +9,12 @@
import warnings

from feedbax._io import save, load
from feedbax._model import AbstractModel, ModelInput, wrap_stateless_callable
from feedbax._model import (
AbstractModel,
ModelInput,
wrap_stateless_callable,
wrap_stateless_keyless_callable,
)
from feedbax._staged import (
AbstractStagedModel,
ModelStage,
Expand Down
41 changes: 21 additions & 20 deletions feedbax/_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,9 +12,11 @@
TYPE_CHECKING,
Generic,
Optional,
Self,
)

import equinox as eqx
from equinox import Module
import jax
from jaxtyping import Array, PRNGKeyArray, PyTree
import numpy as np
Expand All @@ -23,14 +25,15 @@
from feedbax._tree import random_split_like_tree

if TYPE_CHECKING:
from feedbax._staged import AbstractStagedModel
from feedbax.intervene import AbstractIntervenorInput
from feedbax.task import AbstractTaskInputs


logger = logging.getLogger(__name__)


class AbstractModel(eqx.Module, Generic[StateT]):
class AbstractModel(Module, Generic[StateT]):
"""Base class for models that operate on `AbstractState` objects."""

@abstractmethod
Expand All @@ -50,7 +53,7 @@ def __call__(
...

@abstractproperty
def step(self) -> "AbstractModel[StateT]":
def step(self) -> Module:
"""The part of the model PyTree specifying a single time step of the model.
For non-iterated models, this should trivially return `step`.
Expand Down Expand Up @@ -112,7 +115,7 @@ def memory_spec(self) -> PyTree[bool]:
return True


class ModelInput(eqx.Module):
class ModelInput(Module):
"""PyTree that contains all inputs to a model."""

input: PyTree[Array]
Expand Down Expand Up @@ -155,7 +158,7 @@ def init(
)

@property
def step(self):
def step(self) -> Module:
return self

def _get_keys(self, key):
Expand All @@ -164,10 +167,7 @@ def _get_keys(self, key):
)


def wrap_stateless_callable(
callable: Callable,
pass_key: bool = True,
):
def wrap_stateless_callable(callable: Callable):
"""Makes a 'stateless' callable compatible with state-passing.
!!! Info
Expand All @@ -186,20 +186,21 @@ def wrap_stateless_callable(
Arguments:
callable: The callable to wrap.
pass_key: If `True`, the keyword argument `key` will be forwarded to the wrapped
callable. If `False`, it will be discarded. Discarding is useful if the
wrapped callable does not accept a `key` argument.
"""
if pass_key:
@wraps(callable)
def wrapped(input, state, *args, **kwargs):
return callable(input, *args, **kwargs)

@wraps(callable)
def wrapped(input, state, *args, **kwargs):
return callable(input, *args, **kwargs)
return wrapped

else:
def wrap_stateless_keyless_callable(callable: Callable):
"""Like `wrap_stateless_callable`, for a callable that also takes no `key`.
@wraps(callable)
def wrapped(input, state, *args, key: Optional[PRNGKeyArray] = None, **kwargs):
return callable(input, *args, **kwargs)
Arguments:
callable: The callable to wrap.
"""
@wraps(callable)
def wrapped(input, state, *args, key: Optional[PRNGKeyArray] = None, **kwargs):
return callable(input, *args, **kwargs)

return wrapped
return wrapped
19 changes: 11 additions & 8 deletions feedbax/_staged.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,12 +15,13 @@
Generic,
Optional,
Protocol,
Self,
TypeVar,
Union,
)

import equinox as eqx
from equinox import AbstractVar, field
from equinox import AbstractVar, Module, field
import jax
import jax.random as jr
from jaxtyping import Array, PRNGKeyArray, PyTree
Expand All @@ -38,22 +39,22 @@
logger = logging.getLogger(__name__)


ModelT = TypeVar("ModelT", bound=eqx.Module)
StateT = TypeVar("StateT", bound=eqx.Module)
ModelT = TypeVar("ModelT", bound=Module)
StateT = TypeVar("StateT", bound=Module)


class ModelStageCallable(Protocol):
# This is part of the `ModelInput` hack.
def __call__(self, input: ModelInput, state: eqx.Module, key: PRNGKeyArray) -> PyTree[Array]:
def __call__(self, input: ModelInput, state: PyTree[Array], *, key: PRNGKeyArray) -> PyTree[Array]:
...


class OtherStageCallable(Protocol):
def __call__(self, input: PyTree[Array], state: eqx.Module, key: PRNGKeyArray) -> PyTree[Array]:
def __call__(self, input: PyTree[Array], state: PyTree[Array], *, key: PRNGKeyArray) -> PyTree[Array]:
...


class ModelStage(eqx.Module, Generic[ModelT, StateT]):
class ModelStage(Module, Generic[ModelT, StateT]):
"""Specification for a stage in a subclass of `AbstractStagedModel`.
Each stage of a model is a callable that performs a modification to part
Expand Down Expand Up @@ -198,7 +199,7 @@ def __call__(
def init(
self,
*,
key: Optional[PRNGKeyArray] = None,
key: PRNGKeyArray,
) -> StateT:
"""Return a default state for the model."""
...
Expand Down Expand Up @@ -259,13 +260,15 @@ def _get_intervenors_dict(
return intervenors_dict

@property
def step(self) -> "AbstractStagedModel[StateT]":
def step(self) -> Module:
"""The model step.
For an `AbstractStagedModel`, this is trivially the model itself.
"""
return self

# TODO: Avoid referencing `AbstractIntervenor` here, to avoid a circular import
# with `feedbax.intervene`.
@property
def _all_intervenor_labels(self):
model_leaves = jax.tree_util.tree_leaves(
Expand Down
16 changes: 5 additions & 11 deletions feedbax/bodies.py
Original file line number Diff line number Diff line change
Expand Up @@ -287,19 +287,13 @@ def get_nn_input_size(
def state_consistency_update(
self, state: SimpleFeedbackState
) -> SimpleFeedbackState:
"""Adjust the state
"""Returns a corrected initial state for the model.
Update the plant configuration state, given that the user has
1. Update the plant configuration state, given that the user has
initialized the effector state.
Also fill the feedback queues with the initial feedback states. This
is less problematic than passing all zeros.
TODO:
- Check which of the two (effector or config) initialized, and update the other one.
Might require initializing them to NaN or something in `init`.
- Only initialize feedback channels whose *queues* are NaN, don't just check if
the entire channel is NaN and updated all-or-none of them.
2. Fill the feedback queues with the initial feedback states. This
is less problematic than passing all zeros until the delay elapses
for the first time.
"""
state = eqx.tree_at(
lambda state: state.mechanics.plant.skeleton,
Expand Down
2 changes: 1 addition & 1 deletion feedbax/channel.py
Original file line number Diff line number Diff line change
Expand Up @@ -98,7 +98,7 @@ def __check_init__(self):
if not isinstance(self.delay, int):
raise ValueError("Delay must be an integer")

def _update_queue(self, input, state, *, key):
def _update_queue(self, input: PyTree[Array], state: ChannelState, *, key: PRNGKeyArray):
return ChannelState(
output=state.queue[0],
queue=state.queue[1:] + (input,),
Expand Down
7 changes: 3 additions & 4 deletions feedbax/dynamics.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,14 +8,13 @@
import logging
from typing import Optional

import equinox as eqx
from equinox import AbstractVar
from equinox import Module
import jax
import jax.numpy as jnp
from jaxtyping import Array, Float, PRNGKeyArray, PyTree

from feedbax._model import AbstractModel
from feedbax.state import CartesianState, StateBounds, StateT
from feedbax.state import StateBounds, StateT


logger = logging.getLogger(__name__)
Expand Down Expand Up @@ -68,7 +67,7 @@ def init(self, *, key: PRNGKeyArray) -> StateT:
...

@property
def step(self) -> "AbstractDynamicalSystem[StateT]":
def step(self) -> Module:
return self


Expand Down
Loading

0 comments on commit a6b68bd

Please sign in to comment.