## Neural ODEs algorithm

In this notebook, we will derive the adjoint equation for neural ODEs using Lagrange multipliers. We will then use this adjoint equation to derive the gradient of the loss function with respect to the parameters of the neural network.

### Minimization problem

Here is our minimization problem from neural ODEs:

$$
\textrm{argmin}_{\theta} L(z(t_1))
$$

subject to
$$
\frac{dz}{dt} = f(z(t), t, \theta), \quad z(t_0) = z_0, \quad t_0 < t_1
$$

- $f$ is the neural network with parameters $\theta$
- $z(t_0)$ is the input, $z(t_1)$ is the output, $z(t)$ is the state reached from $z(t_0)$ at time $t\in[t_0, t_1]$
- $L$ is the loss function and it is a function of the output $z(t_1)$

### Example: 

The minimization problem: find out $\theta$ such that
$$
L(z(1)) = (z(1) - 1)^2
$$
reaches the minimum value where $z(t)$
subject to
$$
\frac{dz}{dt} = -z(t) + \theta, \quad z(0) = 2, \quad 0< t < 1
$$

#### Solution:
We first solve the ODE using the integrating factor method:
$$
(e^{t} z(t))' = e^{t} (z'(t) + z(t)) = e^{t} \theta
$$
Then integrating the above equation, we get:
$$
e^{t} z(t) - e^{0} z(0) = \int_{0}^{t} e^{t} \theta dt = \theta \left( e^{t} - 1 \right)
$$
Using the initial condition $z(0) = 2$, we get:
$$
z(t) = \theta + (2 - \theta) e^{-t}
$$
Then we compute the loss:
$$
L(z(1)) = (z(1) - 1)^2 = (\theta + (2 - \theta) e^{-1} - 1)^2 = (\theta(1 - e^{-1}) + 2e^{-1} - 1)^2
$$
Then this loss reaches the minimum by setting
$$
\theta(1 - e^{-1}) + 2e^{-1} - 1 = 0
$$
We solve out
$$
\theta = \frac{1 - 2e^{-1}}{1 - e^{-1}}
$$


### General case

For general $f(z, t, \theta)$ (represented by a neural network with parameter $\theta$), we cannot solve the ODE analytically. This is a constrained nonlinear optimization problem. To solve it, we need to compute the gradient of the loss function $L(z(t_1))$ with respect to the neural network parameters $\theta$:
$$
\frac{\partial L(z(t_1))}{\partial \theta}
$$

### Application of neural ODEs:

The restriction of neural ODEs is that the input $z(t_0)$ and output $z(t_1)$ of the neural network must be the same dimension. However, we want the input and output of the neural network to be different dimensions. For example, we want to classify images. The input of the neural network is an image and the output of the neural network is the classification result.

To solve this problem, we put two feedforward neural networks before the initial condition and after the output of the ODE. The initial condition is the output of the first neural network and the output of the ODE is the input of the second neural network. The output of the second neural network is the classification result.

The function $y=f(x; \theta, \theta_1, \theta_2)$ can be represented by:
$$
\frac{dz}{dt} = f(z(t), t, \theta),
$$
with the initial condition
$$
z(t_0) = g(x, \theta_1)
$$
and the output determined by the solution
$$
y = h(z(t_1), \theta_2)
$$
Here $g$ and $h$ are two feedforward neural networks.

### Direct Method for computing the gradient

Let us take the gradient directly using the chain rule:
$$
\frac{\partial L(z(t_1))}{\partial \theta} = \frac{\partial L(z(t_1))}{\partial z(t_1)} \frac{\partial z(t_1)}{\partial \theta}
$$
The first term $\frac{\partial L(z(t_1))}{\partial z(t_1)}$ is easy to compute by just taking gradient of the loss function.

Now we focus on the second $\frac{\partial z(t_1)}{\partial \theta}$. To compute $z(t_1)$, we integrate the ODE in time and obtain:
$$
z(t_1) = z(t_0) + \int_{t_0}^{t_1} f(z(t), t, \theta) dt
$$
Then we take the gradient with respect to $\theta$. Notice that: 
- $z(t_0)$ is independent of $\theta$
- $z(t)$ is a function of $\theta$, we need to use the chain rule:
$$
\frac{\partial z(t_1)}{\partial \theta} = \int_{t_0}^{t_1} \frac{\partial }{\partial \theta} (f(z(t), t, \theta)) dt = \int_{t_0}^{t_1} \frac{\partial f(z(t), t, \theta)}{\partial z(t)} \frac{\partial z(t)}{\partial \theta} + \frac{\partial f(z(t), t, \theta)}{\partial \theta}dt
$$
We do not have good idea of how to compute $\frac{\partial z(t)}{\partial \theta}$, since $z(t)$ depends implicitly on $\theta$.

