Skip to content

Commit

Permalink
Fix the "opt_state changes shape" issue
Browse files Browse the repository at this point in the history
- Add custom init to `Iterator` because the field `_step` is
  private.
- Change all `init` method annotations to `PRNGKeyArray`
  from `Optional[PRNGKeyArray]`.
  - I'm not sure if this is ideal.
- Rename `filter_spec` to `where_train_spec` in `TaskTrainer`.
  More explicit.
- Filter the model by both `where_train_spec` and `eqx.is_array`,
  before passing to `optimizer.init` in `TaskTrainer`. This ensures
  that only the trainable leaves are initialized in `opt_state`,
  so it does not change shape during the first iteration,
  and there isn't an additional delay on the second batch iteration
  due (I think) to JIT recompilation.
- Remove the `tqdm` loop surrounding compilation. Use a timer
  and report the durations in the console output.
  • Loading branch information
mlprt committed Feb 29, 2024
1 parent 147fb42 commit 8f080c6
Show file tree
Hide file tree
Showing 11 changed files with 58 additions and 40 deletions.
2 changes: 1 addition & 1 deletion feedbax/_tree.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@


def filter_spec_leaves(
tree: PyTree[Any, "T"], leaf_func: Callable
tree: PyTree[Any, "T"], leaf_func: Callable,
) -> PyTree[bool, "T"]:
"""Returns a filter specification for tree leaves matching `leaf_func`."""
filter_spec = jax.tree_util.tree_map(lambda _: False, tree)
Expand Down
8 changes: 4 additions & 4 deletions feedbax/bodies.py
Original file line number Diff line number Diff line change
Expand Up @@ -130,11 +130,11 @@ def __init__(
# (say) a dict of dicts.
feedback_specs = _convert_feedback_spec(feedback_spec)

init_mechanics_state = mechanics.init()
example_mechanics_state = mechanics.init(key=jr.PRNGKey(0))

def _build_feedback_channel(spec: ChannelSpec):
return Channel(spec.delay, spec.noise_std, jnp.nan).change_input(
spec.where(init_mechanics_state)
spec.where(example_mechanics_state)
)

self.feedback_channels = jax.tree_map(
Expand Down Expand Up @@ -271,9 +271,9 @@ def get_nn_input_size(
not an instance method because we want to construct the network
before we construct `SimpleFeedback`.
"""
init_mechanics_state = mechanics.init()
example_mechanics_state = mechanics.init(key=jr.PRNGKey(0))
example_feedback = jax.tree_map(
lambda spec: spec.where(init_mechanics_state),
lambda spec: spec.where(example_mechanics_state),
_convert_feedback_spec(feedback_spec),
is_leaf=lambda x: isinstance(x, ChannelSpec),
)
Expand Down
2 changes: 1 addition & 1 deletion feedbax/channel.py
Original file line number Diff line number Diff line change
Expand Up @@ -150,7 +150,7 @@ def memory_spec(self) -> ChannelState:
noise=False,
)

def init(self, *, key: Optional[PRNGKeyArray] = None) -> ChannelState:
def init(self, *, key: PRNGKeyArray) -> ChannelState:
"""Returns an empty `ChannelState` for the channel."""
input_init = jax.tree_map(
lambda x: jnp.full_like(x, self._init_value), self.input_proto
Expand Down
2 changes: 1 addition & 1 deletion feedbax/dynamics.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,7 +63,7 @@ def input_size(self) -> int:
...

@abstractmethod
def init(self, *, key: Optional[PRNGKeyArray] = None) -> StateT:
def init(self, *, key: PRNGKeyArray) -> StateT:
"""Returns the initial state of the system."""
...

Expand Down
14 changes: 14 additions & 0 deletions feedbax/iterate.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@ def init(self, *, key: PRNGKeyArray) -> StateT:

@property
def step(self) -> AbstractModel[StateT]:
"""The model to be iterated."""
return self._step

def state_consistency_update(self, state: StateT) -> StateT:
Expand All @@ -61,6 +62,19 @@ class Iterator(AbstractIterator[StateT]):
_step: AbstractModel[StateT]
n_steps: int

def __init__(
self,
step: AbstractModel[StateT],
n_steps: int,
):
"""
Arguments:
step: The model to be iterated.
n_steps: The number of steps to iterate for.
"""
self._step = step
self.n_steps = n_steps

def __call__(
self,
input: PyTree,
Expand Down
4 changes: 2 additions & 2 deletions feedbax/mechanics/mechanics.py
Original file line number Diff line number Diff line change
Expand Up @@ -151,11 +151,11 @@ def memory_spec(self):
def init(
self,
*,
key: Optional[PRNGKeyArray] = None,
key: PRNGKeyArray,
):
"""Returns an initial state for use with the `Mechanics` module."""

plant_state = self.plant.init()
plant_state = self.plant.init(key=key)
init_input = jnp.zeros((self.plant.input_size,))

return MechanicsState(
Expand Down
6 changes: 3 additions & 3 deletions feedbax/mechanics/muscle.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@
import numpy as np

from feedbax.dynamics import AbstractDynamicalSystem
from feedbax.model import AbstractModel
from feedbax._model import AbstractModel
from feedbax.state import AbstractState, StateBounds


Expand Down Expand Up @@ -79,7 +79,7 @@ def _tau_diff(self):
def input_size(self):
return 1

def init(self):
def init(self, *, key: PRNGKeyArray) -> Array:
raise NotImplementedError("No state PyTree associated with ActivationFilter.")


Expand Down Expand Up @@ -234,7 +234,7 @@ class VirtualMuscle(AbstractMuscle):
force_func: AbstractFLVFunction
noise_func: Optional[Callable[[Array, Array, Array], Array]] = None

def init(self, *, key: Optional[PRNGKeyArray] = None) -> VirtualMuscleState:
def init(self, *, key: PRNGKeyArray) -> VirtualMuscleState:
"""Return a default state for the model."""
state = VirtualMuscleState(
activation=jnp.zeros(self.n_muscles),
Expand Down
13 changes: 7 additions & 6 deletions feedbax/mechanics/plant.py
Original file line number Diff line number Diff line change
Expand Up @@ -123,7 +123,7 @@ def dynamics_spec(self) -> dict[str, DynamicsComponent[PlantState]]:
...

@abstractmethod
def init(self, *, key: Optional[PRNGKeyArray] = None) -> PlantState:
def init(self, *, key: PRNGKeyArray) -> PlantState:
"""Returns a default state for the plant."""
...

Expand Down Expand Up @@ -234,10 +234,10 @@ def memory_spec(self) -> PyTree[bool]:
muscles=False,
)

def init(self, *, key: Optional[PRNGKeyArray] = None) -> PlantState:
def init(self, *, key: PRNGKeyArray) -> PlantState:
"""Return a default state for the plant."""
return PlantState(
skeleton=self.skeleton.init(),
skeleton=self.skeleton.init(key=key),
muscles=None,
)

Expand Down Expand Up @@ -441,11 +441,12 @@ def memory_spec(self) -> PlantState:
muscles=True,
)

def init(self, *, key: Optional[PRNGKeyArray] = None) -> PlantState:
def init(self, *, key: PRNGKeyArray) -> PlantState:
"""Return a default state for the muscled arm."""
key1, key2 = jax.random.split(key)
return PlantState(
skeleton=self.skeleton.init(),
muscles=self.muscle_model.init(),
skeleton=self.skeleton.init(key=key1),
muscles=self.muscle_model.init(key=key2),
)

@property
Expand Down
4 changes: 2 additions & 2 deletions feedbax/mechanics/skeleton/skeleton.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@ def vector_field(
...

@abstractmethod
def init(self, *, key: Optional[PRNGKeyArray] = None) -> StateT:
def init(self, *, key: PRNGKeyArray) -> StateT:
"""Return a default state for the skeleton."""
...

Expand Down Expand Up @@ -74,7 +74,7 @@ def update_state_given_effector_force(
...

# @abstractmethod
# def init(self, *, key: Optional[PRNGKeyArray] = None) -> StateT:
# def init(self, *, key: PRNGKeyArray) -> StateT:
# """Returns the initial state of the system.
# """
# ...
2 changes: 1 addition & 1 deletion feedbax/nn.py
Original file line number Diff line number Diff line change
Expand Up @@ -337,7 +337,7 @@ def memory_spec(self):
encoding=True,
)

def init(self, *, key: Optional[PRNGKeyArray] = None):
def init(self, *, key: PRNGKeyArray):
if self.out_size is None:
output = None
else:
Expand Down
41 changes: 22 additions & 19 deletions feedbax/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@

from feedbax import loss
from feedbax.loss import AbstractLoss, LossDict
from feedbax.misc import TqdmLoggingHandler, delete_contents
from feedbax.misc import Timer, TqdmLoggingHandler, delete_contents
from feedbax._model import AbstractModel, ModelInput
import feedbax.plot as plot
from feedbax.state import StateT
Expand Down Expand Up @@ -177,24 +177,23 @@ def __call__(
key: The random key.
"""

filter_spec = filter_spec_leaves(model, where_train)
where_train_spec = filter_spec_leaves(model, where_train)
model_trainables = eqx.filter(eqx.filter(model, where_train_spec), eqx.is_array)

if ensembled:
# Infer the number of replicates from shape of trainable arrays
n_replicates = jax.tree_leaves(eqx.filter(model, eqx.is_array))[0].shape[0]
loss_array_shape = (n_batches, n_replicates)
opt_state = jax.vmap(self.optimizer.init)(
eqx.filter(model, eqx.is_array) # Is this necessary?
)
opt_state = jax.vmap(self.optimizer.init)(model_trainables)
else:
loss_array_shape = (n_batches,)
opt_state = self.optimizer.init(eqx.filter(model, eqx.is_array))
opt_state = self.optimizer.init(model_trainables)

# TODO: ensembling
if save_model_trainables:
model_train_history = jax.tree_map(
lambda x: jnp.empty((n_batches,) + x.shape) if eqx.is_array(x) else x,
eqx.filter(model, filter_spec),
model_trainables,
)
else:
model_train_history = None
Expand Down Expand Up @@ -258,7 +257,9 @@ def __call__(

# Finish the JIT compilation before the first training iteration.
if not jax.config.jax_disable_jit:
for _ in tqdm(range(1), desc="compile", disable=disable_tqdm):
timer = Timer()

with timer:
if ensembled:
key_compile = jr.split(key, n_replicates)
else:
Expand All @@ -271,15 +272,17 @@ def __call__(
treedef_model,
flat_opt_state,
treedef_opt_state,
filter_spec,
where_train_spec,
key_compile,
)
if not disable_tqdm:
tqdm.write(f"Training step compiled.", file=sys.stdout)

logger.info(f"Training step compiled in {timer.time:.2f} seconds.")

with timer:
evaluate(model, key_compile)
if not disable_tqdm:
tqdm.write(f"Validation step compiled.", file=sys.stdout)

logger.info(f"Validation step compiled in {timer.time:.2f} seconds.")

else:
logger.debug("JIT globally disabled, skipping pre-run compilation")

Expand Down Expand Up @@ -307,7 +310,7 @@ def __call__(
treedef_model,
flat_opt_state,
treedef_opt_state,
filter_spec,
where_train_spec,
key_train,
)
)
Expand All @@ -322,7 +325,7 @@ def __call__(
lambda history: history.model_trainables,
history,
tree_set(
history.model_trainables, eqx.filter(model, filter_spec), batch
history.model_trainables, eqx.filter(model, where_train_spec), batch
),
)

Expand Down Expand Up @@ -474,7 +477,7 @@ def _train_step(
treedef_model,
flat_opt_state,
treedef_opt_state,
filter_spec, #! can't do AbstractModel[StateT[bool]]
where_train_spec, #! can't do AbstractModel[StateT[bool]]
key: PRNGKeyArray,
):
"""Executes a single training step of the model.
Expand Down Expand Up @@ -519,7 +522,7 @@ def _train_step(

init_states = jax.vmap(model.step.state_consistency_update)(init_states)

diff_model, static_model = eqx.partition(model, filter_spec)
diff_model, static_model = eqx.partition(model, where_train_spec)

opt_state = jtu.tree_unflatten(treedef_opt_state, flat_opt_state)

Expand Down Expand Up @@ -643,10 +646,10 @@ def _grad_wrap_task_loss_func(loss_func: AbstractLoss):
Note that we are assuming that
1) `TaskTrainer` will manage a `filter_spec` on the trainable parameters.
1) `TaskTrainer` will manage a `where_train_spec` on the trainable parameters.
When `jax.grad` is applied to the wrapper, the gradient will be
taken with respect to the first argument `diff_model` only, and the
`filter_spec` defines this split.
`where_train_spec` defines this split.
2) Model modules will use a `target_state, init_state, key` signature.
TODO:
Expand Down

0 comments on commit 8f080c6

Please sign in to comment.