In [None]:
from jax import jit, random
import pandas as pd
from datetime import datetime, timedelta
import numpyro
from numpyro import distributions as dist
from numpyro import infer
import arviz as az
from IPython.display import Markdown
from plotly.express.colors import qualitative as qual_colours
from pathlib import Path
import math
import plotly.express as px
import plotly.graph_objects as go
from plotly.subplots import make_subplots
from IPython.display import Markdown
import matplotlib
import matplotlib.pyplot as plt

from estival.sampling import tools as esamp

from emu_renewal.process import CosineMultiCurve
from emu_renewal.distributions import GammaDens
from emu_renewal.renew import RenewalDeathsModel
from emu_renewal.outputs import get_spaghetti_from_params, get_quant_df_from_spaghetti, plot_spaghetti
from emu_renewal.outputs import plot_uncertainty_patches, PANEL_SUBTITLES, plot_3d_spaghetti, plot_post_prior_comparison
from emu_renewal.calibration import StandardCalib
from emu_renewal.utils import get_adjust_idata_index, adjust_summary_cols

In [2]:
PROJECT_PATH = Path.cwd().resolve()
DATA_PATH = PROJECT_PATH.parent / "data/covid_aus"

In [3]:
# Get data
target_data = pd.read_csv(DATA_PATH / "WHO-COVID-19-global-data_21_8_24.csv")
seroprev_data = pd.read_csv(DATA_PATH / "aus_seroprev_data.csv")

# Clean cases data
aust_data = target_data.loc[target_data["Country"] == "Australia"]
aust_data.index = pd.to_datetime(aust_data["Date_reported"], format="%d/%m/%Y")
aust_cases = aust_data["New_cases"].resample("W-SUN").interpolate(method="linear").fillna(0.0)
aust_deaths = aust_data["New_deaths"]

# Clean seroprevalence data
seroprev_data.index = pd.to_datetime(seroprev_data["date"])
aust_seroprev = seroprev_data["seroprevalence"]

In [4]:
# Specify fixed parameters and get calibration data
proc_update_freq = 14
init_time = 50
pop = 26e6
analysis_start = datetime(2021, 12, 1)
analysis_end = datetime(2022, 10, 1)
init_start = analysis_start - timedelta(init_time)
init_end = analysis_start - timedelta(1)
select_data = aust_cases.loc[analysis_start: analysis_end]
select_deaths = aust_deaths.loc[analysis_start: analysis_end]
init_data = aust_cases.resample("D").asfreq().interpolate().loc[init_start: init_end] / 7.0

In [6]:
# Define model and fitter
proc_fitter = CosineMultiCurve()
renew_model = RenewalDeathsModel(pop, analysis_start, analysis_end, proc_update_freq, proc_fitter, GammaDens(), init_time, init_data, GammaDens())

In [79]:
# Define parameter ranges
priors = {
    "gen_mean": dist.TruncatedNormal(7.3, 0.5, low=1.0),
    "gen_sd": dist.TruncatedNormal(3.8, 0.5, low=1.0),
    "cdr": dist.Beta(15, 15), #(16,40)
    "ifr": dist.Beta(2, 200),
    "rt_init": dist.Normal(0.0, 0.25),
    "report_mean": dist.TruncatedNormal(8, 0.5, low=1.0),
    "report_sd": dist.TruncatedNormal(3, 0.5, low=1.0),
    "death_mean": dist.TruncatedNormal(21.0, 0.5, low=1.0),
    "death_sd": dist.TruncatedNormal(5.0, 0.5, low=1.0),
}

In [80]:
# Define model and fitter
proc_fitter = CosineMultiCurve()
renew_model = RenewalDeathsModel(pop, analysis_start, analysis_end, proc_update_freq, proc_fitter, GammaDens(), init_time, init_data, GammaDens())

In [81]:
# Define calibration and calib data
calib_data = {
    "weekly_sum": select_data,
    "seropos": aust_seroprev,
    "weekly_deaths": select_deaths,
}
calib = StandardCalib(renew_model, priors, calib_data)

In [None]:
# Run calibration
kernel = infer.NUTS(calib.calibration, dense_mass=True, init_strategy=infer.init_to_uniform(radius=0.5))
mcmc = infer.MCMC(kernel, num_chains=4, num_samples=100, num_warmup=100)
mcmc.run(random.PRNGKey(2))

In [120]:
# Grab sample of data from calibrated model outputs
idata = az.from_dict(mcmc.get_samples(True))
idata_sampled = az.extract(idata, num_samples=800)
sample_params = esamp.xarray_to_sampleiterator(idata_sampled)

In [121]:
# get model results and outputs
def get_full_result(**params):
    return renew_model.renewal_func(**params)

full_wrap = jit(get_full_result)

In [122]:
OUTPUTS = ['weekly_deaths','weekly_sum', 'seropos', 'susceptible', 'R', 'transmission potential'] 
spaghetti = get_spaghetti_from_params(renew_model, sample_params, full_wrap, outputs=OUTPUTS)
quantiles_df = get_quant_df_from_spaghetti(renew_model, spaghetti, quantiles=[0.05, 0.5, 0.95], outputs=OUTPUTS)

In [None]:
plot_main(quantiles_df, select_data, select_deaths, aust_seroprev)

In [None]:
az.summary(idata)

In [None]:
plot_post_prior_comparison(idata, list(priors.keys()), priors);