### Lagrange Multipliers

We can use Lagrange multipliers to solve this problem. We will first derive the adjoint equation for the ODE, then use the adjoint equation to derive the gradient of the loss function with respect to the neural network parameters.

### Review of Lagrange Multipliers in calculus

Let us review Lagrange multipliers in calculus. Suppose we have a constrained optimization problem:
$$
\textrm{min}_{x} f(x)
$$
subject to
$$
g(x) = 0
$$

#### Example:
$$
\textrm{min}_{x,y} (x^2 + y^2)
$$
subject to
$$
x + y = 0
$$

We can use Lagrange multipliers to solve this problem. We introduce a Lagrange multiplier $\lambda$ and rewrite the constrained optimization problem as:
$$
\textrm{argmin}_{x, \lambda} L(x,\lambda)
$$
where
$$
L(x,\lambda) = f(x) + \lambda g(x)
$$
Then we take the gradient of the objective function with respect to $x$ and $\lambda$:
$$
\begin{aligned}
\frac{\partial L}{\partial x} &= 0 \\
\frac{\partial L}{\partial \lambda} &= 0
\end{aligned}
$$
or equivalently:
$$
\begin{aligned}
\frac{\partial f(x)}{\partial x} + \lambda \frac{\partial g(x)}{\partial x} &= 0 \\
g(x) &= 0
\end{aligned}
$$

### Deriving the adjoint equation for neural ODEs using Lagrange multipliers

We will now derive the adjoint equation for neural ODEs using Lagrange multipliers. We will first derive the adjoint equation for the ODE, then use the adjoint equation to derive the gradient of the loss function with respect to the neural network parameters.

We define the Lagrangian function $\psi = \psi(\theta)$:
$$
\psi(\theta) = L(z(t_1)) - \int_{t_0}^{t_1} \lambda(t) (\frac{dz}{dt} - f(z(t), t, \theta)) dt
$$
where $\lambda(t)$ is the Lagrange multiplier and the function $z(t)$ satisfies the ODE:
$$
\frac{dz}{dt} = f(z(t), t, \theta), \quad z(t_0) = z_0, \quad t_0 < t_1
$$

Notice that $z(t)$ satisfies the ODE, so the second term in $\psi(\theta)$ is equal to zero. Then the second term does not contribute to the gradient of $\psi(\theta)$ with respect to $\theta$, that is:
$$
\frac{\partial \psi(\theta)}{\partial \theta} = \frac{\partial L(z(t_1))}{\partial \theta}
$$

Now we just need to compute $\frac{\partial \psi(\theta)}{\partial \theta}$ and it is equal to $\frac{\partial L(z(t_1))}{\partial \theta}$.

Why we want to introduce the Lagrange multiplier $\lambda(t)$? We want to choose $\lambda(t)$ such that we can hopefully eliminate the difficulty in computing $\frac{\partial z(t)}{\partial \theta}$. Now our problem is to find $\lambda(t)$ such that the gradient of $\psi(\theta)$ with respect to $\theta$ is easy to compute.

#### Simplify terms

Let us use integration by parts to compute
$$
\begin{aligned}
& \int_{t_0}^{t_1} \lambda(t) (\frac{dz}{dt} - f(z(t), t, \theta)) dt \\
&= \int_{t_0}^{t_1} \lambda(t) \frac{dz}{dt} dt - \int_{t_0}^{t_1} \lambda(t) f(z(t), t, \theta) dt \\
&= \lambda(t_1) z(t_1) - \lambda(t_0) z(t_0) - \int_{t_0}^{t_1}\frac{d \lambda(t)}{dt} z(t) - \int_{t_0}^{t_1} \lambda(t) f(z(t), t, \theta) dt \\
&= \lambda(t_1) z(t_1) - \lambda(t_0) z_0 - \int_{t_0}^{t_1} (\frac{d \lambda(t)}{dt} z(t) + \lambda(t) f(z(t), t, \theta) ) dt
\end{aligned}
$$
Now we can use chain rule to compute the gradient with respect to $\theta$:
$$
\begin{aligned}
& \frac{\partial }{\partial \theta} (\int_{t_0}^{t_1} \lambda(t) (\frac{dz}{dt} - f(z(t), t, \theta)) dt) \\
=& \frac{\partial }{\partial \theta} (\lambda(t_1) z(t_1) - \lambda(t_0) z_0 - \int_{t_0}^{t_1} (\frac{d \lambda(t)}{dt} z(t) + \lambda(t) f(z(t), t, \theta) ) dt) \\
=& \lambda(t_1) \frac{\partial z(t_1)}{\partial \theta} - \lambda(t_0) \frac{\partial z_0}{\partial \theta} \\
&- \int_{t_0}^{t_1} (\frac{d \lambda(t)}{dt} \frac{\partial z(t)}{\partial \theta} + \lambda(t) ( \frac{\partial f(z(t), t, \theta)}{\partial z}\frac{\partial z}{\partial \theta} + \frac{\partial f(z(t), t, \theta)}{\partial \theta} ) ) dt \\
=& \lambda(t_1) \frac{\partial z(t_1)}{\partial \theta} - \int_{t_0}^{t_1} (\frac{d \lambda(t)}{dt} \frac{\partial z(t)}{\partial \theta} + \lambda(t) ( \frac{\partial f(z(t), t, \theta)}{\partial z}\frac{\partial z}{\partial \theta} + \frac{\partial f(z(t), t, \theta)}{\partial \theta} ) ) dt
\end{aligned}
$$

