Skip to content

Commit

Permalink
Allow user to modify params of intervenors already scheduled with a task
Browse files Browse the repository at this point in the history
- Combine `AbstractTask.intervention_specs` and
  `.intervention_specs_validation` into `intervention_specs` Module
  with `training` and `validation` fields.
- Write function `update_intervenor_param_schedule` that returns
  an updated task where the `intervenor_specs` contain different
  parameters. This is an alternative to rescheduling an
  intervention, and only requires knowledge of the intervention
  label, and the names and values of the param(s) to be changed.
  • Loading branch information
mlprt committed Apr 20, 2024
1 parent d08f403 commit 0dbaffb
Show file tree
Hide file tree
Showing 2 changed files with 105 additions and 92 deletions.
109 changes: 79 additions & 30 deletions feedbax/intervene/schedule.py
Original file line number Diff line number Diff line change
Expand Up @@ -244,6 +244,10 @@ def __call__(self):
return self.param


def is_timeseries_param(x):
return isinstance(x, TimeSeriesParam)


def _eval_intervenor_param_spec(
intervention_spec: InterventionSpec,
trial_spec, #: AbstractTaskTrialSpec,
Expand All @@ -257,10 +261,10 @@ def _eval_intervenor_param_spec(
trial_spec,
key=key,
# Don't unwrap `TimeSeriesParam`s yet:
exclude=lambda x: isinstance(x, TimeSeriesParam),
is_leaf=lambda x: isinstance(x, TimeSeriesParam),
exclude=is_timeseries_param,
is_leaf=is_timeseries_param,
),
is_leaf=lambda x: isinstance(x, TimeSeriesParam),
is_leaf=is_timeseries_param,
)


