Copyright 2023 Google LLC.

SPDX-License-Identifier: Apache-2.0

In [None]:
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

# NOTE: This Colab is only for documentation

## **Learning Hybrid Continuous Time Policies**

## Definition of [Hybrid Continuous Time (HCT) Policies](https://arxiv.org/abs/2203.08715)

Suppose that we observe image(s) $s_t \in \mathbb{R}^{H \times W \times C}$ at
every discrete time index $t \in \mathbb{N}$, and in-between time indices
$t$ and $t + 1$, we have access to continuous-time (this will be
straightforwardly relaxed to “higher-frequency”) state observations from other sensing
modalities, denoted by the function $x_t(\cdot) : \tau \in [0,T] \rightarrow
\mathbb{R}^n$. The variable $\tau$ is referred to as the "interpolation"
time, indexing the continuous time between image observations $s_t$ and
$s_{t+1}$. Thus, we choose an arbitrary interval $[0, T]$ to represent its
domain, with $\tau = 0$ coinciding with discrete time index $t$, and
$\tau = T$ coinciding with discrete time index $t+1$.

An HCT policy is defined as a functional map from the observation tuple $o_t :=
(s_t,x_t(\cdot))$ to a control function $u_t(\cdot): \tau \in [0,T]
\rightarrow U$, where $U$ is the control space. For ease of notation, assume
that $U = \mathbb{R}^m$. Therefore, in “MDP-notation”, our action "$a_t$",
mapped from $o_t$, is the control function $u_t(\cdot)$, drawing natural
analogies with hierarchical policies.

> An HCT Policy therefore is a *functional* map from $o_t$ to $u_t$. Since
> the policy must be realizable with incoming observations, this map must
> additionally be causal.

## Working with Discrete-Time Measurements

While the functional representation will be useful in designing the
architectures, in actuality we receive observations and generate actions at
fixed frequencies. Thus, we introduce some additional notation.

Assume that over the interval $\tau \in [0, T)$, we output $M > 1$ equally
spaced actions at $\tau_0 = 0, \tau_1 = \frac{T}{M}, \ldots, \tau_{M-1} =
\frac{(M-1)T}{M}$. Recall that at $\tau = T$, we reset $\tau = 0$,
coinciding with the arrival of the next image observation $s_{t+1}$. Thus, the control *frequency* is $M$ times the image-observation frequency. Let
$\mathbf{u}_t$ denote the set of actions $\{u_t(\tau_0), \ldots,
u_t(\tau_{M-1})\}$.

Similarly, we assume that the frequency of the observations of the signal $x_t(\cdot)$ is $N$ times the control frequency, where $N \geq 1$. To notate this, we define $\mathbf{x}_t^i$ to be the set of $N+1$ equally spaced observations of $x_t(\cdot)$ within the interval $[\tau_{i-1}, \tau_i]$, for $i = 1,\ldots, M-1$. That is, $\mathbf{x}_t^{i} =
\{x_t(\tau_{i-1}), x_t(\tau_{i-1}+\frac{T}{MN}),\ldots, x_t(\tau_{i}) \}$.

See the figure below for an illustration of these variables. Note that for ease of batching data, we additionally define $\mathbf{x}_t^0$ as the set of $N+1$ high-frequency state observations in-between $u_{t-1}(\tau_{M-1})$ and $u_t(\tau_0)$.

![HybridContinuousTimeStructure](InFuser.png)

By the causal assumption, our architecture must conform to the following
functional relations:

$$
\begin{eqnarray}
(s_t, \mathbf{x}_t^0) &\rightarrow &u_t(0) \\
(s_t, \mathbf{x}_t^{0}, \ldots, \mathbf{x}_t^{j}) &\rightarrow &u_t(\tau_j), \text{ for } j = 1, \ldots, M-1
\end{eqnarray}
$$

# NDP

