-
Notifications
You must be signed in to change notification settings - Fork 0
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Generalized loss terms from trial specifications #10
Comments
A working implemention is available as of 4f735f4. In A In When a A |
One issue that remains is the possibility of multiple targets being specified with respect to the same part of the state. For example, in delayed reaching we might want a separate loss terms for the effector position error with respect to 1) the reach goal, and 2) the initial fixation. The possibility of multiple loss terms on a single target is why I enabled
|
Currently:
init
field ofAbstractTaskTrialSpec
specifies 1) lambdas that pick out subtrees of the model state, and 2) replacements for those subtrees, to use to initialize the state on a given trial.TaskTrainer
passes the following to the loss function: 1) the evaluated trajectories of model states, 2) the trial specifications. Thus, everyAbstractLoss
subclass infeedbax.loss
works through hardcoded references to parts ofstates
andtrial_specs
. For example,feedbax.loss.EffectorPositionLoss
is a function of the differencestates.mechanics.effector.pos - trial_specs.target.pos
.target
field ofAbstractTaskTrialSpec
is aCartesianState
that only specifies the state of a single effector.But what if we want to (say) include losses that penalize the position of two different effectors, given by
CartesianState
leaves at different locations in thestates
PyTree? Thentrial_specs.target
should specify two differentCartesianState
targets. Should the relationship between the leaves oftrial_specs.target
and the leaves ofstates
be hardcoded into a subclass ofAbstractLoss
? This could lead to a proliferation ofAbstractLoss
classes. (However, see #24.)Instead,
trial_specs.target
could be defined similarly totrial_specs.init
. In that case, each member oftarget
would provide 1) a lambda that picks out a part of the state, and 2) a target trajectory for that state. A loss term could be automatically constructed from each entry -- something liketarget.where(states) - target.value
.I think some losses should still be defined the same way they already are. But this feature might make it easier to train more complex models, since the user will only need to specify a couple of lambdas when subclassing
AbstractTask
, instead of needing to write multiple subclasses ofAbstractLoss
as well.The text was updated successfully, but these errors were encountered: