<a href="https://colab.research.google.com/github/fmottes/jax-morph/blob/eqx">
  <img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab" style="width:250px;"/>
  </a>

# Optimization of stochastic simulations

We now turn to the problem of optimizing a simulation that involves steps where some sort of sampling is involved. The sampling step is not differentiable in the classic sense, so it is not possible to propagate gradients through the simulation directly.

In this notebook we will use some toy problems to illustrate the basic variants of the two main ideas used in general to overcome the non-differentiability of the sampling step:

1. **Score function**: The idea is to use the score function to estimate the gradient of the expectation of the function of interest. This idea is at the core of the REINFORCE algorithm, the one we adopt also for the optimizations carried out in the paper.

2. **Reparametrization**: The idea is to reparametrize the sampling step in a way that makes it differentiable, by "decoupling" the learnable parameters from the randomness. This usually involves a transformation of the random variable that depends on the learnable parameters, and that is differentiable with respect to them. Heuristically, though, it is fairly hard to make this second approach work in practice, especially in complex models like ours.

# Imports

In [None]:
import jax
import jax.numpy as np
jax.config.update("jax_enable_x64", True)
jax.config.update("jax_debug_nans", True)

key = jax.random.PRNGKey(0) # random number generator

import matplotlib.pyplot as plt
plt.rcParams.update({'font.size': 15})

from tqdm import trange