This is an adaptation of the
[Neural Dynamic Policies](https://shikharbahl.github.io/neural-dynamic-policies/resources/ndp.pdf)
paper. The key aspect of this model is that it only uses the observations
$s_t$ and $x_t(0)$, and generates the control function $u_t(\cdot)$ for
the entire interval $\tau \in [0, T]$ open-loop by solving a parametric 2nd
order ODE.

NOTE: For this model, we set $T=1$. Thus, the $M$ actions are output at
$\tau \in \{0, \frac{1}{M}, \ldots, \frac{M-1}{M}\}$.

A
[Dynamic Movement Primitive](https://homes.cs.washington.edu/~todorov/courses/amath579/reading/DynamicPrimitives.pdf)
(DMP) is defined by the following parametric coupled 2nd order ODE:

$$
\begin{eqnarray} 
\dfrac{d^2 u_t(\tau)}{d \tau^2} &:= &\alpha_u \left(\beta
(g_t - u_t(\tau)) - \dfrac{d u_t(\tau)}{d \tau}\right) + f_t(\phi_t(\tau)),
\quad \tau \in [0, 1)  \\
\dfrac{d \phi_t(\tau)}{d \tau} &:=
&-\alpha_\phi \phi_t(\tau), \quad \tau \in [0, 1).
\end{eqnarray}
$$

where $\alpha_u, \alpha_\phi, \beta \in \mathbb{R}$ are
positive constants, and $\phi_t$ is the phase function with initial condition
$\phi_t(0) = 1$. The function $f_t$, known as the forcing function, takes
the form:

$$
f_t(\phi) = \dfrac{\phi}{\sum_{k=1}^K \psi_k(\phi)} (W_t \psi(\phi)) \circ (g_t - u_t(0))
$$

where $\psi_k$ is a Gaussian radial basis function and $\psi(\phi) :=
(\psi_1(\phi), \ldots, \psi_K(\phi))$ where $K$ is the number of basis
functions. The parameters defining this system of equations are the "goal
vector" $g_t \in \mathbb{R}^m$ and the "weights matrix" $W_t \in
\mathbb{R}^{m \times K}$.

The architecture is defined by three modules:

*   an encoder, that maps the observations $(s_t, x_t(0))$ to an image embedding $z_{s_t}$ and DMP parameters $\{g_t, W_t\}$; hence *neural dynamic* policies,
*   a decoder, that maps $(z_{s_t}, x_t(0))$ to the DMP ODE's initial conditions $\left(u_t(0), \frac{d u_t(0)}{d\tau}\right)$, and
*   the DMP ODEs defined above. The constants $\{\alpha_u, \alpha_\phi, \beta, K\}$ are left as hyper-parameters.

## [Flax](https://flax.readthedocs.io/en/latest/api_reference/flax.linen.html) Module

To define an NDP model, we initialize an `NDP` module object defined in
`ndp_model.py`. Please see the module header for all relevant definitions. We
outline certain key forward-pass methods.

## Computing the Flow

There are three ways of computing the "flow," i.e., the solution of the NDP to
obtain the control actions between time-steps $t$ and $t+1$.

1.  (Batched) Compute all $M$ actions $\mathbf{u}_t$ from $(s_t, x_t(0))$
    using:

    ```python
    model.apply(params, batch_images, batch_hf_obs)
    ```

    where

    *   `params` are the module parameters,
    *   `batch_images` is a batch of image tensors ($s_t$), and
    *   `batch_hf_obs` is a batch of high-frequency observations ($x_t(0)$).

    Each instance in the batched output has shape $M \times m$.

2.  (Batched) Compute the solution $u_t(\cdot)$ at a dense set of times $\tau
    \in [0, 1]$, given by the vector `pred_times` using:

    ```python
    model.apply(params, batch_images, batch_hf_obs, pred_times,
                method=ndp_model.compute_ndp_flow)
    ```

    Each instance in the batched output has shape `len(pred_times)`$\times m$.

3.  (Unbatched) Compute each action in the sequence iteratively, i.e., as a
    policy.

    ```python
    # Extract the step functions for the NDP model
    re_init, step_fwd = model.step_functions

    # Given a new (image, hf_obs) pair, compute the initial action and NDP params
    ndp_state, ndp_args = re_init(params, image, hf_obs)
    # u(0) = ndp_state[:model.action_dim]

    # Compute u(tau_1)...u(tau_{M-1}) step-by-step:
    tau = 0.
    for i in range(1, M):
      ndp_state, tau = step_fwd(params, ndp_state, tau, ndp_args)
      # u(tau_i) = ndp_state[:model.action_dim]
    ```

## Training via Imitation Learning

Let $\hat{\mathbf{u}}_t = \{\hat{u}_t(0), \ldots, \hat{u}_t(\frac{M-1}{M})\}$
be the observed sequence of actions between discrete time-steps $t$ and
$t+1$. Let us define its linear interpolant as the function
$\hat{u}_t(\cdot)$, and the control function generated by the NDP model
(conditioned on observations) as $u_{t,\theta}(\cdot)$, where $\theta$
represents all learnable parameters. We define the imitation loss as the
integral:

$$
I_t(\theta) := \int_{0}^{\frac{M-1}{M}} l(\hat{u}_t(\tau), u_{t,\theta}(\tau))\ d\tau,
$$

where $l(\cdot, \cdot) \mapsto \mathbb{R}$ is some loss function penalizing
the observed and true actions. Note that this integral is equal to the terminal
value of the following auxiliary ODE:

$$
  \dfrac{d J(\tau)}{d\tau} = l(\hat{u}_t(\tau), u_{t,\theta}(\tau)), \quad \tau \in [0, \frac{M-1}{M}],\ J(0) = 0.
$$

To compute the loss, we solve an augmented set of ODEs which includes the
NDP-ODE as well as the auxiliary "cost-ODE" defined above.

Given a batched set of observations `batch_images`, `batch_hf_obs` and true
actions `batch_true_actions`, where each instance in `batch_true_actions` has
shape $M \times m$, we can compute the vector of losses across the batch as:

```python
batch_pred_actions, batch_losses = model.apply(
  params, batch_images, batch_hf_obs, batch_true_actions,
  method=ndp_model.compute_augmented_flow)
```

Here `batch_pred_actions` are the batch of predicted actions, and `batch_losses`
is the vector of losses across the batch. These can be averaged to yield a final
loss value which can be auto-diffed and incorporated into any training pipeline.
This library uses
[`flax` ](https://flax.readthedocs.io/en/latest/api_reference/flax.training.html#flax.training.train_state.TrainState)
along with [`optax`](https://github.com/deepmind/optax); see `ndp_utils.py`.