## Try out a simple optimisation routine with scipy's minimize function
Using the same model as in the previous notebook.
$\\i_t = R_t\sum_{\tau<t} i_\tau g_{t-\tau}$

In [None]:
from typing import Dict, List
from scipy.stats import gamma
import numpy as np
import pandas as pd
from plotly import graph_objects as go
from plotly.subplots import make_subplots
from scipy.optimize import minimize

In [None]:
from collections import namedtuple

In [None]:
Outputs = namedtuple('outputs', ['incidence', 'suscept', 'r_t', 'process_vals'])

def semimech(gen_time_densities, process_vals, pop, seed, n_times):
    """The semimechanistic process.
    """
    incidence = np.zeros(n_times)
    incidence[0] = seed
    this_suscept = pop - seed
    suscept = [this_suscept]
    r_t = [np.nan]
    for t in range(1, n_times):
        
        # Update outputs
        infect_modifier = process_vals[t] * this_suscept / pop  # Product of pre-specified process value and proportion susceptible
        infect_contribution_by_day = incidence[:t] * gen_time_densities[t-1::-1]  # Product of past incidence and their generation time densities
        this_inc = infect_contribution_by_day.sum() * infect_modifier  # Incidence
        this_suscept = max(this_suscept - this_inc, 0.0)  # Zero out any small negative values

        # Track quantities of interest
        incidence[t] = this_inc
        suscept.append(this_suscept)
        r_t.append(gen_time_densities[t-1::-1].sum() * infect_modifier)
    return Outputs(incidence, suscept, r_t, process_vals)

def get_gamma_params_from_mean_sd(req_mean: float, req_sd: float) -> Dict[str, float]:
    """Get values for constructive the gamma distribution from user requested mean and standard deviation.
    """
    var = req_sd ** 2.0
    scale = var / req_mean
    a = req_mean / scale
    return {'a': a, 'scale': scale}

def get_gamma_densities_from_params(mean: float, sd: float) -> np.array:
    """Get integrals over integer differences in gamma distribution for simulation duration.
    """
    params = get_gamma_params_from_mean_sd(mean, sd)
    return np.diff(gamma.cdf(range(n_times + 1), **params))

def get_interp_vals_over_model_time(req: List[float], n_times: int) -> np.array:
    """Linear interpolation at requested values at regular intervals over simulation period.
    """
    return np.interp(range(n_times), np.linspace(0.0, n_times, len(req)), req)

def model_func(gen_time_mean: float, gen_time_sd: float, process_req: List[float], pop: int, seed: int, n_times: int) -> tuple:
    """The other epidemiological aspects of the model.
    """
    gen_time_densities = get_gamma_densities_from_params(gen_time_mean, gen_time_sd)
    process_vals = get_interp_vals_over_model_time(process_req, n_times)
    return semimech(gen_time_densities, process_vals, pop, seed, n_times)

def calib_func(parameters: List[float], pop: int, seed: int, n_times: int) -> float:
    """Get the loss function from the model.
    """
    gen_time_mean, gen_time_sd, *process_req = parameters
    incidence = model_func(gen_time_mean, gen_time_sd, process_req, pop, seed, n_times)[0]
    return sum([(incidence[t] - d) ** 2 for t, d in dummy_data.items()])

In [None]:
# Model parameters
population = 100.0
infectious_seed = 1.0
n_times = 40
dummy_data = pd.Series(
    {
        5: 1.0,
        10: 1.0,
        15: 1.5,
        25: 4.2,
        30: 3.8,
        35: 2.1,
    },
)

In [None]:
# Optimise and run results through the model
param_bounds = [[5.0, 6.0]] + [[1.5, 2.0]] + [[0.0, 10.0]] * 4
result = minimize(calib_func, [5.0, 1.5] + [2.0] * 4, method='Nelder-Mead', args=(population, infectious_seed, n_times), bounds=param_bounds)
optimised, suscept, r_t, process_vals = model_func(result.x[0], result.x[1], result.x[2:], population, infectious_seed, n_times)

In [None]:
# Plot
model_times = pd.Series(range(n_times))
fig = make_subplots(3, 1, shared_xaxes=True, vertical_spacing=0.05, subplot_titles=['incidence', 'reproduction number', 'susceptibles'])
fig.add_trace(go.Scatter(x=dummy_data.index, y=dummy_data, mode='markers', name='targets'), row=1, col=1)
fig.add_trace(go.Scatter(x=model_times, y=optimised, name='model'), row=1, col=1)
fig.add_trace(go.Scatter(x=model_times, y=process_vals, name='transmission potential'), row=2, col=1)
fig.add_trace(go.Scatter(x=model_times, y=r_t, name='Rt'), row=2, col=1)
fig.add_trace(go.Scatter(x=model_times, y=suscept, name='susceptibles'), row=3, col=1).update_layout(height=800, margin={'t': 20, 'b': 5, 'l': 5, 'r': 5})