In [None]:
import arviz as az
import numpy as np
import pymc as pm
import pytensor

from climepi import epimod

pytensor.config.cxx = "/usr/bin/clang++"

In [None]:
parameters = {
    "eggs_per_female_per_cycle": {
        "curve_type": "briere",
        "priors": {
            "scale": lambda: pm.Gamma(
                "scale", alpha=0.1 * 29.28512, beta=0.1 * 268.88795
            ),
            "temperature_min": lambda: pm.Gamma(
                "temperature_min", alpha=0.1 * 57.320563, beta=0.1 * 4.114297
            ),
            "temperature_max": lambda: pm.Gamma(
                "temperature_max", alpha=0.1 * 2597.57096, beta=0.1 * 80.99238
            ),
            "noise_precision": lambda: pm.Gamma(
                "noise_precision", alpha=0.1 * 2.238483, beta=0.1 * 84.707295
            ),
        },
        "attrs": {"long_name": "Eggs per female per cycle"},
    },
    "egg_to_adult_development_rate": {
        "curve_type": "briere",
        "priors": {
            "scale": lambda: pm.Gamma(
                "scale", alpha=0.1 * 19.69926, beta=0.1 * 132184.35192
            ),
            "temperature_min": lambda: pm.Gamma(
                "temperature_min", alpha=0.1 * 24.824235, beta=0.1 * 1.641687
            ),
            "temperature_max": lambda: pm.Gamma(
                "temperature_max", alpha=0.1 * 5189.7837, beta=0.1 * 137.7874
            ),
            "noise_precision": lambda: pm.Gamma(
                "noise_precision", alpha=0.1 * 7.295192761, beta=0.1 * 0.007662294
            ),
        },
        "attrs": {"long_name": "Egg to adult development rate", "units": "per day"},
    },
    "egg_to_adult_survival_probability": {
        "curve_type": "quadratic",
        "probability": True,
        "priors": {
            "scale": lambda: pm.Gamma(
                "scale", alpha=0.1 * 101.3912, beta=0.1 * 30194.2090
            ),
            "temperature_min": lambda: pm.Gamma(
                "temperature_min", alpha=0.1 * 154.75066, beta=0.1 * 20.14923
            ),
            "temperature_max": lambda: pm.Gamma(
                "temperature_max", alpha=0.1 * 3319.22251, beta=0.1 * 86.63973
            ),
            "noise_precision": lambda: pm.Gamma(
                "noise_precision", alpha=0.1 * 3.86616035, beta=0.1 * 0.01627125
            ),
        },
        "attrs": {"long_name": "Egg to adult survival probability"},
    },
    "adult_lifespan": {
        "curve_type": "quadratic",
        "priors": {
            "scale": lambda: pm.Gamma(
                "scale", alpha=0.01 * 73.17713, beta=0.01 * 58.83547
            ),
            "temperature_min": lambda: pm.Gamma(
                "temperature_min", alpha=0.01 * 1764.9573, beta=0.01 * 106.1194
            ),
            "temperature_max": lambda: pm.Gamma(
                "temperature_max", alpha=0.01 * 5601.4318, beta=0.01 * 175.8671
            ),
            "noise_precision": lambda: pm.Gamma(
                "noise_precision", alpha=0.01 * 1.904063, beta=0.01 * 15.663954
            ),
        },
        "attrs": {"long_name": "Adult lifespan", "units": "days"},
    },
    "biting_rate": {
        "curve_type": "briere",
        "priors": {
            "scale": lambda: pm.Gamma(
                "scale", alpha=0.1 * 15.61913, beta=0.1 * 57672.66973
            ),
            "temperature_min": lambda: pm.Gamma(
                "temperature_min", alpha=0.1 * 42.657272, beta=0.1 * 2.906991
            ),
            "temperature_max": lambda: pm.Gamma(
                "temperature_max", alpha=0.1 * 351.663454, beta=0.1 * 8.577776
            ),
            "noise_precision": lambda: pm.Gamma(
                "noise_precision", alpha=0.1 * 5.49987887, beta=0.1 * 0.01249048
            ),
        },
        "attrs": {"long_name": "Biting rate", "units": "per day"},
    },
    "human_to_mosquito_transmission_probability": {
        "curve_type": "briere",
        "priors": {
            "scale": lambda: pm.Gamma(
                "scale", alpha=0.5 * 87.88911, beta=0.5 * 167941.92094
            ),
            "temperature_min": lambda: pm.Gamma(
                "temperature_min", alpha=0.5 * 1.1497051, beta=0.5 * 0.7603777
            ),
            "temperature_max": lambda: pm.Gamma(
                "temperature_max", alpha=0.5 * 853.84870, beta=0.5 * 24.57488
            ),
            "noise_precision": lambda: pm.Gamma(
                "noise_precision", alpha=0.5 * 17.9579092, beta=0.5 * 0.7864043
            ),
        },
        "attrs": {
            "long_name": "Human to mosquito transmission probability",
            "units": "per bite",
        },
        "probability": True,
    },
    "mosquito_to_human_transmission_probability": {
        "curve_type": "briere",
        "priors": {
            "scale": lambda: pm.Gamma(
                "scale", alpha=0.5 * 25.96487, beta=0.5 * 26322.25052
            ),
            "temperature_min": lambda: pm.Gamma(
                "temperature_min", alpha=0.5 * 36.029388, beta=0.5 * 2.989315
            ),
            "temperature_max": lambda: pm.Gamma(
                "temperature_max", alpha=0.5 * 2236.38565, beta=0.5 * 68.19415
            ),
            "noise_precision": lambda: pm.Gamma(
                "noise_precision", alpha=0.5 * 10.9259890, beta=0.5 * 0.4362927
            ),
        },
        "attrs": {
            "long_name": "Mosquito to human transmission probability",
            "units": "per bite",
        },
        "probability": True,
    },
    "extrinsic_incubation_rate": {
        "curve_type": "briere",
        "priors": {
            "scale": lambda: pm.Gamma(
                "scale", alpha=0.5 * 6.45593, beta=0.5 * 61855.80928
            ),
            "temperature_min": lambda: pm.Gamma(
                "temperature_min", alpha=0.5 * 5.8000327, beta=0.5 * 0.5044451
            ),
            "temperature_max": lambda: pm.Gamma(
                "temperature_max", alpha=0.5 * 118.951801, beta=0.5 * 3.052254
            ),
            "noise_precision": lambda: pm.Gamma(
                "noise_precision", alpha=0.5 * 8.37446874, beta=0.5 * 0.01684306
            ),
        },
        "attrs": {"long_name": "Extrinsic incubation rate", "units": "per day"},
    },
}
data = epimod.get_example_temperature_response_data("mordecai_ae_albopictus")

