-
Notifications
You must be signed in to change notification settings - Fork 0
/
task.py
1106 lines (901 loc) · 38.9 KB
/
task.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
794
795
796
797
798
799
800
801
802
803
804
805
806
807
808
809
810
811
812
813
814
815
816
817
818
819
820
821
822
823
824
825
826
827
828
829
830
831
832
833
834
835
836
837
838
839
840
841
842
843
844
845
846
847
848
849
850
851
852
853
854
855
856
857
858
859
860
861
862
863
864
865
866
867
868
869
870
871
872
873
874
875
876
877
878
879
880
881
882
883
884
885
886
887
888
889
890
891
892
893
894
895
896
897
898
899
900
901
902
903
904
905
906
907
908
909
910
911
912
913
914
915
916
917
918
919
920
921
922
923
924
925
926
927
928
929
930
931
932
933
934
935
936
937
938
939
940
941
942
943
944
945
946
947
948
949
950
951
952
953
954
955
956
957
958
959
960
961
962
963
964
965
966
967
968
969
970
971
972
973
974
975
976
977
978
979
980
981
982
983
984
985
986
987
988
989
990
991
992
993
994
995
996
997
998
999
1000
"""Tasks on which models are trained and evaluated.
TODO:
- Maybe allow initial mods to model parameters, in addition to substates.
- Some of the private functions could be public.
- Refactor `get_target_seq` and `get_scalar_epoch_seq` redundancy.
- Also, the way `seq` and `seqs` are generated is similar to `states` in
`ForgetfulIterator.init`...
:copyright: Copyright 2023-2024 by Matt L Laporte.
:license: Apache 2.0, see LICENSE for details.
"""
#! Can't do this because `AbstractVar` annotations can't be stringified.
# from __future__ import annotations
from abc import abstractmethod, abstractproperty
from collections.abc import Callable, Mapping
import dis
from functools import cached_property
import logging
from typing import (
TYPE_CHECKING,
Optional,
Tuple,
)
import equinox as eqx
from equinox import AbstractVar, Module, field
import jax
import jax.numpy as jnp
import jax.random as jr
import jax.tree_util as jtu
from jaxtyping import Array, Float, Int, PRNGKeyArray, PyTree, Shaped
import numpy as np
from feedbax.intervene import AbstractIntervenorInput, TimeSeriesParam
from feedbax.loss import AbstractLoss, LossDict
from feedbax._mapping import AbstractTransformedOrderedDict
from feedbax._model import ModelInput
from feedbax.state import AbstractState, CartesianState, StateT
from feedbax._tree import tree_call
if TYPE_CHECKING:
from feedbax._model import AbstractModel
logger = logging.getLogger(__name__)
N_DIM = 2
def _get_where_str(where_func: Callable) -> str:
"""
Given a function that accesses a tree of attributes of a single parameter,
return a string repesenting the attributes.
This is useful for getting a unique string representation of a substate
of an `AbstractState` or `AbstractModel` object, as defined by a `where`
function, so we can compare two such functions and see if they refer to
the same substate.
TODO:
- I'm not sure it's good practice to introspect on bytecode like this.
"""
bytecode = dis.Bytecode(where_func)
return ".".join(instr.argrepr for instr in bytecode if instr.opname == "LOAD_ATTR")
@jtu.register_pytree_node_class
class WhereDict(
AbstractTransformedOrderedDict[
str, Callable[[PyTree[Array]], PyTree[Array, "T"]], PyTree[Array, "T"]
]
):
"""An `OrderedDict` that allows limited use of `where` lambdas as keys.
In particular, keys can be lambdas that take a single argument,
and return a single (nested) attribute accessed from that argument.
Lambdas are parsed to equivalent strings, which can be used
interchangeably as keys. For example, the following are equivalent when
`init_spec` is a `WhereDict`:
```python
init_spec[lambda state: state.mechanics.effector]
```
```python
init_spec['mechanics.effector']
```
??? dev-note "Performance tests"
- Construction is about 100x slower than `OrderedDict`. For dicts of
size relevant to our use case, this means ~100 us instead of ~1 us.
A modest increase in dict size only changes this modestly.
- Access is about 800x slower, and as expected this doesn't
change with dict size because there's just a constant overhead
for doing the key transformation on each access.
- Each access is about 26 us, and this is also the duration of
a call to `get_where_str`.
- `list(init_spec.items())` is about 100x slower (24 us versus 234 ns)
for a single-entry `init_spec`, and about 260x slower (149 us versus
571 ns) for 6 entries, which is about as many as I expect anyone to use
in the near future, when constructing a task.
This is pretty slow, but since we only need to do a single
construction and a single access of `init_spec` per batch/evaluation,
it shouldn't matter too much in practice—overhead of about 125 us/batch,
with a batch normally taking about 20,000+ us to train.
Optimizations should focus on `get_where_str`.
"""
def _key_transform(self, key: str | Callable) -> str:
if isinstance(key, Callable):
where_str = _get_where_str(key)
if not where_str:
raise ValueError(
"WhereDict keys must be lambdas that perform " "attribute access."
)
return where_str
return key
def __repr__(self):
# Make a pretty representation of the lambdas
items_str = ", ".join(
f"(lambda state: state{'.' if k else ''}{k}, {v})"
for k, (_, v) in self.store.items()
)
return f"{type(self).__name__}([{items_str}])"
class AbstractTaskInputs(Module):
"""Abstract base class for model inputs provided by a task.
!!! Note ""
Normally, each field of a subclass will be a PyTree of arrays where
each array has a leading dimension corresponding to the time step—
which becomes the second dimension, when the PyTree describes a batch
of trials.
"""
...
# TODO: One `target` is specified as a `WhereDict` as well, the only thing
# that will really change between classes of tasks is `inputs`. In that case,
# we could just make this `TaskTrialSpec` and have it be a generic of
# `AbstractTaskInputs`
class AbstractTaskTrialSpec(Module):
"""Abstract base class for trial specifications provided by a task.
Attributes:
inits: A mapping from `lambdas` that select model substates to be
initialized, to substates to initialize them with.
inputs: A PyTree of inputs to the model.
target: A PyTree of target states.
intervene: A mapping from unique intervenor names, to per-trial
intervention parameters.
"""
inits: AbstractVar[WhereDict]
# inits: OrderedDict[Callable[[AbstractState], PyTree[Array]],
# PyTree[Array]]
inputs: AbstractVar[AbstractTaskInputs]
target: AbstractVar[PyTree[Array]]
intervene: AbstractVar[Mapping[str, Array]]
class AbstractReachTrialSpec(AbstractTaskTrialSpec):
"""Abstract base class for trial specifications for reaching tasks.
Attributes:
inits: A mapping from `lambdas` that select model substates to be
initialized, to substates to initialize them with.
inputs: A PyTree of inputs to the model, including data about the
reach target.
target: The target trajectory for the mechanical end effector,
used for computing the loss.
intervene: A mapping from unique intervenor names, to per-trial
intervention parameters.
"""
inits: AbstractVar[WhereDict]
inputs: AbstractVar[AbstractTaskInputs]
target: AbstractVar[CartesianState]
intervene: AbstractVar[Mapping[str, Array]]
@cached_property
def goal(self):
"""The final state in the target trajectory for the mechanical end effector."""
return jax.tree_map(lambda x: x[:, -1], self.target)
class SimpleReachTaskInputs(AbstractTaskInputs):
"""Model input for a simple reaching task.
Attributes:
effector_target: The trajectory of effector target states to be presented to
the model.
"""
effector_target: CartesianState #! column vector: why here?
class DelayedReachTaskInputs(Module):
"""Model input for a delayed reaching task.
Attributes:
effector_target: The trajectory of effector target states to be presented to
the model.
hold: The hold/go (1/0 signal) to be presented to the model.
target_on: A signal indicating to the model when the value of `effector_target`
should be interpreted as a reach target. Otherwise, if zeros are passed for
the target during (say) the hold period, the model may interpret this as
meaningful—that is, "your reach target is at 0".
"""
effector_target: CartesianState # PyTree[Float[Array, "time ..."]]
hold: Int[
Array, "time 1"
] # TODO: do these need to be typed as column vectors, here?
target_on: Int[Array, "time 1"]
class SimpleReachTrialSpec(AbstractReachTrialSpec):
"""Trial specification for a simple reaching task.
Attributes:
inits: A mapping from `lambdas` that select model substates to be
initialized, to substates to initialize them with at the start of trials.
inputs: For providing the model with the reach target.
target: The target trajectory for the mechanical end effector.
intervene: A mapping from unique intervenor names, to per-trial
intervention parameters.
"""
inits: WhereDict
inputs: SimpleReachTaskInputs
target: CartesianState
intervene: Mapping[str, Array] = field(default_factory=dict)
class DelayedReachTrialSpec(AbstractReachTrialSpec):
"""Trial specification for a delayed reaching task.
Attributes:
inits: A mapping from `lambdas` that select model substates to be
initialized, to substates to initialize them with at the start of trials.
inputs: For providing the model with the reach target and hold signal.
target: The target trajectory for the mechanical end effector.
epoch_start_idxs: The indices of the start of each epoch in the trial.
intervene: A mapping from unique intervenor names, to per-trial
intervention parameters.
"""
inits: WhereDict
inputs: DelayedReachTaskInputs
target: CartesianState
epoch_start_idxs: Int[Array, "n_epochs"]
intervene: Mapping[str, Array] = field(default_factory=dict)
class AbstractTask(Module):
"""Abstract base class for tasks.
Provides methods for evaluating suitable models or ensembles of models on
training and validation trials.
!!! Note ""
Subclasses must provide:
- a method that generates training trials
- a property that provides a set of validation trials
- a field for a loss function that grades performance on the task
Attributes:
loss_func: The loss function that grades task performance.
n_steps: The number of time steps in the task trials.
seed_validation: The random seed for generating the validation trials.
intervention_specs: A mapping from unique intervenor names, to specifications
for generating per-trial intervention parameters on training trials.
intervention_specs_validation: A mapping from unique intervenor names, to
specifications for generating per-trial intervention parameters on
validation trials.
"""
loss_func: AbstractVar[AbstractLoss]
n_steps: AbstractVar[int]
seed_validation: AbstractVar[int]
# TODO: The following line is wrong: each entry will have the same PyTree structure as `AbstractIntervenorInput`
# but will be filled with callables that specify a trial distribution for the leaves
intervention_specs: AbstractVar[Mapping[str, "AbstractIntervenorInput"]]
intervention_specs_validation: AbstractVar[Mapping[str, "AbstractIntervenorInput"]]
def _intervention_params(
self,
intervention_specs: Mapping[str, "AbstractIntervenorInput"],
trial_spec: AbstractTaskTrialSpec,
key: PRNGKeyArray,
):
# Evaluate any parameters that vary by trial.
intervention_params = tree_call(
intervention_specs,
trial_spec,
key=key,
# We don't want to unwrap the time series params, yet.
exclude=lambda x: isinstance(x, TimeSeriesParam),
)
timeseries, other = eqx.partition(
intervention_params,
lambda x: isinstance(x, TimeSeriesParam),
is_leaf=lambda x: isinstance(x, TimeSeriesParam),
)
# Unwrap the `TimeSeriesParam` instances.
timeseries_arrays = tree_call(
timeseries, is_leaf=lambda x: isinstance(x, TimeSeriesParam)
)
# timeseries_arrays = jax.tree_map(
# lambda x: x(),
# timeseries,
# is_leaf=lambda x: isinstance(x, TimeSeriesParam),
# )
# Broadcast the non-timeseries arrays.
other_broadcasted = jax.tree_map(
lambda x: jnp.broadcast_to(x, (self.n_steps - 1, *x.shape)),
jax.tree_map(jnp.array, other),
)
return eqx.combine(timeseries_arrays, other_broadcasted)
@eqx.filter_jit
def get_train_trial_with_intervenor_params(
self,
key: PRNGKeyArray,
) -> AbstractTaskTrialSpec:
"""Return a single training trial specification, including intervention parameters.
Arguments:
key: A random key for generating the trial.
"""
key, key_intervene = jr.split(key)
with jax.named_scope(f"{type(self).__name__}.get_train_trial"):
trial_spec = self.get_train_trial(key)
trial_spec = eqx.tree_at(
lambda x: x.intervene,
trial_spec,
self._intervention_params(
self.intervention_specs,
trial_spec,
key_intervene,
),
is_leaf=lambda x: x is None,
)
return trial_spec
@abstractmethod
def get_train_trial(
self,
key: PRNGKeyArray,
) -> AbstractTaskTrialSpec:
"""Return a single training trial specification.
Arguments:
key: A random key for generating the trial.
"""
...
@abstractmethod
def get_validation_trials(
self,
key: PRNGKeyArray,
) -> AbstractTaskTrialSpec:
"""Return a set of validation trials, given a random key.
!!! Note ""
Subclasses must override this method. However, the validation
used during training and provided by `self.validation_set`
will be determined by the field `self.seed_validation`, which must
also be implemented by subclasses.
Arguments:
key: A random key for generating the validation set.
"""
...
@cached_property
def validation_trials(self) -> AbstractTaskTrialSpec:
"""The set of validation trials associated with the task."""
key = jr.PRNGKey(self.seed_validation)
keys = jr.split(key, self.n_validation_trials)
trial_specs = self.get_validation_trials(key)
trial_specs = eqx.tree_at(
lambda x: x.intervene,
trial_specs,
eqx.filter_vmap(self._intervention_params, in_axes=(None, 0, 0))(
self.intervention_specs_validation,
trial_specs,
keys,
),
is_leaf=lambda x: x is None,
)
return trial_specs
@abstractproperty
def n_validation_trials(self) -> int:
"""Number of trials in the validation set."""
...
@eqx.filter_jit
@jax.named_scope("fbx.AbstractTask.eval_trials")
def eval_trials(
self,
model: "AbstractModel[StateT]",
trial_specs: AbstractTaskTrialSpec,
keys: PRNGKeyArray,
) -> Tuple[LossDict, StateT]:
"""Evaluate a model on a set of trials.
Arguments:
model: The model to evaluate.
trial_specs: The set of trials to evaluate the model on.
keys: For providing randomness during model evaluation.
"""
init_states = jax.vmap(model.init)(key=keys)
for where_substate, init_substates in trial_specs.inits.items():
init_states = eqx.tree_at(
where_substate,
init_states,
init_substates,
)
init_states = jax.vmap(model.step.state_consistency_update)(init_states)
states = eqx.filter_vmap(model)( # ), in_axes=(eqx.if_array(0), 0, 0))(
ModelInput(trial_specs.inputs, trial_specs.intervene),
init_states,
keys,
)
losses = self.loss_func(states, trial_specs)
return losses, states
def eval_with_loss(
self,
model: "AbstractModel[StateT]",
key: PRNGKeyArray,
) -> Tuple[LossDict, StateT]:
"""Evaluate a model on the task's validation set of trials.
Arguments:
model: The model to evaluate.
key: For providing randomness during model evaluation.
Returns:
The losses for the trials in the validation set.
The evaluated model states.
"""
keys = jr.split(key, self.n_validation_trials)
trial_specs = self.validation_trials
return self.eval_trials(model, trial_specs, keys)
def eval(
self,
model: "AbstractModel[StateT]",
key: PRNGKeyArray,
) -> StateT:
"""Return states for a model evaluated on the tasks's set of validation trials.
Arguments:
model: The model to evaluate.
key: For providing randomness during model evaluation.
"""
return self.eval_with_loss(model, key)[1]
@eqx.filter_jit
def eval_ensemble(
self,
models: "AbstractModel[StateT]",
n_replicates: int,
key: PRNGKeyArray,
) -> StateT:
"""Return states for an ensemble of models evaluated on the tasks's set of
validation trials.
Arguments:
models: The ensemble of models to evaluate.
n_replicates: The number of models in the ensemble.
key: For providing randomness during model evaluation.
Will be split into `n_replicates` keys.
"""
# TODO: Why not just use `eqx.filter_vmap`? It should handle the array partitioning.
models_arrays, models_other = eqx.partition(models, eqx.is_array)
def evaluate_single(model_arrays, model_other, key):
model = eqx.combine(model_arrays, model_other)
return self.eval(model, key)
# TODO: Instead, we should expect the user to provide `keys` instead of `key`,
# if they are vmapping `eval`.
keys_eval = jr.split(key, n_replicates)
return eqx.filter_vmap(evaluate_single, in_axes=(0, None, 0))(
models_arrays, models_other, keys_eval
)
@eqx.filter_jit
def eval_train_batch(
self,
model: "AbstractModel[StateT]",
batch_size: int,
key: PRNGKeyArray,
) -> Tuple[LossDict, StateT, AbstractTaskTrialSpec]:
"""Evaluate a model on a single batch of training trials.
Arguments:
model: The model to evaluate.
batch_size: The number of trials in the batch.
key: For providing randomness during model evaluation.
Returns:
The losses for the trials in the batch.
The evaluated model states.
The trial specifications for the batch.
"""
key_batch, key_eval = jr.split(key)
keys_batch = jr.split(key_batch, batch_size)
keys_eval = jr.split(key_eval, batch_size)
trials = jax.vmap(self.get_train_trial_with_intervenor_params)(keys_batch)
losses, states = self.eval_trials(model, trials, keys_eval)
return losses, states, trials
@eqx.filter_jit
def eval_ensemble_train_batch(
self,
models: "AbstractModel[StateT]",
n_replicates: int,
batch_size: int,
key: PRNGKeyArray,
) -> Tuple[LossDict, StateT, AbstractTaskTrialSpec]:
"""Evaluate an ensemble of models on a single training batch.
Arguments:
models: The ensemble of models to evaluate.
n_replicates: The number of models in the ensemble.
batch_size: The number of trials in the batch to evaluate.
key: For providing randomness during model evaluation.
Returns:
The losses for the trials in the batch, for each model in the ensemble.
The evaluated model states, for each trial and each model in the ensemble.
The trial specifications for the batch.
"""
models_arrays, models_other = eqx.partition(models, eqx.is_array)
def evaluate_single(model_arrays, model_other, batch_size, key):
model = eqx.combine(model_arrays, model_other)
return self.eval_train_batch(model, batch_size, key)
keys_eval = jr.split(key, n_replicates)
return eqx.filter_vmap(evaluate_single, in_axes=(0, None, None, 0))(
models_arrays, models_other, batch_size, keys_eval
)
def _pos_only_states(pos_endpoints: Float[Array, "... ndim=2"]):
"""Construct Cartesian init and target states with zero force and velocity."""
vel_endpoints = jnp.zeros_like(pos_endpoints)
forces = jnp.zeros_like(pos_endpoints)
states = jax.tree_map(
lambda x: CartesianState(*x),
list(zip(pos_endpoints, vel_endpoints, forces)),
is_leaf=lambda x: isinstance(x, tuple),
)
return states
def internal_grid_points(
bounds: Float[Array, "bounds=2 ndim=2"], n: int = 2
) -> Float[Array, "n**ndim ndim=2"]:
"""Return a list of evenly-spaced grid points internal to the bounds.
Arguments:
bounds: The outer bounds of the grid.
n: The number of internal grid points along each dimension.
!!! Example
```python
internal_grid_points(
bounds=((0, 0), (9, 9)),
n=2,
)
```
```>> Array([[3., 3.], [6., 3.], [3., 6.], [6., 6.]]).```
"""
ticks = jax.vmap(lambda b: jnp.linspace(b[0], b[1], n + 2)[1:-1])(bounds.T)
points = jnp.vstack(jax.tree_map(jnp.ravel, jnp.meshgrid(*ticks))).T
return points
def _centerout_endpoints_grid(
workspace: Float[Array, "bounds=2 ndim=2"],
eval_grid_n: int,
eval_n_directions: int,
eval_reach_length: float,
):
"""Sets of center-out reaches, their centers in a grid across a workspace."""
centers = internal_grid_points(workspace, eval_grid_n)
pos_endpoints = jax.vmap(
centreout_endpoints,
in_axes=(0, None, None),
out_axes=1,
)(
centers, eval_n_directions, eval_reach_length
).reshape((2, -1, N_DIM))
return pos_endpoints
def _forceless_task_inputs(
target_states: CartesianState,
) -> CartesianState:
"""Only position and velocity of targets are supplied to the model."""
return CartesianState(
pos=target_states.pos,
vel=target_states.vel,
force=None,
)
class SimpleReaches(AbstractTask):
"""Reaches between random endpoints in a rectangular workspace. No hold signal.
Validation set is center-out reaches.
!!! Note
This passes a trajectory of target velocities all equal to zero, assuming
that the user will choose a loss function that penalizes only the initial
or final velocities. If the loss function penalizes the intervening velocities,
this task no longer makes sense as a reaching task.
Attributes:
n_steps: The number of time steps in each task trial.
loss_func: The loss function that grades performance on each trial.
workspace: The rectangular workspace in which the reaches are distributed.
seed_validation: The random seed for generating the validation trials.
intervention_specs: A mapping from unique intervenor names, to specifications
for generating per-trial intervention parameters on training trials.
intervention_specs_validation: A mapping from unique intervenor names, to
specifications for generating per-trial intervention parameters on
validation trials.
eval_grid_n: The number of evenly-spaced internal grid points of the
workspace at which a set of center-out reach is placed.
eval_n_directions: The number of evenly-spread center-out reaches
starting from each workspace grid point in the validation set. The number
of trials in the validation set is equal to
`eval_n_directions * eval_grid_n ** 2`.
eval_reach_length: The length (in space) of each reach in the validation set.
"""
n_steps: int
loss_func: AbstractLoss
workspace: Float[Array, "bounds=2 ndim=2"] = field(converter=jnp.asarray)
seed_validation: int = 5555
intervention_specs: Mapping[str, "AbstractIntervenorInput"] = field(default_factory=dict)
intervention_specs_validation: Mapping[str, "AbstractIntervenorInput"] = field(
default_factory=dict
)
eval_n_directions: int = 7
eval_reach_length: float = 0.5
eval_grid_n: int = 1 # e.g. 2 -> 2x2 grid of center-out reach sets
def get_train_trial(self, key: PRNGKeyArray) -> SimpleReachTrialSpec:
"""Random reach endpoints across the rectangular workspace.
Arguments:
key: A random key for generating the trial.
"""
effector_pos_endpoints = uniform_tuples(key, n=2, bounds=self.workspace)
effector_init_state, effector_target_state = _pos_only_states(
effector_pos_endpoints
)
# Broadcast the fixed targets to a sequence with the desired number of
# time steps, since that's what `ForgetfulIterator` and `Loss` will expect.
# Hopefully this should not use up any extra memory.
effector_target_state = jax.tree_map(
lambda x: jnp.broadcast_to(x, (self.n_steps, *x.shape)),
effector_target_state,
)
effector_target = _forceless_task_inputs(
jax.tree_map(
lambda x: x[:-1],
effector_target_state,
)
)
# TODO: It might be better here to use an `Intervenor`-like callable
# instead of `WhereDict`, which is slow. Though the callable would
# ideally provide the initial state as a
# def init_func(state):
# return eqx.tree_at(
# lambda state: state.mechanics.effector,
# state,
# effector_init_state,
# )
return SimpleReachTrialSpec(
inits=WhereDict(
{(lambda state: state.mechanics.effector): effector_init_state}
),
inputs=SimpleReachTaskInputs(effector_target=effector_target),
target=effector_target_state,
)
def get_validation_trials(self, key: PRNGKeyArray) -> SimpleReachTrialSpec:
"""Center-out reach sets in a grid across the rectangular workspace."""
effector_pos_endpoints = _centerout_endpoints_grid(
self.workspace,
self.eval_grid_n,
self.eval_n_directions,
self.eval_reach_length,
)
effector_init_states, effector_target_states = _pos_only_states(
effector_pos_endpoints
)
# Broadcast to the desired number of time steps. Awkwardly, we also
# need to use `swapaxes` because the batch dimension is explicit, here.
effector_target_states = jax.tree_map(
lambda x: jnp.swapaxes(jnp.broadcast_to(x, (self.n_steps, *x.shape)), 0, 1),
effector_target_states,
)
task_inputs = _forceless_task_inputs(
jax.tree_map(
lambda x: x[:, :-1],
effector_target_states,
)
)
return SimpleReachTrialSpec(
inits=WhereDict(
{(lambda state: state.mechanics.effector): effector_init_states}
),
inputs=task_inputs,
target=effector_target_states,
)
@property
def n_validation_trials(self) -> int:
"""Number of trials in the validation set."""
return self.eval_grid_n**2 * self.eval_n_directions
class DelayedReaches(AbstractTask):
"""Uniform random endpoints in a rectangular workspace.
e.g. allows for a stimulus epoch, followed by a delay period, then movement.
Attributes:
loss_func: The loss function that grades performance on each trial.
workspace: The rectangular workspace in which the reaches are distributed.
n_steps: The number of time steps in each task trial.
epoch_len_ranges: The ranges from which to uniformly sample the durations of
the task phases for each task trial.
target_on_epochs: The epochs in which the "target on" signal is turned on.
hold_epochs: The epochs in which the hold signal is turned on.
eval_n_directions: The number of evenly-spread center-out reaches
starting from each workspace grid point in the validation set. The number
of trials in the validation set is equal to
`eval_n_directions * eval_grid_n ** 2`.
eval_reach_length: The length (in space) of each reach in the validation set.
eval_grid_n: The number of evenly-spaced internal grid points of the
workspace at which a set of center-out reach is placed.
seed_validation: The random seed for generating the validation trials.
"""
loss_func: AbstractLoss
workspace: Float[Array, "bounds=2 ndim=2"] = field(converter=jnp.asarray)
n_steps: int
epoch_len_ranges: Tuple[Tuple[int, int], ...] = field(
default=(
(5, 15), # start
(10, 20), # target on ("stim")
(10, 25), # delay
)
)
target_on_epochs: Int[Array, "_"] = field(default=(1,), converter=jnp.asarray)
hold_epochs: Int[Array, "_"] = field(default=(0, 1, 2), converter=jnp.asarray)
eval_n_directions: int = 7
eval_reach_length: float = 0.5
eval_grid_n: int = 1
seed_validation: int = 5555
def get_train_trial(self, key: PRNGKeyArray) -> DelayedReachTrialSpec:
"""Random reach endpoints across the rectangular workspace.
Arguments:
key: A random key for generating the trial.
"""
key1, key2 = jr.split(key)
effector_pos_endpoints = uniform_tuples(key1, n=2, bounds=self.workspace)
effector_init_state, effector_target_state = _pos_only_states(
effector_pos_endpoints
)
task_inputs, effector_target_states, epoch_start_idxs = self._get_sequences(
effector_init_state, effector_target_state, key2
)
return DelayedReachTrialSpec(
inits=WhereDict(
{(lambda state: state.mechanics.effector): effector_init_state}
),
inputs=task_inputs,
target=effector_target_states,
epoch_start_idxs=epoch_start_idxs,
)
def get_validation_trials(self, key: PRNGKeyArray) -> DelayedReachTrialSpec:
"""Center-out reach sets in a grid across the rectangular workspace."""
effector_pos_endpoints = _centerout_endpoints_grid(
self.workspace,
self.eval_grid_n,
self.eval_n_directions,
self.eval_reach_length,
)
effector_init_states, effector_target_states = _pos_only_states(
effector_pos_endpoints
)
key_val = jr.PRNGKey(self.seed_validation)
epochs_keys = jr.split(key_val, effector_init_states.pos.shape[0])
task_inputs, effector_target_states, epoch_start_idxs = jax.vmap(
self._get_sequences
)(effector_init_states, effector_target_states, epochs_keys)
return DelayedReachTrialSpec(
inits=WhereDict(
{(lambda state: state.mechanics.effector): effector_init_states}
),
inputs=task_inputs,
target=effector_target_states,
epoch_start_idxs=epoch_start_idxs,
)
def _get_sequences(
self,
init_states: CartesianState,
target_states: CartesianState,
key: PRNGKeyArray,
) -> Tuple[DelayedReachTaskInputs, CartesianState, Int[Array, "n_epochs"]]:
"""Convert static task inputs to sequences, and make hold signal."""
epoch_lengths = gen_epoch_lengths(key, self.epoch_len_ranges)
epoch_start_idxs = jnp.pad(
jnp.cumsum(epoch_lengths), (1, 0), constant_values=(0, -1)
)
epoch_masks = get_masks(self.n_steps, epoch_start_idxs)
move_epoch_mask = jnp.logical_not(jnp.prod(epoch_masks, axis=0))[None, :]
stim_seqs = get_masked_seqs(
_forceless_task_inputs(target_states), epoch_masks[self.target_on_epochs]
)
target_seqs = jax.tree_map(
lambda x, y: x + y,
get_masked_seqs(target_states, move_epoch_mask),
get_masked_seqs(init_states, epoch_masks[self.hold_epochs]),
)
stim_on_seq = get_scalar_epoch_seq(
epoch_start_idxs, self.n_steps, 1.0, self.target_on_epochs
)
hold_seq = get_scalar_epoch_seq(
epoch_start_idxs, self.n_steps, 1.0, self.hold_epochs
)
task_input = DelayedReachTaskInputs(stim_seqs, hold_seq, stim_on_seq)
target_states = target_seqs
return task_input, target_states, epoch_start_idxs
def n_validation_trials(self) -> int:
"""Number of trials in the validation set."""
return self.eval_grid_n**2 * self.eval_n_directions
class Stabilization(AbstractTask):
"""Postural stabilization task at random points in workspace.
Validation set is center-out reaches.
"""
loss_func: AbstractLoss
workspace: Float[Array, "bounds=2 ndim=2"] = field(converter=jnp.asarray)
n_steps: int
eval_grid_n: int # e.g. 2 -> 2x2 grid
eval_workspace: Optional[Float[Array, "bounds=2 ndim=2"]] = field(
converter=jnp.asarray, default=None
)
@eqx.filter_jit
@jax.named_scope("fbx.SimpleReaches.get_train_trial")
def get_train_trial(self, key: PRNGKeyArray) -> SimpleReachTrialSpec:
"""Random reach endpoints in a 2D rectangular workspace."""
points = uniform_tuples(key, n=1, bounds=self.workspace)
target_state = _pos_only_states(points)
init_state = target_state
# Broadcast the fixed targets to a sequence with the desired number of
# time steps, since that's what `ForgetfulIterator` and `Loss` will expect.
# Hopefully this should not use up any extra memory.
target_state = jax.tree_map(
lambda x: jnp.broadcast_to(x, (self.n_steps, *x.shape)),
target_state,
)
task_input = _forceless_task_inputs(target_state)
return SimpleReachTrialSpec(
inits=WhereDict({lambda state: state.mechanics.effector: init_state}),
inputs=task_input,
target=target_state,
)
def get_validation_trials(self, key: PRNGKeyArray) -> SimpleReachTrialSpec:
"""Center-out reaches across a regular workspace grid."""
if self.eval_workspace is None:
workspace = self.workspace
else:
workspace = self.eval_workspace
pos_endpoints = _points_grid(
workspace,
self.eval_grid_n,
)
target_states = _pos_only_states(pos_endpoints)
init_states = target_states
# Broadcast to the desired number of time steps. Awkwardly, we also
# need to use `swapaxes` because the batch dimension is explicit, here.
target_states = jax.tree_map(
lambda x: jnp.swapaxes(jnp.broadcast_to(x, (self.n_steps, *x.shape)), 0, 1),
target_states,
)
task_inputs = _forceless_task_inputs(target_states)
return SimpleReachTrialSpec(
inits=WhereDict({lambda state: state.mechanics.effector: init_states}),
inputs=task_inputs,
target=target_states,
)
@property
def n_validation_trials(self) -> int:
"""Size of the validation set."""
return self.eval_grid_n**2
def _points_grid(
workspace: Float[Array, "bounds=2 ndim=2"],
grid_n: int | Tuple[int, int],
):
"""A regular grid of points over a rectangular workspace.
Args:
grid_n: Number of grid points in each dimension.
"""
if isinstance(grid_n, int):
grid_n = (grid_n, grid_n)