In [None]:
import arviz as az
import pandas as pd
import numpy as np

import plotly.graph_objects as go
import plotly.express as px

from estival.model import BayesianCompartmentalModel
from estival import targets as est
from estival import priors as esp

from autumn.core.project import get_project

from autumn.projects.sm_covid2.common_school.calibration import get_bcm_object

from datetime import datetime
from copy import copy

COVID_BASE_DATETIME = datetime(2019, 12, 31)

In [None]:
calibration_file = "calib_11May.nc"

In [None]:

idata = az.from_netcdf(calibration_file)
chain_length = idata.sample_stats.sizes['draw']
n_chains = idata.sample_stats.sizes['chain']
print(f"Found {n_chains} chains, each containing {chain_length} samples.")

In [None]:
burn_in = 0

In [None]:
burnt_idata = idata.sel(draw=range(burn_in, chain_length))  # Discard burn-in
calib_df = burnt_idata.to_dataframe(groups="posterior")  # Also get as dataframe

In [None]:
# Report acceptance ratios
(idata.sample_stats.accepted.sum(axis=1) / idata.sample_stats.coords["draw"].size).to_dataframe()

In [None]:
az.summary(idata)

In [None]:
az.plot_trace(burnt_idata, figsize=(16, 3.0 * len(burnt_idata.posterior)), compact=False, legend=True);

In [None]:
az.plot_posterior(burnt_idata);

In [None]:

def convert_idata_to_df(
    idata: az.data.inference_data.InferenceData, 
    param_names: list,
) -> pd.DataFrame:
    """
    Convert arviz inference data to dataframe organised
    by draw and chain through multi-indexing.
    
    Args:
        idata: arviz inference data
        param_names: String names of the model parameters
    """
    sampled_idata_df = idata.to_dataframe()[param_names]
    return sampled_idata_df.sort_index(level="draw").sort_index(level="chain")


def get_sampled_results(sampled_df, outputs):
    d2_index = pd.Index([index[:2] for index in sampled_df.index]).unique()

    sampled_results = {output: pd.DataFrame(index=bcm.model._get_ref_idx(), columns=d2_index) for output in outputs}

    for chain, draw in d2_index:
        # read rp delta values
        delta_values = sampled_df.loc[chain, draw]['random_process.delta_values']
        
        params_dict = sampled_df.loc[chain, draw, 0].to_dict()
        params_dict["random_process.delta_values"] = np.array(delta_values)

        run_model = bcm.run(params_dict)

        for output in outputs:
            sampled_results[output][(chain, draw)] = run_model.derived_outputs[output]

    return sampled_results


def plot_from_model_runs_df(
    model_results, 
    output_name
) -> go.Figure:
    """
    Create interactive plot of model outputs by draw and chain
    from standard data structures.
    
    Args:
        model_results: Model outputs generated from run_samples_through_model
        sampled_df: Inference data converted to dataframe in output format of convert_idata_to_df
    """
    melted = model_results[output_name].melt(ignore_index=False)
    melted.columns = ["chain", "draw", output]
    melted.index = (melted.index  - COVID_BASE_DATETIME).days

    fig = px.line(melted, y=output, color="chain", line_group="draw", hover_data=melted.columns)

    if output_name in bcm.targets:
        fig.add_trace(
            go.Scattergl(x=bcm.targets[output_name].data.index, y=bcm.targets[output_name].data, marker=dict(color="black"), name="target", mode="markers"),
        )

    return fig

In [None]:
bcm = get_bcm_object("france")

In [None]:
param_names = list(burnt_idata.posterior.data_vars.keys())

num_samples_request = 50
sampled_idata = az.extract(burnt_idata, num_samples=num_samples_request)  # Sample from the inference data
sampled_df = convert_idata_to_df(sampled_idata, param_names)

In [None]:
outputs = ["infection_deaths", "transformed_random_process"]
sampled_results = get_sampled_results(sampled_df, outputs)

In [None]:
for output in outputs:
    fig = plot_from_model_runs_df(sampled_results, output)
    fig.show()