In [None]:
def _suitability_function(
    eggs_per_female_per_cycle=None,
    egg_to_adult_development_rate=None,
    egg_to_adult_survival_probability=None,
    adult_lifespan=None,
    biting_rate=None,
    human_to_mosquito_transmission_probability=None,
    mosquito_to_human_transmission_probability=None,
    extrinsic_incubation_rate=None,
):
    R0_rel = (
        biting_rate
        * mosquito_to_human_transmission_probability
        * human_to_mosquito_transmission_probability
        * np.exp(-1 / (extrinsic_incubation_rate * adult_lifespan))
        * eggs_per_female_per_cycle  # assume eggs/cycle = (biting rate)*(eggs/day)
        * egg_to_adult_survival_probability
        * egg_to_adult_development_rate
        * (adult_lifespan**3)
    ) ** 0.5
    return R0_rel

In [None]:
suitability_model = epimod.ParameterizedSuitabilityModel(
    parameters=parameters, data=data, suitability_function=_suitability_function
)

In [None]:
idata_dict = suitability_model.fit_temperature_responses(tune=10000, draws=25000)

In [None]:
plots = suitability_model.plot_fitted_temperature_responses(
    temperature_vals=np.linspace(0, 50, 500)
).cols(4)
plots[4].opts(ylim=(0, 0.4))
plots

In [None]:
az.plot_trace(idata_dict["eggs_per_female_per_cycle"]);

In [None]:
idata_dict["eggs_per_female_per_cycle"].posterior.std(dim="draw")

In [None]:
for parameter_name, idata in idata_dict.items():
    print(f"Parameter: {parameter_name}")
    print(az.rhat(idata))

In [None]:
for parameter_name, idata in idata_dict.items():
    print(f"Parameter: {parameter_name}")
    print(az.ess(idata))

In [None]:
suitability_model.construct_suitability_table(
    num_samples=10000, temperature_vals=np.linspace(0, 50, 1000)
)

In [None]:
(
    suitability_model.reduce(stat="mean", rescale=True).plot_suitability()
    * suitability_model.reduce(
        stat="quantile", quantile=[0.025, 0.975], rescale="mean"
    ).plot_suitability(color="red", line_dash="dashed", by="suitability_quantile")
).opts(xlim=(10, 40))

In [None]:
suitability_model.reduce(
    suitability_threshold=0, stat="quantile", quantile=0.025
).temperature_range

In [None]:
ds_posterior_min_optimal_max = (
    suitability_model.get_posterior_min_optimal_max_temperature()
)
(
    ds_posterior_min_optimal_max["temperature_min"].hvplot.hist()
    + ds_posterior_min_optimal_max["temperature_optimal"].hvplot.hist()
    + ds_posterior_min_optimal_max["temperature_max"].hvplot.hist()
).cols(2)

In [None]:
ds_posterior_min_optimal_max.mean()