Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Generalized loss terms from trial specifications #10

Closed
mlprt opened this issue Feb 17, 2024 · 2 comments
Closed

Generalized loss terms from trial specifications #10

mlprt opened this issue Feb 17, 2024 · 2 comments
Labels
enhancement New feature or request structure

Comments

@mlprt
Copy link
Owner

mlprt commented Feb 17, 2024

Currently:

  • the init field of AbstractTaskTrialSpec specifies 1) lambdas that pick out subtrees of the model state, and 2) replacements for those subtrees, to use to initialize the state on a given trial.
  • for each batch of training trials, TaskTrainer passes the following to the loss function: 1) the evaluated trajectories of model states, 2) the trial specifications. Thus, every AbstractLoss subclass in feedbax.loss works through hardcoded references to parts of states and trial_specs. For example, feedbax.loss.EffectorPositionLoss is a function of the difference states.mechanics.effector.pos - trial_specs.target.pos.
  • the target field of AbstractTaskTrialSpec is a CartesianState that only specifies the state of a single effector.

But what if we want to (say) include losses that penalize the position of two different effectors, given by CartesianState leaves at different locations in the states PyTree? Then trial_specs.target should specify two different CartesianState targets. Should the relationship between the leaves of trial_specs.target and the leaves of states be hardcoded into a subclass of AbstractLoss? This could lead to a proliferation of AbstractLoss classes. (However, see #24.)

Instead, trial_specs.target could be defined similarly to trial_specs.init. In that case, each member of target would provide 1) a lambda that picks out a part of the state, and 2) a target trajectory for that state. A loss term could be automatically constructed from each entry -- something like target.where(states) - target.value.

I think some losses should still be defined the same way they already are. But this feature might make it easier to train more complex models, since the user will only need to specify a couple of lambdas when subclassing AbstractTask, instead of needing to write multiple subclasses of AbstractLoss as well.

@mlprt mlprt added enhancement New feature or request structure labels Feb 17, 2024
@mlprt mlprt changed the title Generalized loss terms inferred from trial specifications Generalized loss terms from trial specifications Feb 17, 2024
@mlprt
Copy link
Owner Author

mlprt commented Apr 28, 2024

A working implemention is available as of 4f735f4.

In feedbax.loss, [TargetStateLoss][https://github.com/mlprt/feedbax/blob/4f735f434ededd122ddeee6957fd911a9e8870c7/feedbax/loss.py#L357] is a subclass of AbstractLoss that associates a where, a norm function, and a TargetSpec.

A TargetSpec provides information about 1) the target value of the state, 2) the time indices at which the state's value should be penalized (e.g. penalize effector velocity on final time step only), and 3) an array of discounting factors. All of these fields are optional, and partial specifications may be combined.

In TaskTrialSpec, there is now a field targets: WhereDict[TargetSpec] through which a subclass of AbstractLoss can provide trial-by-trial TargetSpec information to instances of TargetStateLoss (via TaskTrainer).

When a TargetStateLoss instance is called, its spec: Optional[TargetSpec] field is eqx.combine'd with any entries in trial_specs.targets. This allows the user to supply default target values on instantiation of TargetStateLoss, but also for the task to be designed (example) so that it provides trial-by-trial targets. A target value must be specified either trial-by-trial or as a default -- an error is raised if no target value is available.

A loss_func must still be passed on instantiating an AbstractTask subclass. Composing the terms of the loss function is now a little more complicated than it used to be, since we add TargetStateLoss instances by specifying the part of the state to penalize. We can probably replace the old loss classes like EffectorPositionLoss with factories/wrappers for TargetStateLoss, which would simplify the loss function construction again, in some typical use cases.

@mlprt mlprt closed this as completed Apr 28, 2024
@mlprt
Copy link
Owner Author

mlprt commented Apr 28, 2024

One issue that remains is the possibility of multiple targets being specified with respect to the same part of the state. For example, in delayed reaching we might want a separate loss terms for the effector position error with respect to 1) the reach goal, and 2) the initial fixation. The possibility of multiple loss terms on a single target is why I enabled tuple[Callable, str] keys for WhereDict, so that a where lambda can be combined with a unique label. However, this means that we have to make sure that the label field of TargetStateLoss matches the string entry in TargetSpecs constructed by the task. There are a couple of other options here:

  1. Allow the values of trial_specs.targets to be a Mapping[str, TargetSpec]. This doesn't solve the string-matching issue, but it does simplify the allowable keys of WhereDict.
  2. Only allow a single target for each part of the state. This should be possible (e.g. goal and fixation targets happen at different times during the reach) but it would mean that some other mechanism (say, in AbstractLoss) would be necessary if we want users to be able to distinguish loss contributions from different epochs of a trial.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
enhancement New feature or request structure
Projects
None yet
Development

No branches or pull requests

1 participant