### Optimisation
Optimise parameter set and do single model run with the resulting parameters. Then illustrate results with basic plots.

### Running over Colab
Uncomment the commands in the following cell to install in Colab.
Once completed, click the 'Restart runtime' button that will appear to restart the Colab environment and then proceed to the following cells.

In [None]:
# !pip uninstall numba -y
# !pip uninstall librosa -y
# !pip install estival==0.4.9 numpy==1.24.3 kaleido

In [None]:
try:
    import google.colab
    on_colab = True
    ! git clone https://github.com/monash-emu/aust-covid.git --branch main
    %cd aust-covid
    %pip install -e ./
    import multiprocessing as mp
    mp.set_start_method('forkserver')
except:
    on_colab = False

In [None]:
import numpy as np
import pymc as pm
from plotly import graph_objects as go

from estival.model import BayesianCompartmentalModel
from estival.wrappers import pymc as epm

from inputs.constants import SUPPLEMENT_PATH, PLOT_START_DATE
from aust_covid.inputs import get_ifrs
from aust_covid.model import build_model
from aust_covid.calibration import get_priors, get_targets
from aust_covid.plotting import plot_single_run_outputs, plot_example_model_matrices
from emutools.tex import DummyTexDoc, StandardTexDoc, add_image_to_doc
from emutools.utils import load_param_info, round_sigfig

In [None]:
max_iterations = 50

In [None]:
param_info = load_param_info()
param_info['value'].update(get_ifrs(DummyTexDoc()))
parameters = param_info['value'].to_dict()
epi_model = build_model(DummyTexDoc(), param_info['abbreviations'], mobility_ext=True, cross_ref=False)
priors = get_priors(False, param_info['abbreviations'], DummyTexDoc())
prior_names = [p.name for p in priors]
targets = get_targets(DummyTexDoc())
bcm = BayesianCompartmentalModel(epi_model, parameters, priors, targets)

In [None]:
with pm.Model() as pmc_model:
    start_params = {k: np.clip(v, *bcm.priors[k].bounds(0.99)) for k, v in parameters.items() if k in bcm.priors}
    variables = epm.use_model(bcm)
    map_params = pm.find_MAP(start=start_params, vars=variables, include_transformed=False, maxeval=max_iterations)
    map_params = {k: float(v) for k, v in map_params.items()}
print('Best candidate parameters found:')
for i_param, param in enumerate([p for p in map_params if '_dispersion' not in p]):
    print(f'   {param}: {round_sigfig(map_params[param], 4)} (within bound {priors[i_param].bounds()}')
parameters.update(map_params)
epi_model.run(parameters=parameters)

In [None]:
parameters.update(map_params)
epi_model.run(parameters=parameters)

In [None]:
outputs_fig = plot_single_run_outputs(epi_model, targets)
outputs_fig