Expand Down Expand Up @@ -365,7 +369,7 @@ def schedule_intervenor(
invalid_labels_tasks = jax.tree_util.tree_reduce(
lambda x, y: x + y,
jax.tree_map(
lambda task: tuple(task.all_intervention_specs.keys()),
lambda task: tuple(task.intervention_specs.all.keys()),
tasks,
is_leaf=is_module, # AbstractTask
),
Expand All @@ -377,17 +381,16 @@ def schedule_intervenor(
label = type(intervenor_).__name__
label = get_unique_label(label, invalid_labels)

# Construct the additions to `AbstractTask.intervenor_specs*`
# Set to active (`True`) by default.
intervention_specs = {label: (True, InterventionSpec(
# Construct the additions to `AbstractTask.intervenor_specs`
intervention_specs = {label: InterventionSpec(
intervenor=intervenor_,
where=where,
stage_name=stage_name,
default_active=default_active,
))}
)}

if intervenor_params_validation is not None:
intervention_specs_validation = {label: (True, InterventionSpec(
intervention_specs_validation = {label: InterventionSpec(
intervenor=eqx.tree_at(
lambda intervenor: intervenor.params,
intervenor_,
Expand All @@ -396,7 +399,7 @@ def schedule_intervenor(
where=where,
stage_name=stage_name,
default_active=default_active,
))}
)}
elif validation_same_schedule:
intervention_specs_validation = intervention_specs
else:
Expand All @@ -405,11 +408,14 @@ def schedule_intervenor(
# Add the spec intervenors to every task in `tasks`
tasks = jax.tree_map(
lambda task: eqx.tree_at(
lambda task: (task.intervention_specs, task.intervention_specs_validation),
lambda task: (
task.intervention_specs.training,
task.intervention_specs.validation,
),
task,
(
task.intervention_specs | intervention_specs,
task.intervention_specs_validation | intervention_specs_validation,
task.intervention_specs.training | intervention_specs,
task.intervention_specs.validation | intervention_specs_validation,
),
),
tasks,
Expand All @@ -433,7 +439,7 @@ def schedule_intervenor(
intervenor,
_eval_intervenor_param_spec(
# Prefer the validation parameters, if they exist.
(intervention_specs | intervention_specs_validation)[label][1],
(intervention_specs | intervention_specs_validation)[label],
trial_spec_example,
key_example,
)
Expand All @@ -459,22 +465,65 @@ def schedule_intervenor(
return tasks, models


# def update_intervenor_param_schedule(
# task: "AbstractTask",
# params: Mapping[IntervenorLabelStr, dict[str, Any]],
# training: bool = True,
# validation: bool = True,
# ) -> "AbstractTask":
# for cond, suffix in {training: "", validation: "_validation"}.items():
# if cond:
# specs = _get
def update_intervenor_param_schedule(
task: "AbstractTask",
params: Mapping[IntervenorLabelStr, Mapping[str, Any]],
training: bool = True,
validation: bool = True,
is_leaf: Optional[Callable[..., bool]] = None,
) -> "AbstractTask":
"""Return a task with updated specifications for intervention parameters.
This might fail if the parameter is passed, or already assigned, as an `eqx.Module`
or other PyTree, since `tree_leaves` will flatten its contents. In that case you
should set `is_leaf=is_module` (or similar) so that the entire object is treated
as the parameter.
TODO: Just... flatten the nested dict instead of using `tree_leaves`, to avoid this issue.
Arguments:
task: The task to modify.
params: A mapping from intervenor labels (a subset of the keys from the fields
of `task.intervention_specs`) to mappings from parameter names to updated
parameter values.
training: Whether to apply the update to the training trial intervention
specifications.
validation: Whether to apply the update to the validation trial intervention
specifications.
is_leaf: A function that returns `True` for objects that should be treated
as parameters.
"""
if isinstance(is_leaf, Callable):
is_leaf_or_timeseries = lambda x: is_leaf(x) or is_timeseries_param(x)
else:
is_leaf_or_timeseries = is_timeseries_param

params_flat = jax.tree_leaves(params, is_leaf=is_leaf_or_timeseries)

for cond, suffix in {training: "training", validation: "validation"}.items():
if cond:
specs = getattr(task.intervention_specs, suffix)

# task = eqx.tree_at(
# lambda task: getattr(task, f"intervention_specs{suffix}"),
# task,
# specs,
# )
# return task
specs = eqx.tree_at(
lambda specs: jax.tree_leaves({
intervenor_label: {
param_name: getattr(
specs[intervenor_label].intervenor.params,
param_name,
)
for param_name in ps
}
for intervenor_label, ps in params.items()
}, is_leaf=is_leaf_or_timeseries),
specs,
params_flat,
)

task = eqx.tree_at(
lambda task: getattr(task.intervention_specs, suffix),
task,
specs,
)
return task


# TODO: take `Sequence[IntervenorSpec]` or `dict[IntervenorLabel, IntervenorSpec]`
Expand All @@ -485,7 +534,7 @@ def update_fixed_intervenor_param(
param_name: str,
replace: Any,
labels: Optional[PyTree[str, 'T']] = None,
):
) -> "AbstractStagedModel":
if labels is None:
labels = jax.tree_map(
lambda spec: f"FIXED_{type(spec.intervenor).__name__}",
Expand Down
88 changes: 26 additions & 62 deletions feedbax/task.py
Original file line number Diff line number Diff line change
Expand Up @@ -291,10 +291,17 @@ class DelayedReachTrialSpec(AbstractReachTrialSpec):
NonCharSequence: TypeAlias = MutableSequence[T] | tuple[T, ...]


active: TypeAlias = Literal[True]
inactive: TypeAlias = Literal[False]
IsActive: TypeAlias = active | inactive
TaskInterventionSpecs: TypeAlias = Mapping[IntervenorLabelStr, tuple[IsActive, InterventionSpec]]
LabeledInterventionSpecs: TypeAlias = Mapping[IntervenorLabelStr, InterventionSpec]


class TaskInterventionSpecs(Module):
training: LabeledInterventionSpecs = field(default_factory=dict)
validation: LabeledInterventionSpecs = field(default_factory=dict)

@cached_property
def all(self) -> LabeledInterventionSpecs:
# Validation specs are assumed to take precedence, in case of conflicts.
return {**self.training, **self.validation}


class AbstractTask(Module):
Expand All @@ -314,18 +321,15 @@ class AbstractTask(Module):
loss_func: The loss function that grades task performance.
n_steps: The number of time steps in the task trials.
seed_validation: The random seed for generating the validation trials.
# TODO
intervention_specs: A mapping from unique intervenor names, to specifications
for generating per-trial intervention parameters on training trials.
intervention_specs_validation: A mapping from unique intervenor names, to
specifications for generating per-trial intervention parameters on
validation trials.
intervention_specs: Mappings from unique intervenor names, to specifications
for generating per-trial intervention parameters. Distinct fields provide
mappings for training and validation trials, though the two may be identical
depending on scheduling.
"""

loss_func: AbstractVar[AbstractLoss]
n_steps: AbstractVar[int]
seed_validation: AbstractVar[int]

intervention_specs: AbstractVar[TaskInterventionSpecs]

def __check_init__(self):
Expand All @@ -346,28 +350,6 @@ def get_train_trial(
"""
...

@cached_property
def active_intervention_specs(self):
return {
k: v for k, (is_active, v) in self.intervention_specs.items()
if is_active
}

@cached_property
def active_intervention_specs_validation(self):
return {
k: v for k, (is_active, v) in self.intervention_specs_validation.items()
if is_active
}

@cached_property
def all_intervention_specs(self) -> Mapping[str, InterventionSpec]:
return {k: v for k, (_, v) in self.intervention_specs.items()}

@cached_property
def all_intervention_specs_validation(self) -> Mapping[str, InterventionSpec]:
return {k: v for k, (_, v) in self.intervention_specs_validation.items()}

@eqx.filter_jit
def get_train_trial_with_intervenor_params(
self,
Expand All @@ -387,7 +369,7 @@ def get_train_trial_with_intervenor_params(
lambda x: x.intervene,
trial_spec,
self._intervenor_params(
self.active_intervention_specs,
self.intervention_specs.training,
trial_spec,
key_intervene,
),
Expand Down Expand Up @@ -468,7 +450,7 @@ def validation_trials(self) -> AbstractTaskTrialSpec:
lambda x: x.intervene,
trial_specs,
eqx.filter_vmap(self._intervenor_params, in_axes=(None, 0, 0))(
self.active_intervention_specs_validation,
self.intervention_specs.validation,
trial_specs,
keys,
),
Expand Down Expand Up @@ -725,10 +707,11 @@ def add_intervenors_to_base_model(

# Use schedule_intervenors to reproduce `self`, along with the modified model
task, model_ = base_task, base_model
for label, spec in self.all_intervention_specs.items():
if label in self.all_intervention_specs_validation:
for label, spec in self.intervention_specs.training.items():
#! This won't work if an intervenor spec is only present in the validation dict
if label in self.intervention_specs.validation:
intervenor_params_val = (
self.all_intervention_specs_validation[label].intervenor.params
self.intervention_specs.validation[label].intervenor.params
)
else:
intervenor_params_val = None
Expand Down Expand Up @@ -759,7 +742,7 @@ def activate_interventions(
"""

if labels == 'all':
labels = list(self.intervention_specs.keys())
labels = list(self.intervention_specs.training.keys())
elif labels == 'none':
labels = []

Expand All @@ -769,7 +752,7 @@ def activate_interventions(
if validation_same_schedule:
labels_validation = labels
elif validation_same_schedule == 'all':
labels_validation = list(self.intervention_specs_validation.keys())
labels_validation = list(self.intervention_specs.validation.keys())
elif validation_same_schedule == 'none':
labels_validation = []

Expand Down Expand Up @@ -855,10 +838,6 @@ def _forceless_task_inputs(
)


def _empty_task_intervention_specs():
return dict()


class SimpleReaches(AbstractTask):
"""Reaches between random endpoints in a rectangular workspace. No hold signal.
Expand Down Expand Up @@ -893,12 +872,7 @@ class SimpleReaches(AbstractTask):
loss_func: AbstractLoss
workspace: Float[Array, "bounds=2 ndim=2"] = field(converter=jnp.asarray)
seed_validation: int = 5555
intervention_specs: TaskInterventionSpecs = field(
default_factory=_empty_task_intervention_specs
)
intervention_specs_validation: TaskInterventionSpecs = field(
default_factory=_empty_task_intervention_specs
)
intervention_specs: TaskInterventionSpecs = TaskInterventionSpecs()
eval_n_directions: int = 7
eval_reach_length: float = 0.5
eval_grid_n: int = 1 # e.g. 2 -> 2x2 grid of center-out reach sets
Expand Down Expand Up @@ -1031,12 +1005,7 @@ class DelayedReaches(AbstractTask):
eval_reach_length: float = 0.5
eval_grid_n: int = 1
seed_validation: int = 5555
intervention_specs: TaskInterventionSpecs = field(
default_factory=_empty_task_intervention_specs
)
intervention_specs_validation: TaskInterventionSpecs = field(
default_factory=_empty_task_intervention_specs
)
intervention_specs: TaskInterventionSpecs = TaskInterventionSpecs()

def get_train_trial(self, key: PRNGKeyArray) -> DelayedReachTrialSpec:
"""Random reach endpoints across the rectangular workspace.
Expand Down Expand Up @@ -1148,12 +1117,7 @@ class Stabilization(AbstractTask):
eval_workspace: Optional[Float[Array, "bounds=2 ndim=2"]] = field(
converter=jnp.asarray, default=None
)
intervention_specs: TaskInterventionSpecs = field(
default_factory=_empty_task_intervention_specs
)
intervention_specs_validation: TaskInterventionSpecs = field(
default_factory=_empty_task_intervention_specs
)
intervention_specs: TaskInterventionSpecs = TaskInterventionSpecs()

def get_train_trial(self, key: PRNGKeyArray) -> SimpleReachTrialSpec:
"""Random reach endpoints in a 2D rectangular workspace."""
Expand Down

0 comments on commit 0dbaffb

Please sign in to comment.