In [None]:
from jax import random
import pandas as pd
from datetime import datetime, timedelta
import numpyro
from numpyro import distributions as dist
from numpyro import infer
import arviz as az
from plotly.express.colors import qualitative as qual_colours
from pathlib import Path
import plotly.express as px
import plotly.graph_objects as go
from plotly.subplots import make_subplots

from estival.sampling import tools as esamp

from emu_renewal.process import CosineMultiCurve
from emu_renewal.distributions import GammaDens
from emu_renewal.renew import RenewalHospModel
from emu_renewal.outputs import get_spaghetti, get_quant_df_from_spaghetti, get_spagh_df_from_dict
from emu_renewal.plotting import plot_post_prior_comparison, plot_spaghetti_calib_comparison
from emu_renewal.calibration import StandardCalib
from emu_renewal.targets import StandardDispTarget

In [None]:
PROJECT_PATH = Path.cwd().resolve()
DATA_PATH = PROJECT_PATH.parent / "data/covid_aus"

In [None]:
# Get data
target_data = pd.read_csv(DATA_PATH / "WHO-COVID-19-global-data_21_8_24.csv")
seroprev_data = pd.read_csv(DATA_PATH / "aus_seroprev_data.csv")

# Clean cases data
aust_data = target_data.loc[target_data["Country"] == "Australia"]
aust_data.index = pd.to_datetime(aust_data["Date_reported"], format="%d/%m/%Y")
aust_cases = aust_data["New_cases"].resample("W-SUN").interpolate(method="linear").fillna(0.0)
aust_deaths = aust_data["New_deaths"]

# Clean seroprevalence data
seroprev_data.index = pd.to_datetime(seroprev_data["date"])
aust_seroprev = seroprev_data["seroprevalence"]

# Hospitalisation data
aust_hosp = pd.read_csv(DATA_PATH / "hosp.csv")
aust_hosp.index = pd.to_datetime(aust_hosp["date"])
aust_hosp = aust_hosp["value"]

In [None]:
# Specify fixed parameters and get calibration data
proc_update_freq = 14
init_time = 50
pop = 26e6
analysis_start = datetime(2021, 12, 1)
analysis_end = datetime(2022, 10, 1)
# Start calibration targets slightly late so as not to penalise laggy indicators
data_start = analysis_start + timedelta(14)
init_start = analysis_start - timedelta(init_time)
init_end = analysis_start - timedelta(1)
select_data = aust_cases.loc[data_start: analysis_end]
select_deaths = aust_deaths.loc[data_start: analysis_end]
hosp_data = aust_hosp[data_start: analysis_end: 7]
init_data = aust_cases.resample("D").asfreq().interpolate().loc[init_start: init_end] / 7.0

In [None]:
# Define parameter ranges
priors = {
    "gen_mean": dist.TruncatedNormal(7.3, 0.5, low=1.0),
    "gen_sd": dist.TruncatedNormal(3.8, 0.5, low=1.0),
    "cdr": dist.Beta(15, 15), #(16,40)
    "ifr": dist.Beta(3, 200),
    "rt_init": dist.Normal(0.0, 0.25),
    "report_mean": dist.TruncatedNormal(8.0, 0.5, low=1.0),
    "report_sd": dist.TruncatedNormal(3.0, 0.5, low=1.0),
    "death_mean": dist.TruncatedNormal(18.0, 0.5, low=1.0),
    "death_sd": dist.TruncatedNormal(5.0, 0.5, low=1.0),
    "admit_mean": dist.TruncatedNormal(10.0, 1.5, low=1.0),
    "admit_sd": dist.TruncatedNormal(5.0, 0.5, low=1.0),
    "stay_mean": dist.TruncatedNormal(10.0, 1.5, low=1.0),
    "stay_sd": dist.TruncatedNormal(5.0, 0.5, low=1.0),
    "har": dist.Beta(5, 200),
    "shared_dispersion": dist.HalfNormal(0.5),
}

In [None]:
# Define model and fitter
proc_fitter = CosineMultiCurve()
renew_model = RenewalHospModel(pop, analysis_start, analysis_end, proc_update_freq, proc_fitter, GammaDens(), init_time, init_data, GammaDens(), discharge_dens=GammaDens())

In [None]:
# Define calibration and calib data
calib_data = {
    "weekly_cases": StandardDispTarget(select_data),
    "seropos": StandardDispTarget(aust_seroprev),
    "weekly_deaths": StandardDispTarget(select_deaths),
    "occupancy": StandardDispTarget(hosp_data)
}
calib = StandardCalib(renew_model, priors, calib_data)

In [None]:
# Run calibration
kernel = infer.NUTS(calib.calibration, dense_mass=True, init_strategy=calib.custom_init(radius=0.5))
mcmc = infer.MCMC(kernel, num_chains=4, num_samples=1000, num_warmup=1000)
mcmc.run(random.PRNGKey(1))

In [None]:
# Grab sample of data from calibrated model outputs
idata = az.from_dict(mcmc.get_samples(True))
idata_sampled = az.extract(idata, num_samples=20)
sample_params = esamp.xarray_to_sampleiterator(idata_sampled)

In [None]:
az.summary(idata)

In [None]:
spaghetti = get_spagh_df_from_dict(get_spaghetti(calib, sample_params))

In [None]:
plot_spaghetti_calib_comparison(spaghetti, calib.targets, ["weekly_cases", "weekly_deaths", "seropos", "occupancy"])

In [None]:
plot_spaghetti_calib_comparison(spaghetti, calib.targets, ["process", "r_t"])

In [None]:
plot_spaghetti_calib_comparison(spaghetti, calib.targets, ["incidence", "admissions", "occupancy"])

In [None]:
az.summary(idata)

In [None]:
az.plot_trace(idata)

In [None]:
priors.keys()

In [None]:
plot_post_prior_comparison(idata, ["cdr", "stay_mean", "stay_sd", "ifr"], priors);