In [None]:
from matplotlib import pyplot
from matplotlib.ticker import FuncFormatter
import pandas as pd
import warnings

from summer.utils import ref_times_to_dti

from autumn.tools.project import get_project
from autumn.tools.plots.utils import REF_DATE
from autumn.tools.utils.pretty import pretty_print
from autumn.models.sm_sir.preprocess.age_specific_params import convert_param_agegroups
from autumn.tools.inputs import get_population_by_agegroup

In [None]:
project = get_project("sm_sir", "malaysia")

In [None]:
custom_params = project.param_set.baseline
# custom_params = project.param_set.baseline.update({'contact_rate': 0.1,
#                                                    'infection_fatality': {'multiplier': 2.407957228094271},
#                                                    'mobility':{'microdistancing':{'behaviour':{'parameters':{'max_effect': 0.3}}}}})
model = project.run_baseline_model(custom_params)
derived_df = model.get_derived_outputs_df()

In [None]:
baseline_params = project.param_set.baseline.to_dict()
age_params = baseline_params["age_stratification"]
age_groups = baseline_params["age_groups"]
standard_agegroups = range(0, 80, 5)
ifr_age_groups = range(0, 85, 5)

# Susceptibility
fig = pyplot.figure(figsize=(15, 12))
axis = fig.add_subplot(2, 2, 1)
axis.plot(
    age_params["susceptibility"].keys(), 
    age_params["susceptibility"].values(),
    marker="o",
    linestyle="--",
)
axis.set_title("relative susceptibility")
axis.set_ylim(bottom=0.)

# Clinical proportions
for i_prop, clin_prop in enumerate(["prop_hospital", "prop_symptomatic", "ifr"]):
    
    base_agegroups = ifr_age_groups if clin_prop == "ifr" else standard_agegroups
    upper_y = 1. if clin_prop == "prop_symptomatic" else None
    
    axis = fig.add_subplot(2, 2, 2 + i_prop)
    axis.plot(
        base_agegroups, 
        age_params[clin_prop],
        marker="o",
        linestyle="--",
        label="raw",
        color="k"
    )
    axis.plot(
        age_groups,
        convert_param_agegroups(
            age_params[clin_prop], 
            baseline_params["country"]["iso3"], 
            baseline_params["population"]["region"], 
            age_groups
        ),
        marker="o",
        linestyle="--",
        label="processed",    
    )
    axis.set_title(clin_prop.replace("_", " "))
    axis.set_ylim(bottom=0., top=upper_y)
    axis.legend()


In [None]:
# Plot population
iso3 = baseline_params["country"]["iso3"]
subregion = baseline_params["population"]["region"]
year = baseline_params["population"]["year"]


def yaxis_millions(axis_to_adjust):
    
    def millions(x, pos):
        return f"{(x * 1e-6)}"

    formatter = 
    axis_to_adjust.yaxis.set_major_formatter(FuncFormatter(millions))
    axis.set_ylabel("millions")


model_pops = get_population_by_agegroup(age_groups, iso3, subregion, year)
standard_pops = get_population_by_agegroup(standard_agegroups, iso3, subregion, year)
print(f"Total population simulated is {sum(standard_pops) / 1e6} million")
print(f"Infectious seed is {baseline_params['infectious_seed']}")

fig = pyplot.figure(figsize=(10, 7))
axis = fig.add_subplot()

axis.plot(
    age_groups, 
    model_pops,
    marker="o",
    linestyle="--",
    label="modelled",
)
axis.plot(
    standard_agegroups, 
    standard_pops,
    marker="o",
    linestyle="--",
    label="standard"
)
yaxis_millions(axis)
axis.legend()


In [None]:
targets_dict = {t.data.name: pd.Series(t.data.data, index=ref_times_to_dti(model.ref_date, t.data.index)) for t in project.calibration.targets}

In [None]:
pyplot.style.use("ggplot")

warnings.filterwarnings("ignore")

outputs_to_plot = ["notifications", "infection_deaths", "icu_occupancy", "hospital_occupancy"]

model_start_time = ref_times_to_dti(REF_DATE, [project.param_set.baseline["time"]["start"]])[0]
model_end_time = ref_times_to_dti(REF_DATE, [project.param_set.baseline["time"]["end"]])[0]

fig = pyplot.figure(figsize=(15, 12))
for i_out, output in enumerate(outputs_to_plot):
    axis = fig.add_subplot(2, 2, i_out + 1)
    if output in targets_dict:
        targets_dict[output].plot(ax=axis, style='.')
    if output in derived_df:
        derived_df[output].plot(ax=axis)
    axis.set_title(output.replace("_", " "))
    axis.set_xlim([model_start_time, model_end_time])
    