# Total time derivative $\frac{d^n z(t)}{d t^n}$ with `jet`

Given the initial value problem defined by the dynamics
$$ z'(t) = \frac{\partial z}{\partial t} = f(z(t),t) $$
and initial conditions $z_0$ at $t_0$. 

We have the solution $$z(t) = z_0 + \int_{t_0}^t f(z(s),s) ds$$

We are interested in computing the $n$th time derivative of the solution $z(t)$: $$\frac{d^n z(t)}{d t^n}$$

## Autonomous System

To simplify things lets consider an autonomous system, one whose dynamics are time invariant:
$$ z'(t) = \frac{\partial z}{\partial t} = f(z(t),t) = g(z(t)) $$

To allow us to test correctness we consider a simple dynamical system,
$$g(z(t)) = 0.5 z$$
which has a closed form solution
$$z(t) = \exp[0.5t-0.5t_0]*z_0$$

In [1]:
import jax.numpy as np
from jax import grad
from scipy.integrate import odeint

## Initial Value Problem

# Dynamics
def f(z,t):
    return g(z)
def g(z):
    return 0.5*z

# Initial Conditions
t0, z0 = 2., 1.

# Closed form solution
def sol(t):
    return np.exp(0.5*t - 0.5*t0) * z0

# We evaluate our solution at a given time
t_eval = 4.
z_eval = sol(t_eval)

# Confirm the correctness of closed-form
assert np.isclose(odeint(f, z0, [t0, t_eval])[1], z_eval)



## Total time-derivatives of the solution
Since we have access to closed-form solution we can nest jax.grad to compute the nth derivative

In [2]:
def nest(f, n): # why do I have to implement this :(
    def rfun(p):
        return reduce(lambda x, _: f(x), xrange(n), p)
    return rfun

def n_grad(f,x,n):
    return nest(grad,n)(f)(x)

# This is the primal and first 4 derivatives of the solution
n_grads = [n_grad(sol,t_eval,i) for i in range(0,5)]
print n_grads

[DeviceArray(2.7182817, dtype=float32), DeviceArray(1.3591409, dtype=float32), DeviceArray(0.67957044, dtype=float32), DeviceArray(0.33978522, dtype=float32), DeviceArray(0.16989261, dtype=float32)]


## Taylor's Method

Unfortunately, we are rarely interested in systems with known closed-form solutions. Instead we only have access to $f(z,t) = g(z)$ and the current state `z_eval`. We could take gradients through the numerical solve, and for the $n$th order, nest these. But we want to avoid this.

Consider the input state $z_0$, which will be carrying its own higher-order sensitvity terms $(z_1,z_2,\dots, z_d)$. So,

$$ z_0(t) = z_0 +  z_1 t + z_2 t^2  + \cdots + z_d t^d $$

Then Taylor's method gives us an $O(t^d)$ approximation to the solution

$$g(z_0(t)) = v_0 + v_1 t + v_2 t^2 + \cdots + v_d t^d $$

Where 
$$
\begin{align}
v_0 &= g(z_0)\\
v_1 &= g'(z_0) z_1\\
v_2 &= g'(z_0) z_2 + \frac{1}{2} g''(z_0) z_1 z_1\\
v_3 &= g'(z_0) z_3 + g''(z_0) z_1 z_2 + \frac{1}{6} g'''(z_0)z_1 z_1 z_1\\
\dots
\end{align}
$$

Computing the resulting coefficients $v_0 \dots v_d$ is the purpose of `jet`, which shares work in operations common accross the $v_i$s, e.g. sharing $g'(z_0)$.

# Jet

We can use `jet` to compute higher order sensitivities of a function with sensitivity vectors `v`. The `n`th component of the terms returned by jet corresponds to
$$ jet_n(f,x,v) = \frac{\partial^n}{\partial \epsilon^n} f(x + v_1 * \epsilon + v_2 * \frac{\epsilon^2}{2!} + \dots)|_{\epsilon = 0}$$

In [3]:
from jax import jet
import numpy.random as npr
from scipy.special import factorial as fact

def taylor_expansion(f,primals,series):
    def expansion(eps):
        tayterms = [
        sum([eps**(i + 1) * terms[i] / fact(i + 1) for i in range(len(terms))])
        for terms in series
    ]
        return f(*map(sum, zip(primals, tayterms)))
    return expansion

order = 4
vs = list(npr.randn(order))
jet_primal, jet_coefs = jet(g,(z_eval,),(vs,))

