<a href="https://colab.research.google.com/github/fmottes/jax-morph/blob/eqx">
  <img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab" style="width:250px;"/>
  </a>

# Optimization of stochastic simulations

We now turn to the problem of optimizing a simulation that involves steps where some sort of sampling is involved. The sampling step is not differentiable in the classic sense, so it is not possible to propagate gradients through the simulation directly.

In this notebook we will use some toy problems to illustrate the basic variants of the two main ideas used in general to overcome the non-differentiability of the sampling step:

1. **Reparametrization**: The idea is to reparametrize the sampling step in a way that makes it differentiable, by "decoupling" the learnable parameters from the randomness. This usually involves a transformation of the random variable that depends on the learnable parameters, and that is differentiable with respect to them. Heuristically, though, it is fairly hard to make this second approach work in practice, especially in complex models like ours.

2. **Score function**: The idea is to use the score function to estimate the gradient of the expectation of the function of interest. This idea is at the core of the REINFORCE algorithm, the one we adopt also for the optimizations carried out in the paper.


In order to illustrate the problem and the two approaches mentioned above, we will explore two different cases of increasing complexity: sampling from a single **Bernoulli Random Variable** first, and then from a **Categorical Distribution**.

# Imports

In [1]:
import jax
import jax.numpy as np
jax.config.update("jax_enable_x64", True)
jax.config.update("jax_debug_nans", True)

key = jax.random.PRNGKey(0) # random number generator

import matplotlib.pyplot as plt
plt.rcParams.update({'font.size': 15})

from tqdm import trange

I0000 00:00:1715094556.808931       1 tfrt_cpu_pjrt_client.cc:349] TfrtCpuClient created.


# Bernoulli Random Variable

A Bernoulli random variable $X$ is a discrete random variable that takes on two possible outcomes, usually denoted as 0 and 1. The probability mass function of $X$ is given by:

$$
P(X = k) = 
\begin{cases} 
p & \text{if } k = 1 \\
1 - p & \text{if } k = 0
\end{cases}
$$

where $p$ is the probability of the event occurring (i.e., $X = 1$), and $1 - p$ is the probability of the event not occurring (i.e., $X = 0$).


We define the RV Z as distributed as Bernoulli 
$$ Z \sim \text{Bernoulli}(p) $$
and the outcome of our "simulation" will be a function of the value of Z. 


In order to avoid possible confusions deriving from 0 values, we also redefine the outcomes of Z as +1 and -1, instead of 1 and 0.

$$
P(X = k) = 
\begin{cases} 
p & \text{if } k = +1 \\
1 - p & \text{if } k = -1
\end{cases}
$$

In order to make things (slighty) less trivial, we will also suppose that the probability $p$ depends on a parameter $\theta$ that we want to learn. In particular, we will suppose that $p = \sigma(\theta)$, where $\sigma$ is the sigmoid function.

In [2]:
### SIMULATION STEPS

# calculate p given theta
p_fn = lambda theta: jax.nn.sigmoid(theta)


# sample from Bernoulli distribution given p
def sample_bernoulli(subkey, p):

    u = jax.random.uniform(subkey)

    # ±1 encoding instead of [0,1]
    return np.sign(p-u)


# deterministic part of the "simulation"
env_fn = lambda z: 2*z

In [3]:
### ENTIRE SIMULATION

def simulation(subkey, theta):

    p = p_fn(theta) # calculate p
    z = sample_bernoulli(subkey, p) # sample z
    outcome = env_fn(z) # simulate outcome

    return outcome

## The Gradient Problem

Now we can try to calculate the linear sensitivity of the model outcome with respect to the input parameter $\theta$.

The analytic expression for the gradient is:

$$
\frac{\partial \text{outcome}}{\partial \theta}  = \frac{\partial \text{outcome}}{\partial Z} \cdot \frac{\partial Z}{\partial p} \cdot \frac{\partial p}{\partial \theta}
$$

But the derivative
$$ \frac{\partial Z}{\partial p} =\ ?$$
is not defined, since Z is a discrete random variable. In this case, JAX will set the derivative to 0 and the whole expression will be 0 as expected.

In [4]:
theta = 1. # gives p = 0.73

key, subkey = jax.random.split(key)
o, g = jax.value_and_grad(simulation, argnums=1)(subkey, theta)

print(f'Simulation outcome:\t{o}')
print(f'Gradient wrt theta:\t{g}')

Simulation outcome:	2.0
Gradient wrt theta:	0.0


We can also check that each other step is well defined:

In [5]:
print(f'p = {p_fn(theta)} \t\t gradient = {jax.grad(p_fn)(theta)}')