Now we compute the derivative:
$$
\begin{aligned}
& \frac{\partial \psi(\theta)}{\partial \theta} \\
=& \frac{\partial L(z(t_1))}{\partial z(t_1)}\frac{\partial z(t_1)}{\partial \theta} \\
&- \lambda(t_1) \frac{\partial z(t_1)}{\partial \theta} + \int_{t_0}^{t_1} (\frac{d \lambda(t)}{dt} \frac{\partial z(t)}{\partial \theta} + \lambda(t) ( \frac{\partial f(z(t), t, \theta)}{\partial z}\frac{\partial z}{\partial \theta} + \frac{\partial f(z(t), t, \theta)}{\partial \theta} ) ) dt \\
=& (\frac{\partial L(z(t_1))}{\partial z(t_1)} - \lambda(t_1)) \frac{\partial z(t_1)}{\partial \theta} \\
& + \int_{t_0}^{t_1} (\frac{d \lambda(t)}{dt}  + \lambda(t)\frac{\partial f(z(t), t, \theta)}{\partial z})\frac{\partial z}{\partial \theta} dt + \int_{t_0}^{t_1} ( \lambda(t) \frac{\partial f(z(t), t, \theta)}{\partial \theta} )  dt \\
\end{aligned}
$$

In this equation,
- $\frac{\partial L(z(t_1))}{\partial z(t_1)}$ is the gradient of loss with respect to $z(t_1)$. This is easy to compute.

- $\frac{\partial z(t_1)}{\partial \theta}$ is the gradient of $z(t_1)$ with respect to the parameters. This is the difficult part.

- $\frac{\partial f(z(t), t, \theta)}{\partial z}$ is the Jacobian matrix of the function $f(z(t), t, \theta)$ with respect to $z(t)$. This is easy to compute.

- $\frac{\partial z(t)}{\partial \theta}$ is the gradient of $z(t)$ with respect to the parameters. This is the difficult part.

- $\frac{\partial f(z(t), t, \theta)}{\partial \theta}$ is the Jacobian matrix of the function $f(z(t), t, \theta)$ with respect to $\theta$. This is easy to compute.

Now we want to choose appropriate $\lambda(t)$ such that the difficult part is not needed. We want to choose $\lambda(t)$ such that $\frac{\partial \psi(\theta)}{\partial \theta}$ is easy to compute.
$$
\frac{d \lambda}{dt} = -\lambda(t) \frac{\partial f(z(t), t, \theta)}{\partial z}
$$
and
$$
\lambda(t_1) = \frac{\partial L(z(t_1))}{\partial z(t_1)}
$$
Then the gradient will simplify to
$$
\frac{\partial \psi(\theta)}{\partial \theta} = \int_{t_0}^{t_1} ( \lambda(t) \frac{\partial f(z(t), t, \theta)}{\partial \theta} )  dt
$$

In summary, the function $\lambda(t)$ should satisfy the following ODE:
$$
\frac{d \lambda}{dt} = -\lambda(t) \frac{\partial f(z(t), t, \theta)}{\partial z}
$$
with the terminal condition at $t_1$:
$$
\lambda(t_1) = \frac{\partial L(z(t_1))}{\partial z(t_1)}
$$
The gradient is
$$
\frac{\partial L(z(t_1))}{\partial \theta} = - \int_{t_1}^{t_0}  \lambda(t) \frac{\partial f(z(t), t, \theta)}{\partial \theta}   dt
$$

### Algorithm summary:

1. Forward: solve the ODE for $z(t)$ from $t_0$ to $t_1$ and get the output $z(t_1)$.

2. Loss calculation: compute the loss $L(z(t_1))$.

3. Backward: solve the ODE for $\lambda(t)$ from $t_1$ to $t_0$ and get the gradient $\frac{\partial L(z(t_1))}{\partial \theta}$.

4. Use the gradient and optimization algorithm to update the parameters $\theta$.