Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

The abstract-final pattern and generics: should AbstractState be eliminated? #24

Open
mlprt opened this issue Feb 29, 2024 · 6 comments
Labels

Comments

@mlprt
Copy link
Owner

mlprt commented Feb 29, 2024

If we follow the abstract-final pattern strictly, such as by setting strict=True when subclassing equinox.Module, then

  • The base class AbstractState would need to include an AbstractVar for every field that appears in every subclass;
  • Every final subclass of AbstractState would need to implement every one of those fields.

Clearly this doesn't make sense, since different types of state PyTrees usually don't share any fields.

This might also be a reason we should expect to see issues with generic typing of AbstractModel[StateT], where StateT is bound to AbstractState. However, my understanding is that type invariance should be preserved as long as we respect the abstract-final pattern when subclassing whichever AbstractState subclass is ultimately used as the type argument for a final subclass of AbstractModel.

I suspect the solution is:

  1. Replace StateT with a type variable bound to equinox.Module;
  2. Use different base classes (AbstractFooState) that inherit from equinox.Module, for different final subclasses of AbstractModel. Each of these should respect the abstract-final pattern.
@mlprt mlprt added the typing label Feb 29, 2024
@mlprt
Copy link
Owner Author

mlprt commented Feb 29, 2024

A problem that may be directly related: I had been thinking that a given final subclass of AbstractTask would not need to be associated with a single final subclass of AbstractModel (e.g. SimpleFeedback). Different models can be evaluated on the same task, right?

For example, SimpleReaches defines trial specifications whose init selects a part of the state to initialize: lambda state: state.mechanics.effector. This makes sense when state is a SimpleFeedbackState, which has a field mechanics: MechanicsState -- and MechanicsState has a field effector: CartesianState.

However, SimpleFeedbackState has a bunch of other fields as well. What if we have a model FooFeedback[FooFeedbackState] where FooFeedbackState has a field mechanics, but lacks some of the other fields of SimpleFeedbackState? Shouldn't it still make sense to initialize FooFeedbackState using the init from SimpleReaches?

A related problem: loss functions are passed the history of states over a trial, as a states argument. Generally a loss function only depends on a subset of states. For example, EffectorPositionLoss only depends on states.mechanics.effector.pos. Is there some way to type this states argument so that it does not expect SimpleFeedbackState specifically, but any state that has an appropriate mechanics: MechanicsState field?

I have tried to use protocols to solve this:

from typing import Protocol

from feedbax.loss import AbstractLoss
from feedbax.state import CartesianState
from feedbax.task import AbstractTaskTrialSpec


class HasEffectorState(Protocol):
    effector: CartesianState


class HasMechanicsEffectorState(Protocol):
    mechanics: HasEffectorState


class EffectorFooLoss(AbstractLoss):
    ...

    def term(
        self, 
        states: HasMechanicsEffectorState, 
        trial_specs: AbstractTaskTrialSpec,
    ) -> Array:
        ...

Clearly this gets ugly pretty quickly.

My current understanding is that:

  • a final subclass of AbstractTask should be associated with a final subclass of AbstractModel. For example, we should design SimpleReaches and EffectorPositionLoss to operate on states: SimpleFeedbackState specifically;
  • We shouldn't try to type arguments as possessing only specific subtrees;
  • SimpleFeedback is pretty general, and we should not see it as just another model. In principle we could design other high-level models where the controller-dynamics (agent-environment) separation is not so simple, but if I'm going to do that, I should probably be prepared to be specific about my tasks and loss functions.

@mlprt
Copy link
Owner Author

mlprt commented Feb 29, 2024

This issue also applies to AbstractIntervenorInput and AbstractTaskInputs.

@mlprt
Copy link
Owner Author

mlprt commented Feb 29, 2024

I've switched StateT to be bound to equinox.Module. However, I haven't eliminated AbstractState yet as I'd like to wait a bit for feedback on this issue before I make the change and refactor the docs.

@mlprt
Copy link
Owner Author

mlprt commented Mar 1, 2024

One related pyright error that doesn't make sense to me is:

./feedbax/intervene.py:519:34 - error: Type variable "StateT" has no meaning in this context (reportGeneralTypeIssues)

This refers to the appearance of StateT in the signature of schedule_intervenor:

def schedule_intervenor(
tasks: PyTree["AbstractTask"],
models: PyTree[AbstractModel[StateT]],
intervenor: AbstractIntervenor | Type[AbstractIntervenor],
# TODO: intervenor_validation
where: Callable[[AbstractModel[StateT]], Any] = lambda model: model,
stage_name: Optional[str] = None,
validation_same_schedule: bool = True,
intervention_spec: Optional[
AbstractIntervenorInput
] = None, #! wrong! distribution functions are allowed. only the PyTree structure is the same
intervention_spec_validation: Optional[AbstractIntervenorInput] = None,
default_active: bool = False,
) -> Tuple["AbstractTask", AbstractModel[StateT]]:

The type variable appears in both the arguments and the return -- I don't see why it is meaningless. (The return type should be a tuple of two PyTrees, though, since we can modify multiple tasks/models now.)

@mlprt
Copy link
Owner Author

mlprt commented Mar 4, 2024

I've switched StateT again, this time to be bound to PyTree[Array].

StateT = TypeVar("StateT", bound=PyTree[Array])

This allows AbstractModel, AbstractStagedModel, AbstractDynamicalSystem subclasses to take an appropriate equinox.Module as a type argument -- but also, to take a single array. For example, LTISystem operates on a single array.

class LTISystem(AbstractDynamicalSystem[Array]):

@mlprt
Copy link
Owner Author

mlprt commented Apr 23, 2024

I just revisited this issue and was confused why I had thought that the abstract-final pattern implied that AbstractState would need to specify AbstractVars for every possible field of a subclass. I don't think that's necessarily required by the abstract-final pattern -- even in the strict=True case we are allowed to define more fields in a subclass, as long as all implemented fields are implemented in that subclass.

(Perhaps I was confused about the "need" for contravariance in the state argument to the __call__ method of subclasses of AbstractModel... but that shouldn't be an issue since AbstractStagedModel is generic and the state argument is effectively invariant once we're in a concrete subclass of AbstractStagedModel whose generic type argument is specified by some type of PyTree).

Note how Equinox determines whether a module is abstract, when performing strict=True checks:

def _is_abstract(cls):
    return (
        _is_force_abstract[cls]
        or len(cls.__abstractmethods__) > 0
        or len(cls.__abstractvars__) > 0
        or len(cls.__abstractclassvars__) > 0
    )

Thus a totally empty class does not count as abstract, since it contains no abstract attributes or methods. But this is currently the way that AbstractState is defined:

feedbax/feedbax/state.py

Lines 29 to 36 in 0dbaffb

class AbstractState(Module):
"""Base class for model states.
!!! NOTE ""
Currently this is empty, and only used for collectively typing its subclasses.
"""
...

I suppose Equinox's strategy makes sense, since this isn't really an abstract class. It's more like an alias for equinox.Module -- and TypeAlias is probably what we should use if we want to improve readability by referring to PyTree as State sometimes, and as long as we don't need an AbstractState to specify an interface that all state modules are expected to implement.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
Projects
None yet
Development

No branches or pull requests

1 participant