print(f'z = {sample_bernoulli(subkey, .5)} \t\t gradient = {jax.grad(sample_bernoulli, argnums=1)(subkey, .5)}')

print(f'outcome = {env_fn(-1.)} \t\t gradient = {jax.grad(env_fn)(-1.)}')

p = 0.7310585786300049 	 gradient = 0.19661193324148185
z = -1.0 	 gradient = 0.0
outcome = -2.0 	 gradient = 2.0


## Solution 1: Straight-Through Estimator

One first fix is to decide on an arbitrary value (usually 1) to be assigned to the ill-defined derivative, so as to allow gradients to flow past the problematic point. This is the idea behind the Straight-Through Estimator and its more mathematically refined versions.

$$ \frac{\partial Z}{\partial p} \sim\ 1$$

We can easily do this in JAX by applying the following trick with `jax.lax.stop_gradient` (check the JAX docs for details).


In [16]:
def ST_sample_bernoulli(subkey, p):

    u = jax.random.uniform(subkey)
    z = np.sign(p-u)

    zero =  p - jax.lax.stop_gradient(p)

    #in the gradient calculation only the dependence on p remains, so the grad of the function is 1!!
    return zero + jax.lax.stop_gradient(z)

In [17]:
###  SIMULATION with ST step
###  Other steps are unchanged

def ST_simulation(subkey, theta):

    p = p_fn(theta) # calculate p
    z = ST_sample_bernoulli(subkey, p) # sample z
    outcome = env_fn(z) # simulate outcome

    return outcome

For one instance:

In [18]:
theta = 1. # gives p = 0.73

key, subkey = jax.random.split(key)
o, g = jax.value_and_grad(ST_simulation, argnums=1)(subkey, theta)

print(f'Simulation outcome:\t{o}')
print(f'Gradient wrt theta:\t{g}')

Simulation outcome:	-2.0
Gradient wrt theta:	0.3932238664829637


Since the outcome of the simulation is stochastic, in order to get a better gradient we can get the expected gradient value over many simulations:

In [19]:
theta = 1. # gives p = 0.73

N_AVG = 1000

key, *subkeys = jax.random.split(key, N_AVG)
subkeys = np.asarray(subkeys)

o, g = jax.vmap(jax.value_and_grad(ST_simulation, argnums=1), in_axes=(0,None))(subkeys, theta)


print(f'Average simulation outcome:\t{o.mean()}')
print(f'Average gradient wrt theta:\t{g.mean()}')

Average simulation outcome:	0.8428428428428428
Average gradient wrt theta:	0.3932238664829637


Notice that this agrees with the previous calculation!

## Solution 2: Score Function Estimator (REINFORCE)

The second approach is to let go of all hope of calculating the gradient of a single simulation, and calculate the gradient of its **expected value** instead. It can be shown that the gradient of the expected value of a function of a random variable can be expressed in terms of the score function of the distribution of the random variable.

$$
\nabla_{\theta}\ \mathbb{E}_{\theta} [f(Z)] = \mathbb{E}_{\theta}\ [f(Z)\ \nabla_{\theta} \log P_{\theta}(Z)]
$$

That is, we now only need that the probability distribution from which we sample Z (Bernoulli in this specific case) be differentiable with respect to our parameters.

We can estimate the expectation numerically by running a lot of simulations and recording the outcomes and the log-probabilities of the outcomes. Then we can differentiate the log-probabilities instead of the simulations!


In [9]:
### log-probability of outcome
logp_fn = lambda z, p: np.log(p*(z==1) + (1-p)*(z==-1))


### Modify simulation to return logp too

def simulation_logp(subkey, theta):

    p = p_fn(theta) # calculate p
    z = sample_bernoulli(subkey, p) # sample z
    outcome = env_fn(z) # simulate outcome

    return outcome, logp_fn(z, p)

In order to make the gradient calculation more coincise, we can define a **surrogate loss**, which does not have a meaning *per se* but whose gradient is the the one we actually want.

In [10]:
def surrogate_loss(subkey, theta):

    outcome, logp = simulation_logp(subkey, theta)

    return jax.lax.stop_gradient(outcome)*logp

For one instance:

In [11]:
theta = 1. # gives p = 0.73

key, subkey = jax.random.split(key)
o = simulation(subkey, theta)
sl, g = jax.value_and_grad(surrogate_loss, argnums=1)(subkey, theta)

print(f'Simulation outcome:\t{o}')
print(f'Surrogate loss value\t{sl}')
print(f'Gradient wrt theta:\t{g}')

