# Inference (on synthetic data) 

In [None]:
from IPython.display import HTML      # For animation.
from matplotlib import animation      # For animation.
from matplotlib import pyplot as plt  # For plotting.
import numpy as np                    # For solving ODE.
from scipy.integrate import odeint    # For solving ODE.
import cma                            # For optimization.

## Generating synthetic data 

In [None]:
# Parameters.
a = 0.5
b = 0.028
c = 0.85
d = 0.02

t0 = 0.0
t1 = 20.0
y0 = [30, 4]

# Model.
F = lambda y, t: [
    y[0] * (a - b * y[1]),   # How population 0 changes.
    y[1] * (-c + d * y[0]),  # How population 1 changes.
]

In [None]:
# Solve ODE to prepare synthetic data.
data_t = np.linspace(t0, t1, 20)  # Only 20 points!
data_y = odeint(F, y0, data_t)

In [None]:
# Time-population plot.
plt.plot(data_t, data_y[:, 0], '--o', label='prey')
plt.plot(data_t, data_y[:, 1], '--o', label='predator')
plt.xlabel('time')
plt.ylabel('population')
plt.grid(linestyle=':')
plt.legend()
plt.show()

## Optimization (finding optimal paremeters) 

In [None]:
def evaluate(params):
    """Given some parameters, calculate how much does
    the Lotka-Vortella result deviate from data."""
    a, b, c, d = params
    F = lambda y, t: [
        y[0] * (a - b * y[1]),   # How population 0 changes.
        y[1] * (-c + d * y[0]),  # How population 1 changes.
    ]

    # Evaluate at the same points the data is available at.
    y = odeint(F, y0, data_t)
    
    # We define deviation as the sum of squares of differences (Frobenius norm).
    result = np.linalg.norm(y - data_y, 'fro')

    return result

In [None]:
initial_guess = [0.1, 0.05, 0.1, 0.05]  # Initial guess of (a, b, c, d).
lower_bounds = [0.1, 0.001, 0.1, 0.001]
upper_bounds = [0.9, 0.1, 0.9, 0.1]

def constrained_evaluation(params):
    """Wrap `evaluate` function with a penalization that
    keeps the parameters in specific ranges."""
    penalty = 0
    for lb, v, ub in zip(lower_bounds, params, upper_bounds):
        if v < lb:
            penalty += (lb - v) ** 2  # Too low.
        elif v > ub:
            penalty += (v - ub) ** 2  # Too high.
    return 1e6 * penalty + evaluate(params)

es = cma.CMAEvolutionStrategy(initial_guess, 0.5, {'popsize': 64})
cma_result = es.optimize(constrained_evaluation)

print("\nOptimization done! Best parameters:", cma_result.best.get()[0])

In [None]:
# Best parameter set is:
cma_result.best.get()[0]

In [None]:
# Extract parameters as a_, b_, c_ and d_.
a_, b_, c_, d_ = cma_result.best.get()[0]

# Define new ODE with the 
F_ = lambda y, t: [
    y[0] * (a_ - b_ * y[1]),   # How population 0 changes.
    y[1] * (-c_ + d_ * y[0]),  # How population 1 changes.
]

t = np.linspace(t0, t1, 100)
y = odeint(F_, y0, t)

In [None]:
# Time-population plot.
plt.plot(data_t, data_y[:, 0], '--o', color='tab:blue', label='prey (data)')
plt.plot(data_t, data_y[:, 1], '--o', color='tab:orange', label='predator (data)')
plt.plot(t, y[:, 0], color='tab:blue', label='prey')
plt.plot(t, y[:, 1], color='tab:orange', label='predator')
plt.xlabel('time')
plt.ylabel('population')
plt.grid(linestyle=':')
plt.legend()
plt.show()