# Obtaining maximum likelihood estimate using evolutionary algorithm

Sometime it is useful to start somewhere close to the high likelihood region. In this tutorial, we use a built-in evolutionary algorithm to find the maximum likelihood estimate of the model parameters.


In [None]:
import jax
import jax.numpy as jnp  # JAX NumPy
from jax.scipy.special import logsumexp

from flowMC.utils.EvolutionaryOptimizer import EvolutionaryOptimizer


def target_dualmoon(x, data):
    """
    Term 2 and 3 separate the distribution and smear it along the first and second dimension
    """
    print("compile count")
    term1 = 0.5 * ((jnp.linalg.norm(x - data["data"]) - 2) / 0.1) ** 2
    term2 = -0.5 * ((x[:1] + jnp.array([-3.0, 3.0])) / 0.8) ** 2
    term3 = -0.5 * ((x[1:2] + jnp.array([-3.0, 3.0])) / 0.6) ** 2
    return -(term1 - logsumexp(term2) - logsumexp(term3))


n_dim = 5
n_loops = 100
popsize = 100
bounds = jnp.array([[-10, 10]] * n_dim)

optimizer = EvolutionaryOptimizer(n_dim, popsize=popsize, verbose=True)
y = jax.jit(jax.vmap(lambda x: -target_dualmoon(x, {"data": jnp.zeros(n_dim)})))
state = optimizer.optimize(y, bounds, n_loops=n_loops)
best_fit = optimizer.get_result()[0]

In [None]:
print(best_fit)

# Let's compare the maximum likelihood to the posterior

In [None]:
from flowMC.nfmodel.rqSpline import MaskedCouplingRQSpline
from flowMC.sampler.MALA import MALA
from flowMC.sampler.Sampler import Sampler


n_chains = 20
n_loop_training = 5
n_loop_production = 5
n_local_steps = 100
n_global_steps = 100
learning_rate = 0.001
momentum = 0.9
num_epochs = 30
batch_size = 10000

data = {"data": jnp.zeros(n_dim)}

rng_key, subkey = jax.random.split(jax.random.PRNGKey(42))
model = MaskedCouplingRQSpline(n_dim, 4, [32, 32], 8, subkey)

rng_key, subkey = jax.random.split(rng_key)
initial_position = jax.random.normal(subkey, shape=(n_chains, n_dim)) * 1

MALA_Sampler = MALA(target_dualmoon, True, {"step_size": 0.1})

print("Initializing sampler class")

nf_sampler = Sampler(
    n_dim,
    rng_key,
    {"data": jnp.zeros(5)},
    MALA_Sampler,
    model,
    n_loop_training=n_loop_training,
    n_loop_production=n_loop_production,
    n_local_steps=n_local_steps,
    n_global_steps=n_global_steps,
    n_chains=n_chains,
    n_epochs=num_epochs,
    learning_rate=learning_rate,
    momentum=momentum,
    batch_size=batch_size,
    use_global=True,
)

nf_sampler.sample(initial_position, data)

In [None]:
import corner
import numpy as np


chains = np.array(nf_sampler.get_sampler_state(training=False)["chains"])

labels = ["$x_1$", "$x_2$", "$x_3$", "$x_4$", "$x_5$"]
# Plot all chains
figure = corner.corner(
    chains.reshape(-1, n_dim), labels=labels, truths=best_fit, truth_color="red"
)
figure.set_size_inches(7, 7)
figure.suptitle("Visualize samples")