<a href="https://colab.research.google.com/github/f-iachan/MLEcon/blob/master/Ito's_Lemma.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In this notebook I illustrate one of the key ingredients of the paper: How to compute expectations (in contiuous time) with almost no extra computional cost, regardless of the size of the state space.

Consider the following example: There are 10000 state variables $x_i$ that follow the SDEs:

$$
    dx_i = \mu_i(x) dt + \sigma_i(x)^T dZ
$$

Where $dZ$ is a 100-dimensional brownian motion.


Given an arbitrary function $f$, I will illustrate how to compute $\mathbb{E}[\frac{df}{dt}$]. I will use a neural network as the arbitrary function, but you could really use anything you want.

In [None]:
import jax
from jax import jvp, grad, jit, vmap
import numpy as onp
import matplotlib.pyplot as plt
import jax.numpy as np
from jax.experimental import stax
from jax.experimental.stax import Dense, Relu, Tanh, Sigmoid
from functools import partial
from jax.experimental import optimizers

# Problem dimensions
n_shocks = 100
n_states = 10000

# Random numbers seed
rng = jax.random.PRNGKey(0)

# Create neural network to represent the fu f
initializer, f = stax.serial(
    Dense(128), Tanh,
    Dense(64), Tanh,
    Dense(1))
_, Θ = initializer(rng, (-1, n_states))


# Setup the dynamics of the problem
# postulate whatever dynamics you want
def dynamics(x):
    μ = -0.05 * x
    σ = np.array([x] * n_shocks).T

    return μ, σ


# This is the heart of the paper: a generic function to compute the
# drift of arbitrary functions with arbitrary numbers of state vars and
# brownian shocks
def drift(f, Θ, state, μstate, σstate):
    f_flat = lambda state: np.squeeze(f(Θ, state))
    first_order = jvp(f_flat, (state,), (μstate, ))[1]

    def hvp(f, x, σ):
        return jvp(grad(f), (x, ), (σ, ))[1]

    second_order = np.sum(
        np.array([hvp(f_flat, state, σstate.T[i]) @ σstate.T[i] for i
         in range(n_shocks)]))

    EdV = first_order + 0.5 * second_order
    return EdV




In [None]:
# Let's get a sense of how costly it is to evaluate the original function, f,
# for 512 different points picked at random

x = onp.random.normal(size=[512, n_states])

@jit
def compute_f(x):
    return f(Θ, x)

compute_f(x)  # run it once to jit compile it
%timeit compute_f(x).block_until_ready()

100 loops, best of 5: 9.83 ms per loop


In [None]:
# Now let's see how long it takes to compute it's drift

@vmap
@jit
def compute_Edf(x):
    # Dynamics
    μ, σ = dynamics(x)

    # Ito's Lemma
    EdV = drift(f, Θ, x, μ, σ)

    return EdV

compute_Edf(x)  # run it once to jit compile it
%timeit compute_Edf(x).block_until_ready()


10 loops, best of 5: 58.2 ms per loop


# Conclusion:
Computing the exact expectation took 60 ms, compared to 10ms that takes to compute the original function. Notice that we didn't have to compute a single partial derivative, either by hand or numerically. Let alone large and nasty Hessian matrices...