Use this notebook to run the Neural ODE-based end-to-end imitation learning algorithm, learning a black-box model of the dynamics.


First, we load the Neural ODE end-to-end imitation learning algorithm.

In [12]:
# Load the necessary libraries
from myriad.config import HParams, Config, SystemType
from myriad.experiments.node_e2e_sysid import run_node_endtoend

Next we choose which hyperparameters to use, as in previous settings. We choose a smaller size for the neural network for better stability and faster training; larger networks can also be used.

In [13]:
# Create hyperparameter and config objects
# This is the place to specify environment or solver hyperparameters; see config.py for a full list.
hp = HParams(system=SystemType.CANCERTREATMENT,
             intervals=1,
             controls_per_interval=100,
             hidden_layers=(50, 50))
cfg = Config()

Note that since we are learning a black-box dynamics model, there is no simple way to interpret the learned parameters as we do in the parametric model case.

Look for losses in the `losses` directory, and for the current neural network parameters in the `params` directory. Look in the `intermediate_guesses` directory for intermediate guesses at the best control trajectory. Finally, various plots will be generated and put in the `plots` directory, so you can see how the learning process went.


In [None]:
run_node_endtoend(hp, cfg)

initial controls shape (106, 101, 1)
Generating training control trajectories between bounds:
  u lower [0.]
  u upper [2.]
of shapes:
  xs shape (106, 101, 1)
  us shape (106, 101, 1)
  together (106, 101, 2)
Generated training trajectories of shape (100, 101, 2)
Generated validation trajectories of shape (3, 101, 2)
Generated test trajectories of shape (3, 101, 2)
node: params initialized with:  (2,)


  leaves, treedef = jax.tree_flatten(tree)
  return jax.tree_unflatten(treedef, leaves)


node: minibatches are of shape (3, 101, 2)
node: initialized network weights
the controls guess is (101, 1)
hp opt type OptimizerType.SHOOTING
hp quadrature rule QuadratureRule.TRAPEZOIDAL
guess.shape = (103,)
bounds.shape = (103, 2)
the controls guess is (101, 1)
hp opt type OptimizerType.SHOOTING
hp quadrature rule QuadratureRule.TRAPEZOIDAL
guess.shape = (103,)
bounds.shape = (103, 2)
successfully loaded the saved optimal trajectory
unable to find the params, so we'll guess and then optimize and save
saving current params
saving guesses so far
saving imitation losses
resetting guess
the controls guess is (101, 1)
hp opt type OptimizerType.SHOOTING
hp quadrature rule QuadratureRule.TRAPEZOIDAL
guess.shape = (103,)
bounds.shape = (103, 2)
0 loss 0.4495123507170944


  leaves, treedef = jax.tree_flatten(tree)
  return jax.tree_unflatten(treedef, leaves)
  zx = jax.tree_map(lambda x: x * 0., zx)
  zlmbda = jax.tree_map(lambda x: x * 0., zlmbda)
  x_part = jax.tree_map(lambda el: jnp.tensordot(dx, el, axes=(1, 0)), zx)
  lmbda_part = jax.tree_map(lambda el: jnp.tensordot(dlmbda, el, axes=(1, 0)), zlmbda)
  zx = jax.tree_multimap(lambda a, b, c: a + b + c, dp, x_part, lmbda_part)
  x_part = jax.tree_map(lambda el: jnp.tensordot(dx, el, axes=(1, 0)), zx)
  lmbda_part = jax.tree_map(lambda el: jnp.tensordot(dlmbda, el, axes=(1, 0)), zlmbda)
  zlmbda = jax.tree_multimap(lambda a, b, c: a + b + c, dp, x_part, lmbda_part)
  dJdp = jax.tree_map(lambda x: jnp.tensordot(dx_dloop, x, axes=(0, 0)), dloop_dp)


1 loss 0.2914974151202434
2 loss 0.2424433843977694
3 loss 0.23071141813890475
4 loss 0.23257572152268985
5 loss 0.23928261228283254
6 loss 0.246954546012945
7 loss 0.25361483554696473
8 loss 0.25824379704739153
9 loss 0.2603965260380783
10 loss 0.2600188163857264
20 loss 0.2079137985671257
30 loss 0.22005862916828678
40 loss 0.1895171007539043
50 loss 0.14815609033528765
60 loss 0.10434232840327949
70 loss 0.06801993302629927
80 loss 0.05203422276918464
90 loss 0.03803649653869477
100 loss 0.03100611809974632
110 loss 0.027503860882019907
120 loss 0.026465588288021827
130 loss 0.025051464925933516
140 loss 0.02356707277369361
150 loss 0.022540536392685176
160 loss 0.021998686600569885
170 loss 0.021520070953116175
180 loss 0.02090414304026943
190 loss 0.020365090586653978
200 loss 0.019986688688688758
210 loss 0.01967532768791725
220 loss 0.0193277918798006
230 loss 0.018992879631413318
240 loss 0.018719747345665803
250 loss 0.018483444326307812
260 loss 0.018241766143154595
270 loss 