Simulation outcome:	2.0
surrogate_loss value	-0.6265233750364456
Gradient wrt theta:	0.5378828427399902


In order to get the true gradient we now have to get the expected value:

In [15]:
theta = 1. # gives p = 0.73

N_AVG = 1000

key, *subkeys = jax.random.split(key, N_AVG)
subkeys = np.asarray(subkeys)

o = jax.vmap(simulation, in_axes=(0,None))(subkeys, theta)
sl, g = jax.vmap(jax.value_and_grad(surrogate_loss, argnums=1), in_axes=(0,None))(subkeys, theta)


print(f'Average simulation outcome:\t{o.mean()}')
print(f'Average surrogate_loss value\t{sl.mean()}')
print(f'Average gradient wrt theta:\t{g.mean()}')

Average simulation outcome:	0.8588588588588588
Average surrogate_loss value	0.30152299510446934
Average gradient wrt theta:	0.8015532928282839


**NOTE:** While the average simulation outcome is the same in both cases (as it should be), the estimated gradient is different! Straight-through estimators and score function estimators have different characteristcs and are subject to different tradeoffs that can be found in the literature.

# Categorical Distributions

The same ideas can be translated almost exactly to the case of a categorical distribution. A categorical distribution is a discrete distribution over a finite set of outcomes, each with a given probability. The probability mass function of a categorical distribution is given by:

$$
P(X = k) = 
\begin{cases}
p_k & \text{if } k = 1, \ldots, K \\
0 & \text{otherwise}
\end{cases}
$$

where $p_k$ is the probability of the $k$-th outcome.

This is the case that is the most similar to the model presented in the paper, where each category $k$ is a single cell with its associated probability of division $p_k$.


## A toy model (softmax growth)

In order to showcase the application of the previous ideas to the more relevant (for us) case of categorical distributions, we present a toy model with very simple rules that we will optimize with gradient descent.

The rules of the game are the following:
- We have three categories of objects. The **system state** $\bar x = (x_1, ..., x_N)$ is how many objects of each category we have at some point in time.
- At each time step, one category is chosen and an object of the same category added to it. Objects of each category are characterized by a "propensity" of division $\bar \beta = (\beta_1, ..., \beta_N)$.
- Category $k$ is chosen with probability 
$$ p_k = \text{softmax}_k(\bar \beta, \bar x) = \frac{\exp[-\beta_k x_k]}{\sum_i \exp[-\beta_i x_i]} $$
- One game is composed of $T$ rounds.


The aim of the optimization is to choose the propensities $\bar \beta$ in such a way that the system state at the end of the game is as close as possible to a target state $\bar x^*$. We will define the loss function as the squared distance between the final state and the target state:

$$ L_{\bar \beta}(\bar x_T) = \sum_i (x_i(T) - x_i^*)^2 $$

Note that in this case the dependency on $\bar \beta$ of the loss function is given by the fact that 
$$ \bar x(T) = \mathbb F_{\bar \beta}[\bar x(0)] $$

where $\mathbb F$ is the dynamics system state, which clearly depend on the chosen parameters. 

In [32]:
N = 3 # number of categories

T = 20 # number of time steps per simulation


In [40]:
# build target final state

tx1 = T//3
tx2 = T//7
tx3 = T - tx1 - tx2
target_x = np.array([tx1, tx2, tx3])

print(f'Target final state:\t{target_x}')


# define square loss on state
def loss_x(x, target_x):
    return np.sum((x - target_x)**2)

Target final state:	[ 6  2 12]


In [35]:
#generate probabilities
p_fn = lambda x, betas: jax.nn.softmax(x * betas)

# sample category index
sample_category_idx = lambda subkey, p: jax.random.choice(subkey, len(p), p=p)

# update system state
update_state = lambda x, z: x + jax.nn.one_hot(z, len(x))

In [36]:
def simulation(subkey, x0, betas, T):

    def _sim_step(subkey, x, betas):

        p = p_fn(x, betas)
        z = sample_category_idx(subkey, p)
        x = update_state(x, z)

        return x
    
    subkeys = jax.random.split(subkey, T)

    x = x0
    for k in subkeys:
        x = _sim_step(k, x, betas)

    return x

In [39]:
### Carry out a trial simulation

x0 = np.ones(N)
betas = np.ones(N)

key, subkey = jax.random.split(key)
xT = simulation(subkey, x0, betas, T)

print(f'Initial state:\t\t{x0}')
print(f'Final state:\t\t{xT}')
print(f'Target final state:\t{target_x}')

Initial state:		[1. 1. 1.]
Final state:		[19.  2.  2.]
Target final state:	[ 6  2 12]


## Straight-through Approach