In [None]:
import arviz as az
import numpy as np
import pymc as pm
import pytensor

from climepi import epimod

pytensor.config.cxx = "/usr/bin/clang++"


In [None]:
parameters = {
    "eggs_per_female_per_day": {
        "curve_type": "briere",
        "priors": {
            "scale": lambda: pm.Gamma("scale", alpha=2, beta=100),
            "temperature_min": lambda: pm.Gamma(
                "temperature_min", alpha=10, beta=1 / 2
            ),
            "temperature_max": lambda: pm.Gamma(
                "temperature_max", alpha=10, beta=1 / 4
            ),
            "noise_std": lambda: pm.Uniform("noise_std", lower=0, upper=10),
        },
        "attrs": {"long_name": "Eggs per female per day"},
    },
    "egg_to_adult_development_rate": {
        "curve_type": "briere",
        "priors": {
            "scale": lambda: pm.Gamma("scale", alpha=9, beta=100000),
            "temperature_min": lambda: pm.Gamma("temperature_min", alpha=7, beta=1 / 2),
            "temperature_max": lambda: pm.Gamma(
                "temperature_max", alpha=10, beta=1 / 4
            ),
            "noise_std": lambda: pm.Uniform("noise_std", lower=0, upper=1),
        },
        "attrs": {"long_name": "Egg to adult development rate", "units": "per day"},
    },
    "egg_to_adult_survival_probability": {
        "curve_type": "quadratic",
        "probability": True,
        "priors": {
            "scale": lambda: pm.Gamma("scale", alpha=7, beta=1000),
            "temperature_min": lambda: pm.Gamma("temperature_min", alpha=7, beta=1 / 2),
            "temperature_max": lambda: pm.Gamma(
                "temperature_max", alpha=10, beta=1 / 4
            ),
            "noise_std": lambda: pm.Uniform("noise_std", lower=0, upper=5),
        },
        "attrs": {"long_name": "Egg to adult survival probability"},
    },
    "adult_lifespan": {
        "curve_type": "quadratic",
        "priors": {
            "scale": lambda: pm.Gamma("scale", alpha=1, beta=2),
            "temperature_min": lambda: pm.Gamma("temperature_min", alpha=5, beta=1 / 2),
            "temperature_max": lambda: pm.Gamma("temperature_max", alpha=9, beta=1 / 5),
            "noise_std": lambda: pm.Uniform("noise_std", lower=0, upper=50),
        },
        "attrs": {
            "long_name": "Adult lifespan",
            "units": "days",
        },
    },
    "aquatic_stage_carrying_capacity_per_m2": lambda temperature=None,
    precipitation=None: (precipitation >= 0.2)
    * precipitation
    * 300
    / (25 * 4.59 * (5 + 245 + precipitation)),
    "larval_flush_out_rate": lambda temperature=None, precipitation=None: (
        precipitation
    ),
}
data = epimod.get_example_temperature_response_data("mordecai_ae_aegypti")

In [None]:
def _suitability_function(
    eggs_per_female_per_day=None,
    egg_to_adult_development_rate=None,
    egg_to_adult_survival_probability=None,
    adult_lifespan=None,
    aquatic_stage_carrying_capacity_per_m2=None,
    larval_flush_out_rate=None,
):
    aquatic_to_adult_development_rate = (73 / 48) * egg_to_adult_development_rate
    aquatic_stage_death_rate = aquatic_to_adult_development_rate * (
        (1 / egg_to_adult_survival_probability) - 1
    )
    equilibrium_density = aquatic_stage_carrying_capacity_per_m2 * (
        0.5 * aquatic_to_adult_development_rate * adult_lifespan
        - (
            larval_flush_out_rate
            + aquatic_stage_death_rate
            + aquatic_to_adult_development_rate
        )
        / eggs_per_female_per_day
    )
    suitability = equilibrium_density > 0
    return suitability

In [None]:
suitability_model = epimod.ParameterizedSuitabilityModel(
    parameters=parameters, data=data, suitability_function=_suitability_function
)

In [None]:
idata_dict = suitability_model.fit_temperature_responses(
    tune=10000, draws=10000, thin=10
)

In [None]:
plots = (
    suitability_model.plot_fitted_temperature_responses(
        temperature_vals=np.linspace(0, 50, 500)
    )
    .cols(2)
    .opts(legend_position="top_right")
)
plots[0].opts(ylim=(0, 15), legend_position="top_left")
plots[1].opts(ylim=(0, 0.2), show_legend=False)
plots[2].opts(ylim=(0, 1), show_legend=False)
plots[3].opts(ylim=(0, 50), show_legend=False)
plots

In [None]:
az.plot_trace(idata_dict["eggs_per_female_per_day"], figsize=(10, 10), show=False)

In [None]:
az.plot_trace(idata_dict["egg_to_adult_development_rate"]);

In [None]:
az.plot_trace(idata_dict["egg_to_adult_survival_probability"]);

In [None]:
az.plot_trace(idata_dict["adult_lifespan"]);

In [None]:
for parameter_name, idata in idata_dict.items():
    print(f"Parameter: {parameter_name}")
    print(az.rhat(idata))

In [None]:
for parameter_name, idata in idata_dict.items():
    print(f"Parameter: {parameter_name}")
    print(az.ess(idata))

In [None]:
suitability_model.construct_suitability_table(
    temperature_vals=np.linspace(0, 40, 200),
    precipitation_vals=np.linspace(0, 30, 200),
)

In [None]:
suitability_model.reduce(stat="mean").plot_suitability(rasterize=True).opts(
    color_levels=[0, 0.025, 0.5, 0.975, 1],
    cmap=["green", "yellow", "orange", "red"],
)

In [None]:
suitability_model.reduce(stat="median").plot_suitability(rasterize=True)