## Try out a simple optimisation routine with scipy's minimize function
Using the same model as in the previous notebook.
$\\i_t = R_t\sum_{\tau<t} i_\tau g_{t-\tau}$

Possible next steps:
- Allow for truncated generation time distribution
- Make the random walk actually a walk
- Fit to actual data

In [None]:
from typing import Dict, List
from collections import namedtuple
from scipy.stats import gamma
import numpy as np
import pandas as pd
from plotly import graph_objects as go
from plotly.subplots import make_subplots
from scipy.optimize import minimize, shgo

In [None]:
Outputs = namedtuple('outputs', ['incidence', 'suscept', 'r_t'])

gen_times_end = 20

def semimech(gen_time_densities, process_vals, pop, seed, n_times) -> Outputs:
    """The semimechanistic process.
    """
    incidence = np.zeros(n_times)
    incidence[0] = seed
    this_suscept = pop - seed
    suscept = [this_suscept]
    r_t = [np.nan]
    for t in range(1, n_times):

        # Update outputs
        gen_times_interest = min(t, gen_times_end)
        new_gen_times = gen_time_densities[:gen_times_interest][::-1]
        new_inc_vals = incidence[t-gen_times_interest:t]
        
        infect_modifier = process_vals[t] * this_suscept / pop  # Product of pre-specified process value and proportion susceptible
        infect_contribution_by_day = incidence[:t] * gen_time_densities[t-1::-1]  # Product of past incidence and their generation time densities
        
        
        infect_contribution_by_day = incidence[t-gen_times_interest:t] * gen_time_densities[:gen_times_interest][::-1]
        
        
        this_inc = infect_contribution_by_day.sum() * infect_modifier  # Incidence
        this_suscept = max(this_suscept - this_inc, 0.0)  # Zero out any small negative values

        # Track quantities of interest
        incidence[t] = this_inc
        suscept.append(this_suscept)
        r_t.append(gen_time_densities[t-1::-1].sum() * infect_modifier)
    return Outputs(incidence, suscept, r_t)

def get_gamma_params_from_mean_sd(req_mean: float, req_sd: float) -> Dict[str, float]:
    """Get values for constructive the gamma distribution from user requested mean and standard deviation.
    """
    var = req_sd ** 2.0
    scale = var / req_mean
    a = req_mean / scale
    return {'a': a, 'scale': scale}

def get_gamma_densities_from_params(mean: float, sd: float) -> np.array:
    """Get integrals over integer differences in gamma distribution for simulation duration.
    """
    params = get_gamma_params_from_mean_sd(mean, sd)
    return np.diff(gamma.cdf(range(n_times + 1), **params))

def get_interp_vals_over_model_time(req: List[float], n_times: int) -> np.array:
    """Linear interpolation at requested values at regular intervals over simulation period.
    """
    return np.interp(range(n_times), np.linspace(0.0, n_times, len(req)), req)

def model_func(gen_time_mean: float, gen_time_sd: float, process_req: List[float], pop: int, seed: int, n_times: int) -> tuple:
    """The other epidemiological aspects of the model.
    """
    gen_time_densities = get_gamma_densities_from_params(gen_time_mean, gen_time_sd)
    process_req_exp = np.exp(np.array(process_req))
    process_vals = get_interp_vals_over_model_time(process_req_exp, n_times)
    model_result = semimech(gen_time_densities, process_vals, pop, seed, n_times)
    return model_result, process_vals

def calib_func(parameters: List[float], pop: int, seed: int, n_times: int, targets: dict) -> float:
    """Get the loss function from the model.
    """
    gen_time_mean, gen_time_sd, *process_req = parameters
    incidence = model_func(gen_time_mean, gen_time_sd, process_req, pop, seed, n_times)[0][0]
    return sum([(incidence[t] - d) ** 2 for t, d in targets.items()])

In [None]:
# Model parameters
population = 100.0
infectious_seed = 1.0
n_times = 40

In [None]:
model_times = pd.Series(range(n_times))
test_params = [5.5, 1.8, 1.0, 0.5, 1.5, 0.8]
test_data, _ = model_func(test_params[0], test_params[1], test_params[2:], population, infectious_seed, n_times)
test_vals = {t: d for t, d in zip(model_times, test_data.incidence)}

In [None]:
example, _ = model_func(test_params[0], test_params[1], test_params[2:], population, infectious_seed, n_times)


In [None]:
# Local optimisation with Nelder-Mead
param_bounds = [[0.0, 10.0]] + [[0.0, 4.0]] + [[-10.0, 10.0]] * 4
result = minimize(calib_func, [2.0, 1.0] + [0.5] * 4, method='Nelder-Mead', args=(population, infectious_seed, n_times, test_vals), bounds=param_bounds)
model_result, process_vals = model_func(result.x[0], result.x[1], result.x[2:], population, infectious_seed, n_times)
optimised, suscept, r_t = model_result

In [None]:
print(result.x[:2])
print(np.array(result.x[2:]))

In [None]:
# Global optimisation with shgo - need to capture arguments through closure due to bug in optimisation function
# as per comment at https://stackoverflow.com/questions/72794609/scipy-issue-passing-arguments-to-optimize-shgo-function
param_bounds = [[0.1, 30.0]] + [[0.1, 10.0]] + [[-10.0, 10.0]] * 4
global_result = shgo(lambda x, p=population, s=infectious_seed, t=n_times, d=test_vals: calib_func(x, p, s, t, d), param_bounds)
model_result, process_vals = model_func(global_result.x[0], global_result.x[1], global_result.x[2:], population, infectious_seed, n_times)
optimised, suscept, r_t = model_result

In [None]:
print(global_result.x[:2])
print(np.array(global_result.x[2:]))

In [None]:
# Plot
fig = make_subplots(3, 1, shared_xaxes=True, vertical_spacing=0.05, subplot_titles=['incidence', 'reproduction number', 'susceptibles'])
fig.add_trace(go.Scatter(x=list(test_vals.keys()), y=list(test_vals.values()), mode='markers', name='targets'), row=1, col=1)
fig.add_trace(go.Scatter(x=model_times, y=optimised, name='model'), row=1, col=1)
fig.add_trace(go.Scatter(x=model_times, y=process_vals, name='transmission potential'), row=2, col=1)
fig.add_trace(go.Scatter(x=model_times, y=r_t, name='Rt'), row=2, col=1)
fig.add_trace(go.Scatter(x=model_times, y=suscept, name='susceptibles'), row=3, col=1).update_layout(height=800, margin={'t': 20, 'b': 5, 'l': 5, 'r': 5})