In [None]:
az.plot_trace(idata)
fig = plt.gcf()
fig.savefig("trace_2.png")

In [16]:
def plot_main(model_data, case_data, death_data, seroprevalence_data):
    case_index = case_data.index
    cases = case_data
    death_index = death_data.index
    deaths = death_data
    seroprev_index = seroprevalence_data.index
    seroprev = seroprevalence_data
     
    # Define elements needed to add median line plots
    model_index = model_data.index
    weekly_median = model_data["weekly_sum"][0.50]
    death_median = model_data["weekly_deaths"][0.50]
    seropos_median = model_data["seropos"][0.50]
    suscept_median = model_data["susceptible"][0.50]
    rt_median = model_data["R"][0.50]
    transmission_median = model_data["transmission potential"][0.50]
    
    
    # Define elements needed for uncertainty plots
    x_vals = model_data.index.to_list() + model_data.index[::-1].to_list()
    y_vals_weekly = model_data["weekly_sum"][0.05].to_list() + model_data["weekly_sum"][0.95][::-1].to_list()
    y_vals_deaths = model_data["weekly_deaths"][0.05].to_list() + model_data["weekly_deaths"][0.95][::-1].to_list()
    y_vals_seropos = model_data["seropos"][0.05].to_list() + model_data["seropos"][0.95][::-1].to_list()
    y_vals_suscept = model_data["susceptible"][0.05].to_list() + model_data["susceptible"][0.95][::-1].to_list()
    y_vals_R = model_data["R"][0.05].to_list() + model_data["R"][0.95][::-1].to_list()
    y_vals_transmission = model_data["transmission potential"][0.05].to_list() + model_data["transmission potential"][0.95][::-1].to_list()
    
                                 
    # Create subplot
    fig = make_subplots(3,2, shared_xaxes=True,  subplot_titles=('Weekly cases', 'Weekly deaths', 'Seroprevalence', 'Susceptible population',
                                                                 'Rt','Transmission potential'),
                       horizontal_spacing = 0.05, vertical_spacing = 0.05)
    
    # Add modelled case notifications median line
    fig.add_trace(go.Scatter(x=model_index, y=weekly_median, mode="lines", name="Modelled cases", marker_color='#636EFA' ), row=1, col=1)
    # Add modelled case notifications uncertainty
    fig.add_trace(go.Scatter(x=x_vals, y=y_vals_weekly, mode="lines", name="Modelled cases", line={"width": 0.0, "color": '#636EFA'}, fill='toself',
                             showlegend=False ), row=1, col=1)
    # Add case notifications
    fig.add_trace(go.Scatter(x=case_index, y=cases,  mode="markers", name="Reported cases", marker_color="black" ), row=1, col=1)

    # Add modelled death
    fig.add_trace(go.Scatter(x=model_index, y=death_median, mode="lines", name="Deaths", marker_color='#636EFA' ), row=1, col=2)
    # Add modelled case notifications uncertainty
    fig.add_trace(go.Scatter(x=x_vals, y=y_vals_deaths, mode="lines", name="Deaths", line={"width": 0.0, "color": '#636EFA'}, fill='toself',
                             showlegend=False ), row=1, col=2)
    # Add death
    fig.add_trace(go.Scatter(x=death_index, y=deaths,  mode="markers", name="Reported deaths", marker_color="black" ), row=1, col=2)

    # Add modelled seroprev
    fig.add_trace(go.Scatter(x=model_index, y=seropos_median, mode="lines", name="Seroprevalence", marker_color='#636EFA' ), row=2, col=1)
    # Add modelled case notifications uncertainty
    fig.add_trace(go.Scatter(x=x_vals, y=y_vals_seropos, mode="lines", name="Seroprevalence", line={"width": 0.0, "color": '#636EFA'}, fill='toself',
                             showlegend=False ), row=2, col=1)
    # Add seroprev
    fig.add_trace(go.Scatter(x=seroprev_index, y=seroprev,  mode="markers", name="Seroprevalence", marker_color="black" ), row=2, col=1)

    # Add susceptible median line
    fig.add_trace(go.Scatter(x=model_index, y=suscept_median, mode="lines", name="Susceptible", marker_color='#EF553B' ), row=2, col=2)
    # Add susceptible uncertainty 
    fig.add_trace(go.Scatter(x=x_vals, y=y_vals_suscept, mode="lines", name="Susceptible", line={"width": 0.0, "color": '#EF553B'}, fill='toself',
                             showlegend=False ), row=2, col=2)
    
    # Add Rt median line
    fig.add_trace(go.Scatter(x=model_index, y=rt_median, mode="lines", name="Rt", marker_color='#00CC96' ), row=3, col=1)
    # Add Rt uncertainty 
    fig.add_trace(go.Scatter(x=x_vals, y=y_vals_R, mode="lines", name="Rt", line={"width": 0.0, "color": '#00CC96'}, fill='toself', showlegend=False ), row=3,
                  col=1)

    # Add transmission potential median line
    fig.add_trace(go.Scatter(x=model_index, y=transmission_median , mode="lines", name="Transmission potential", marker_color='#AB63FA' ), row=3, col=2)
    # Add transmission potential uncerainty 
    fig.add_trace(go.Scatter(x=x_vals, y=y_vals_transmission, mode="lines", name="Transmission potential", line={"width": 0.0, "color": '#AB63FA'},
                             fill='toself', showlegend=False ), row=3, col=2)
    
                                 
    return fig.update_layout(height=1000, width=1200)