Skip to content

Commit

Permalink
Generalize loss calculations (2 -- working)
Browse files Browse the repository at this point in the history
- Convert `AbstractTaskTrialSpec` to `TaskTrialSpec`, and remove the
  existing subclasses like `DelayedReachTrialSpec`.
- Remove references to `trial_specs.goal` from `feedbax.plot` and
  `feedbax.plotly`.
- Add default spec to `TargetSpecLoss`, so that loss terms may either
  be static (e.g. penalize non-zero states) without any info provided
  by `trial_specs.targets`, OR have data provided and vary
  trial-by-trial (e.g. penalize distance from this trial's goal).
- Rename `TargetSpecLoss` to `TargetStateLoss`.
- Switch `WhereDict` to use `OrderedDict` tree flatten, which it
  should have used in the first place.
- Draft docstring for `TargetStateLoss`.
- Debug `TargetStateLoss` using example notebook 1.

Breaks the old loss classes like `EffectorPositionLoss`, since
`target` is no longer a field of `TaskTrialSpec`. However, I have not
removed/replaced these classes yet.
  • Loading branch information
mlprt committed Apr 27, 2024
1 parent 9d1056a commit 4f735f4
Show file tree
Hide file tree
Showing 11 changed files with 358 additions and 221 deletions.
10 changes: 6 additions & 4 deletions docs/api/task.md
Original file line number Diff line number Diff line change
Expand Up @@ -17,24 +17,26 @@

::: feedbax.task.SimpleReachTaskInputs

::: feedbax.task.SimpleReachTrialSpec
<!-- ::: feedbax.task.SimpleReachTrialSpec -->

::: feedbax.task.SimpleReaches

### Delayed (cued) reaching

::: feedbax.task.DelayedReachTaskInputs

::: feedbax.task.DelayedReachTrialSpec
<!-- ::: feedbax.task.DelayedReachTrialSpec -->

::: feedbax.task.DelayedReaches

## Task trial specifications

::: feedbax.task.TaskTrialSpec

## Abstract base classes

<!-- ::: feedbax.task.AbstractTaskInputs -->

::: feedbax.task.AbstractTaskTrialSpec

::: feedbax.task.AbstractTask

## Useful functions for building tasks
Expand Down
2 changes: 1 addition & 1 deletion docs/structure.md
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ In Feedbax, models are trained to perform tasks. Typically, this means running t

The base class for all types of tasks is [`AbstractTask`][feedbax.task.AbstractTask]. It provides 1) specifications for training trials, 2) specifications for validation trials, 3) a loss function, which scores a model's performance on a trial, and 4) methods for running a model on a given set of trials.

[Trial specifications][feedbax.task.AbstractTaskTrialSpec] are always composed of three things:
[Trial specifications][feedbax.task.TaskTrialSpec] are always composed of three things:

1. Data with which to initialize one or more parts of a model's state, prior to a trial;
2. Target data which the loss function will use to score the history of a model's states, over a trial;
Expand Down
17 changes: 12 additions & 5 deletions feedbax/_mapping.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@

import dis
import logging
from typing import Generic, TypeVar
from typing import Generic, TypeVar, overload

import equinox as eqx
from equinox._pretty_print import tree_pp, bracketed
Expand All @@ -23,7 +23,7 @@

logger = logging.getLogger(__name__)


T = TypeVar("T")
KT1 = TypeVar("KT1")
KT2 = TypeVar("KT2")
VT = TypeVar("VT")
Expand Down Expand Up @@ -54,6 +54,13 @@ def __getitem__(self, key: KT1 | KT2) -> VT:
k = self._key_transform(key)
return self.store[k][1]

def get(self, key: KT1 | KT2, /, default: VT | T | None = None) -> VT | T | None:
k = self._key_transform(key)
if k in self.store:
return self.store[k][1]
else:
return default

def __setitem__(self, key: KT2, value: VT):
self.store[self._key_transform(key)] = (key, value)

Expand All @@ -74,7 +81,7 @@ def _key_transform(self, key: KT1 | KT2) -> KT1:

def tree_flatten(self):
"""The same flatten function used by JAX for `dict`"""
return unzip2(sorted(self.items()))[::-1]
return tuple(self.values()), tuple(self.keys())