print np.isclose(taylor_expansion(g,(z_eval,),(vs,))(0.),jet_primal)
print np.isclose(grad(taylor_expansion(g,(z_eval,),(vs,)))(0.), jet_coefs[0])
print np.isclose(grad(grad(taylor_expansion(g,(z_eval,),(vs,))))(0.), jet_coefs[1])
print np.isclose(grad(grad(grad(taylor_expansion(g,(z_eval,),(vs,)))))(0.),jet_coefs[2])

True
True
True
True


In [4]:
jet_primal, jet_coefs = jet(g,(z_eval,),((1.,0.,0.),))
print jet_coefs

def k_test(f,x,s,k):
    expand = lambda t: f(x+t*s)
    ds = nest(grad,k)(expand)(0.)
    return 1/fact(k) * ds

k_test(g,z_eval,1.,0)

[DeviceArray(0.5, dtype=float32), DeviceArray(0., dtype=float32), DeviceArray(0., dtype=float32)]


DeviceArray(1.3591409, dtype=float32)

## Jet for Total Derivatives

Calling `jet` with "standard" sentivitiy terms `(1.,0.,0.,...)` corresponds to the expansion

In [5]:
def shift_jet(old_jet):
    return [old_jet[0]] + old_jet[1]


def std_terms(x, n):
    return [np.ones_like(x)] + [np.zeros_like(x)] * (n - 1)


def zero_terms(x, n):
    return [np.zeros_like(x)] * n


NameError: name 'z_terms' is not defined

[DeviceArray(1.3591409, dtype=float32), DeviceArray(0.5, dtype=float32)]

DeviceArray(2.7182817, dtype=float32)

In [32]:
def make_derivs_sol(primals,order):
    fst_vs = jet(g,primals,((1.,2),))
    snd_vs = jet(g,primals,(shift_jet(fst_vs),))
    thd_vs = jet(g,primals,(shift_jet(snd_vs),))
    return [fst_vs, snd_vs, thd_vs]

In [33]:
make_derivs_sol((z_eval,),1)

[(DeviceArray(1.3591409, dtype=float32),
  [DeviceArray(0.5, dtype=float32), DeviceArray(1., dtype=float32)]),
 (DeviceArray(1.3591409, dtype=float32),
  [DeviceArray(0.67957044, dtype=float32),
   DeviceArray(0.25, dtype=float32),
   DeviceArray(0.5, dtype=float32)]),
 (DeviceArray(1.3591409, dtype=float32),
  [DeviceArray(0.67957044, dtype=float32),
   DeviceArray(0.33978522, dtype=float32),
   DeviceArray(0.125, dtype=float32),
   DeviceArray(0.25, dtype=float32)])]

In [21]:
g(z_eval)

DeviceArray(1.3591409, dtype=float32)

In [22]:
n_grads[1:]

[DeviceArray(1.3591409, dtype=float32),
 DeviceArray(0.67957044, dtype=float32),
 DeviceArray(0.33978522, dtype=float32),
 DeviceArray(0.16989261, dtype=float32)]

In [29]:
nest(grad,2)(taylor_expansion(sol,(t_eval,),((1.,1.,0.),)))(0.)

DeviceArray(2.0387113, dtype=float32)

In [89]:
u0 =2.
u1 = 1.
u2 = 0.
u3= 0.

u1t = u1
u2t = u2*2
u3t = u3*3

In [90]:
def total_jet(f,primals,series):
    j1 = jet(f,primals,series)
    j2 = jet(f,primals,(shift_jet(j1),))
    j3 = jet(f,primals,(shift_jet(j2),))
    return (j3[0],j3[1][:3])

total_jet(np.sin,(u0,),((u1,u2,u3),))

(DeviceArray(0.9092974, dtype=float32),
 [DeviceArray(-0.37840125, dtype=float32),
  DeviceArray(-0.5943564, dtype=float32),
  DeviceArray(1.4922845, dtype=float32)])

In [96]:
s0 = np.sin(u0)
c0 = np.cos(u0)
c1t = -u1t*s0
s1t = -u1t*c0

s1 = s1t
c1 = c1t

c2t = -u1t*s1 - u2t*s0
c2 = c2t/2
s3t = u1t*c2 + u2t*c1 + u3t * c0

s2t = u1t*c1 + u2t*c0

In [99]:
s2t/2 * fact(2)

DeviceArray(-0.9092974, dtype=float32)

In [98]:
jet(np.sin,(u0,),((u1,u2,u3),))

(DeviceArray(0.9092974, dtype=float32),
 [DeviceArray(-0.41614684, dtype=float32),
  DeviceArray(-0.9092974, dtype=float32),
  DeviceArray(0.41614684, dtype=float32)])

In [102]:
s3t/3*fact(3)
s1t

DeviceArray(0.41614684, dtype=float32)