Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Typing ModelStage #23

Open
mlprt opened this issue Feb 28, 2024 · 3 comments
Open

Typing ModelStage #23

mlprt opened this issue Feb 28, 2024 · 3 comments
Labels

Comments

@mlprt
Copy link
Owner

mlprt commented Feb 28, 2024

State operations ("stages") in an AbstractStagedModel are defined by the property model_spec which is an OrderedDict[str, ModelStage].

For example, here's one entry from SimpleFeedback.model_spec:

feedbax/feedbax/bodies.py

Lines 195 to 199 in 2ce8b1c

"mechanics_step": ModelStage(
callable=lambda self: self.mechanics,
where_input=lambda input, state: state.net.output,
where_state=lambda state: state.mechanics,
),

The state arguments in these lambdas should be typed as SimpleFeedbackState, so that the type checker recognizes that state.mechanics is a valid reference: SimpleFeedbackState has a field mechanics: MechanicsState.

Currently, ModelStage is a generic of the type variable StateT = TypeVar('StateT', AbstractState, Array), where all of the state PyTrees like SimpleFeedbackState inherit from AbstractState. Here is a slight simplification:

StateT = TypeVar('StateT', AbstractState, Array)

class ModelStage(eqx.Module, Generic[StateT]):
    callable: Callable[[AbstractStagedModel[StateT]], Callable]
    where_input: Callable[[AbstractTaskInputs, StateT], PyTree]
    where_state: Callable[[StateT], PyTree]
    intervenors: Sequence[AbstractIntervenor] = field(default_factory=tuple)

However, nowhere do we subclass ModelStage and give an argument for this type variable.

Throughout Feedbax, Pyright raises errors for only some of the references found in lambdas of model_spec properties. For example, it raises an error for the callable field of the "mechanics_step" stage given at the start of this issue,

feedbax/bodies.py:196:48 - error: Cannot access member "mechanics" for type "AbstractStagedModel[Unknown]"
    Member "mechanics" is unknown (reportAttributeAccessIssue)

but not for its where_input or where_state fields. Checking the Pylance tooltips, the input arguments are typed as AbstractTaskInputs, but the state fields are Any, which is at odds with the StateT annotation in ModelStage.

I suspect it will be necessary (#24 ) to eliminate AbstractState. Then, the type arguments to final subclasses of AbstractStagedModel will need to be distinct subclasses of equinox.Module. Similarly, the type variable in ModelStage will need to be bound to equinox.Module instead of AbstractState. This doesn't seem problematic as ModelStage will only be used inside an AbstractStagedModel with which it shares a type argument. For example, ModelState[SimpleFeedbackState] will be used in SimpleFeedback[SimpleFeedbackState].

I've tried explicitly writing "mechanics_step": ModelStage[SimpleFeedbackState](...) in the model_spec entries. Writing type arguments in at the time of instantiation is not a syntax I've ever used before, and it doesn't resolve the type checker's errors. I wouldn't expect this to resolve the error raised for the callable field of ModelStage, since the type argument won't specify that self should be of type SimpleFeedback.

How should this be properly typed?

@mlprt
Copy link
Owner Author

mlprt commented Feb 29, 2024

When the return annotation OrderedDict[str, ModelStage] is removed from SimpleFeedback.model_spec, then for "mechanics_step", pyright starts typing the state arguments of the lambdas as StateT@ModelStage instead of Any. Thus it also raises errors for the references in the lambdas:

  ./feedbax/bodies.py:197:60 - error: Cannot access member "net" for type "AbstractState*"
    Member "net" is unknown (reportAttributeAccessIssue)
  ./feedbax/bodies.py:197:60 - error: Cannot access member "net" for type "Array*"
    Member "net" is unknown (reportAttributeAccessIssue)
  ./feedbax/bodies.py:198:53 - error: Cannot access member "mechanics" for type "AbstractState*"
    Member "mechanics" is unknown (reportAttributeAccessIssue)
  ./feedbax/bodies.py:198:53 - error: Cannot access member "mechanics" for type "Array*"
    Member "mechanics" is unknown (reportAttributeAccessIssue)

@mlprt
Copy link
Owner Author

mlprt commented Feb 29, 2024

Typing of all the arguments is working now. Two things needed to be done:

  1. Change ModelStage to a Generic[ModelT, StateT] where both type variables are bound to equinox.Module.
  2. Type each stage, and the return, of model_spec with explicit type arguments to ModelStage.

For example, here's the updated model_spec of SimpleFeedback:

feedbax/feedbax/bodies.py

Lines 168 to 203 in 0c6e824

@property
def model_spec(self) -> OrderedDict[str, ModelStage[Self, SimpleFeedbackState]]:
"""Specifies the stages of the model in terms of state operations."""
Stage = ModelStage[Self, SimpleFeedbackState]
return OrderedDict(
{
"update_feedback": Stage(
callable=lambda self: self._feedback_module,
where_input=lambda input, state: jax.tree_map(
lambda spec: spec.where(state.mechanics),
self._feedback_specs,
is_leaf=lambda x: isinstance(x, ChannelSpec),
),
where_state=lambda state: state.feedback,
),
"nn_step": Stage(
callable=lambda self: self.net,
where_input=lambda input, state: (
input,
# Get the output state for each feedback channel.
jax.tree_map(
lambda state: state.output,
state.feedback,
is_leaf=lambda x: isinstance(x, ChannelState),
),
),
where_state=lambda state: state.net,
),
"mechanics_step": Stage(
callable=lambda self: self.mechanics,
where_input=lambda input, state: state.net.output,
where_state=lambda state: state.mechanics,
),
}
)

Making this change for all of the model_spec definitions in Feedbax reduced the number of Pyright errors from ~290 to 261. Quite a few errors are now being raised due to the return types of the lambdas, but these can now be addressed normally.

Unfortunately, the use of Self here means that this will only work in Python>=3.11. To support (#25) earlier versions of Python, another solution will need to be found, or we'll need to ignore types within model stages.

@mlprt
Copy link
Owner Author

mlprt commented Feb 29, 2024

Being uncertain about potential alternatives to my solution, I'm leaving this issue open for now.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
Projects
None yet
Development

No branches or pull requests

1 participant