@classmethod
def tree_unflatten(cls, keys, values):
Expand Down Expand Up @@ -178,11 +185,11 @@ class WhereDict(
at least 20,000 us to train.
"""

def _key_transform(self, key: str | Callable) -> str:
def _key_transform(self, key: str | Callable | tuple[Callable, str]) -> str:
return self.key_transform(key)

@staticmethod
def key_transform(key: str | Callable) -> str:
def key_transform(key: str | Callable | tuple[Callable, str]) -> str:

if isinstance(key, str):
pass
Expand Down
2 changes: 1 addition & 1 deletion feedbax/intervene/schedule.py
Original file line number Diff line number Diff line change
Expand Up @@ -250,7 +250,7 @@ def is_timeseries_param(x):

def _eval_intervenor_param_spec(
intervention_spec: InterventionSpec,
trial_spec, #: AbstractTaskTrialSpec,
trial_spec, #: TaskTrialSpec,
key: PRNGKeyArray,
):
# Unwrap any `TimeSeriesParam` instances:
Expand Down
128 changes: 100 additions & 28 deletions feedbax/loss.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,7 @@

if TYPE_CHECKING:
from feedbax.bodies import SimpleFeedbackState
from feedbax.task import AbstractTaskTrialSpec
from feedbax.task import TaskTrialSpec


logger = logging.getLogger(__name__)
Expand Down Expand Up @@ -108,15 +108,15 @@ class AbstractLoss(Module):
def __call__(
self,
states: PyTree,
trial_specs: "AbstractTaskTrialSpec",
trial_specs: "TaskTrialSpec",
) -> LossDict:
return LossDict({self.label: self.term(states, trial_specs)})

@abstractmethod
def term(
self,
states: PyTree,
trial_specs: "AbstractTaskTrialSpec",
trial_specs: "TaskTrialSpec",
) -> Array:
"""Implement this to calculate a loss term."""
...
Expand Down Expand Up @@ -284,7 +284,7 @@ def __or__(self, other: "CompositeLoss") -> "CompositeLoss":
def __call__(
self,
states: State,
trial_specs: "AbstractTaskTrialSpec",
trial_specs: "TaskTrialSpec",
) -> LossDict:
"""Evaluate, weight, and return all component terms.
Expand Down Expand Up @@ -313,42 +313,114 @@ def __call__(
def term(
self,
states: PyTree,
trial_specs: "AbstractTaskTrialSpec",
trial_specs: "TaskTrialSpec",
) -> Array:
return self(states, trial_specs).total


# Maybe rename TargetValueSpec; I feel like a "`TargetSpec`" would include a `where` field
class TargetSpec(Module):
label: str
data: PyTree[Array]
norm: Callable = jnp.linalg.norm
"""Associate a state's target value with time indices and discounting factors."""
# `target_value` may be `None` when we specify default values for the other fields
target_value: Optional[PyTree[Array]] = None
# TODO: If `time_idxs` is `Array`, it must be 1D or we'll lose the time dimension before we sum over it!
time_idxs: Optional[Array | slice] = field(default_factory=lambda: slice(0, None))
discount: Optional[Array | slice] = field(default_factory=lambda: jnp.array(1))
time_idxs: Optional[Array] = None
discount: Optional[Array] = None # field(default_factory=lambda: jnp.array([1.0]))

def __and__(self, other):
# Allows user to do `target_zero & target_final_state`, for example.
return eqx.combine(self, other)

def __rand__(self, other):
# Necessary for edge case of `None & spec`
return eqx.combine(other, self)

@property
def batch_axes(self) -> PyTree[None | int]:
# Assume that only the target value will vary between trials.
# TODO: (Low priority.) It's probably better to give control over this to
# `AbstractTask`, since in some cases we might want to vary these parameters
# over trials and not just across batches. And if we don't want to vary them
# at all, then why are time_idxs and discount not just fields of
# `TargetStateLoss`?
return TargetSpec(
target_value=0,
time_idxs=None,
discount=None,
)


"""Useful partial target specs"""
target_final_state = TargetSpec(None, jnp.array([-1], dtype=int), None)
target_zero = TargetSpec(jnp.array(0.0), None, None)


class TargetSpecLoss(AbstractLoss):
class TargetStateLoss(AbstractLoss):
"""Penalize a state variable in comparison to a target value.
!!! Note ""
Currently only supports `where` functions that select a
single state array, not a `PyTree[Array]`.
Arguments:
label: The label for the loss term.
where: Function that takes the PyTree of model states, and
returns the substate to be penalized.
norm: Function which takes the difference between
the substate and the target, and transforms it into a distance. For example,
if the substate is effector position, then the substate-target difference
gives the difference between the $x$ and $y$ position components, and the
default `norm` function (`jnp.linalg.norm` on `axis=-1`) returns the
Euclidean distance between the actual and target positions.
spec: Gives default/constant values for the substate target, discount, and
time index.
"""
label: str
where: Callable
norm: Callable = lambda x: jnp.linalg.norm(x, axis=-1) # Spatial distance
spec: Optional[TargetSpec] = None # Default/constant values.

@cached_property
def _where_str(self):
return WhereDict.key_transform(self.where)
def key(self):
return WhereDict.key_transform((self.where, self.label))

def term(
self,
states: PyTree,
trial_specs: "AbstractTaskTrialSpec",
trial_specs: "TaskTrialSpec",
) -> Array:
"""
Arguments:
trial_specs: Trial-by-trial information. In particular, if
`trial_specs.targets` contains a `TargetSpec` entry mapped by
`self.key`, the values of that `TargetSpec` instance will
take precedence over the defaults specified by `self.spec`.
This allows `AbstractTask` subclasses to specify trial-by-trial
targets, where appropriate.
"""

# TODO: Support PyTrees, not just single arrays
state = self.where(states)[:, 1:]

if (task_target_spec := trial_specs.targets.get(self.key, None)) is None:
if self.spec is None:
raise ValueError("`TargetSpec` must be provided on construction of "
"`TargetStateLoss`, or as part of the trial "
"specifications")

target_spec = self.spec
else:
# Override default spec with trial-by-trial spec provided by the task, if any
target_spec: TargetSpec = eqx.combine(self.spec, task_target_spec)

if target_spec.time_idxs is not None:
state = state[..., target_spec.time_idxs, :]

target_spec = trial_specs.targets[self._where_str]
loss_over_time = self.norm(state - target_spec.target_value)

state = self.where(states)[..., target_spec.time_idxs, :]
target = target_spec.data
loss_over_time = target_spec.norm(state - target, axis=-1)
discounted_loss_over_time = loss_over_time * target_spec.discount
if target_spec.discount is not None:
loss_over_time = loss_over_time * target_spec.discount

return jnp.sum(discounted_loss_over_time, axis=-1)
return jnp.sum(loss_over_time, axis=-1)


class EffectorPositionLoss(AbstractLoss):
Expand Down Expand Up @@ -389,7 +461,7 @@ class EffectorPositionLoss(AbstractLoss):
def term(
self,
states: "SimpleFeedbackState",
trial_specs: "AbstractTaskTrialSpec",
trial_specs: "TaskTrialSpec",
) -> Array:

# Sum over X, Y, giving the squared Euclidean distance
Expand Down Expand Up @@ -433,7 +505,7 @@ class EffectorStraightPathLoss(AbstractLoss):
def term(
self,
states: "SimpleFeedbackState",
trial_specs: "AbstractTaskTrialSpec",
trial_specs: "TaskTrialSpec",
) -> Array:

effector_pos = states.mechanics.effector.pos
Expand Down Expand Up @@ -471,7 +543,7 @@ class EffectorFixationLoss(AbstractLoss):
def term(
self,
states: PyTree,
trial_specs: "AbstractTaskTrialSpec", # DelayedReachTrialSpec
trial_specs: "TaskTrialSpec", # DelayedReachTrialSpec
) -> Array:

loss = jnp.sum(
Expand Down Expand Up @@ -504,7 +576,7 @@ class EffectorVelocityLoss(AbstractLoss):
def term(
self,
states: "SimpleFeedbackState",
trial_specs: "AbstractTaskTrialSpec",
trial_specs: "TaskTrialSpec",
) -> Array:

# Sum over X, Y, giving the squared Euclidean distance
Expand Down Expand Up @@ -541,7 +613,7 @@ class EffectorFinalVelocityLoss(AbstractLoss):
def term(
self,
states: PyTree,
trial_specs: "AbstractTaskTrialSpec",
trial_specs: "TaskTrialSpec",
) -> Array:

loss = jnp.sum(
Expand All @@ -564,7 +636,7 @@ class NetworkOutputLoss(AbstractLoss):
def term(
self,
states: PyTree,
trial_specs: "AbstractTaskTrialSpec",
trial_specs: "TaskTrialSpec",
) -> Array:

# Sum over output channels
Expand All @@ -588,7 +660,7 @@ class NetworkActivityLoss(AbstractLoss):
def term(
self,
states: PyTree,
trial_specs: "AbstractTaskTrialSpec",
trial_specs: "TaskTrialSpec",
) -> Array:

# Sum over hidden units
Expand Down
Loading

0 comments on commit 4f735f4

Please sign in to comment.