In [3]:
import datetime
import os
from math import ceil
import plotly.express as px
from plotly.subplots import make_subplots
import plotly.graph_objects as go


from summer.utils import ref_times_to_dti

from autumn.settings.constants import COVID_BASE_DATETIME
from autumn.core.runs import ManagedRun
from autumn.core.project import get_project, load_timeseries



In [4]:
region = "northern_territory"

In [5]:
run_id = "sm_sir/northern_territory/1662380460/28ff0b7"

In [6]:
mr = ManagedRun(run_id)

In [7]:
pbi = mr.powerbi.get_db()
targets = pbi.get_targets()
results = pbi.get_uncertainty()
mcmc_params = mr.calibration.get_mcmc_params()
n_params = mcmc_params.shape[1]
mcmc_runs = mr.calibration.get_mcmc_runs()
chains = mcmc_runs.chain.unique()
mcmc_table = mcmc_params.merge(mcmc_runs, on=["urun"])
full_run = mr.full_run.get_derived_outputs()

In [8]:
project = get_project("sm_sir", region, reload=True)

In [9]:
# Get target data used for the model
project_root = os.path.abspath(os.path.join(os.getcwd(), "..", "..", "..", ".."))
model, country, run, commit = run_id.split("/")
project_file_path = os.path.join(project_root, "autumn", "projects", model, "australia", country, "timeseries.secret.json")
all_targets = load_timeseries(project_file_path)
for target in all_targets:
    all_targets[target].index = ref_times_to_dti(COVID_BASE_DATETIME, all_targets[target].index)

In [219]:
for s in results.columns.unique('scenario'):
    for q in results.columns.unique('quantile'):
        results['hospital_admissions', s, q] = results['hospital_admissionsXindigenous', s, q] + results['hospital_admissionsXnon_indigenous', s, q]

In [245]:
results.columns

MultiIndex([('cumulative_hospital_admissionsXindigenous', 0, 0.025),
            ('cumulative_hospital_admissionsXindigenous', 0,  0.25),
            ('cumulative_hospital_admissionsXindigenous', 0,   0.5),
            ('cumulative_hospital_admissionsXindigenous', 0,  0.75),
            ('cumulative_hospital_admissionsXindigenous', 0, 0.975),
            ('cumulative_hospital_admissionsXindigenous', 1, 0.025),
            ('cumulative_hospital_admissionsXindigenous', 1,  0.25),
            ('cumulative_hospital_admissionsXindigenous', 1,   0.5),
            ('cumulative_hospital_admissionsXindigenous', 1,  0.75),
            ('cumulative_hospital_admissionsXindigenous', 1, 0.975),
            ...
            (                      'hospital_admissions', 3, 0.025),
            (                      'hospital_admissions', 3,  0.25),
            (                      'hospital_admissions', 3,   0.5),
            (                      'hospital_admissions', 3,  0.75),
            (     

In [10]:
title_lookup = {
    "notifications": "daily notifications",
    "infection_deaths": "COVID-19-specific deaths",
    "hospital_admissions": "new daily hospital admissions",
    "icu_admissions": "new daily admissions to ICU",
    "proportion_seropositive": "proportion recovered from COVID-19",
    "incidence": "daily new infections",
    "prop_incidence_strain_delta": "proportion of cases due to Delta",
    "hospital_admissions": "daily hospital admissions",
    "hospital_occupancy": "total hospital beds",
    "icu_admissions": "daily ICU admissions",
    "icu_occupancy": "total ICU beds",
    "prop_ever_infected": "ever infected with Delta or Omicron",
    "cumulative_infection_deaths": "cumulative COVID-19 deaths",
    "cumulative_hospital_admissions": "cumulative hospital admissions",
    "cumulative_icu_admissions": "cumulative ICU admissions",    
}

In [289]:
def outputs_with_uncertainty(outputs, scenarios,
                             plot_50_CI=True,
                             plot_95_CI=True,
                             input_shape=None,
                             start_date=None,
                             end_date=None):
    
    # Set row and col position for outputs based on requested outputs
    # if use defines own shape/layout
    if input_shape:
        shape = input_shape
    # Otherwise define shape based on requested outputs
    else:
        if len(outputs) == 4:
            shape = [2, 2]
        elif len(outputs) == 3:
            shape = [1, 3]
        elif len(outputs) == 2:
            shape = [1, 2]
        elif len(outputs) == 1:
            shape = [1, 1]
    
    # Store number of rows and columns
    rows = shape[0]
    cols = shape[1]
    
    # Create all [row, col] positions
    positions = []
    for r in range(rows):
        for c in range(cols):
            positions.append([r+1, c+1])
    
    fig = make_subplots(rows, cols, 
                        subplot_titles=[title_lookup[o] for o in outputs])
    
    # Define colours for scenarios as partial rgba strings to allow for specifying opacity with alpha
    colours = {0: 'rgba(26,150,65,',
               1: 'rgba(248,156,116,',
               2: 'rgba(204, 102, 119,',
               3: 'rgba(29, 105, 150,',
               4: 'rgba(221, 204, 119,'
              }
    
    for pos, o in zip(positions, outputs):
        for s in scenarios:
            colour = colours[s]
            
            # Set transparency
            median_alpha = 1.0
            ci_50_alpha = 0.4
            ci_95_alpha = 0.2
            
            # Set data, index and labels
            results_df = results[(o, s)]
            indices = results_df.index
            label = "baseline " if s == 0 else project.param_set.scenarios[s - 1]["description"]+' '
            show_legend = True if pos[0] == 1 and pos[1] == 1 else False
            
            # add median trace
            fig.add_trace(go.Scatter(x=indices, y=results_df[0.500], 
                                     line=dict(color=(colour+str(median_alpha)+")")), 
                                     name=label,
                                     showlegend = show_legend
                                    ),
                          row = pos[0], col= pos[1]
                         )
            
            #50% CI trace
            if plot_50_CI:
                fig.add_traces([go.Scatter(x=results_df.index, y=results_df[0.250], 
                                           line=dict(color=colour+str(ci_50_alpha)+")"),
                                           fillcolor=colour+str(ci_50_alpha)+")", showlegend=False),
                                go.Scatter(x=results_df.index, y=results_df[0.750], 
                                           line=dict(color=colour+str(ci_50_alpha)+")"), 
                                           fillcolor=colour+str(ci_50_alpha)+")", 
                                           fill="tonexty", showlegend=False)
                               ],
                               rows = pos[0], cols = pos[1]
                              )
            
            #95% CI trace
            if plot_95_CI:
                fig.add_traces([go.Scatter(x=results_df.index, y=results_df[0.025], 
                                           line=dict(color=colour+str(ci_95_alpha)+")"), 
                                           fillcolor=colour+str(ci_95_alpha)+")", showlegend=False),
                                go.Scatter(x=results_df.index, y=results_df[0.975], 
                                           line=dict(color=colour+str(ci_95_alpha)+")"),
                                           fillcolor=colour+str(ci_95_alpha)+")", 
                                           fill="tonexty", showlegend=False)
                               ],
                               rows = pos[0], cols = pos[1]
                              )
            # Add plots for targets used for calibration  
            if o in targets:
                    fig.add_trace(go.Scatter(x=targets.index, y=targets[o],
                                            mode='markers',
                                            marker=dict(color="rgba(0,0,0, 0.8)", size=3),
                                            showlegend=False),
                                row = pos[0], col= pos[1])
             
            
            if o in all_targets and len(all_targets[o]) > 0:
                fig.add_trace(go.Scatter(x=all_targets[o].index, y=all_targets[o],
                                            mode='markers',
                                            marker=dict(color="rgba(255,255,255, 0.4)", size=4),
                                            showlegend=False),
                                row = pos[0], col= pos[1])

    # Update x_axis
    fig.update_xaxes(range=[start_date, end_date])
    
    # Legend position - place in right hand corner
    fig.update_layout(legend_x=-0.05, legend_y=1.3, legend_orientation='h',
                      # Need to adjust height and width
                    height=400, width=1000)
    
    # Update image config to allow for higher resolution image to be quick downloaded
    config = {
        'toImageButtonOptions': {
            'format': 'png',
            'filename': 'image_1',
            'height': 400,
            'width': 1000,
            'scale':4 #adjust scale to increase resolution
        }
    }


    fig.show(config=config)
    
   

In [291]:
# Specify scenarios here
scenarios = [0, 4]

# Specify outputs here up to 4: function will determine shape based on number of outputs specified
outputs = ('notifications', "hospital_admissions",) #"hospital_occupancy",'icu_admissions', "infection_deaths")

# Specify start and end dates
s = "2021-08-12"
e = '2023-06-01'

outputs_with_uncertainty(outputs, scenarios,  
                         start_date=s,
                         end_date=e)

In [262]:
# trial of pulling out targets and plotting

fig_2 = make_subplots()
outputs = ('notifications', "hospital_admissions", "hospital_occupancy",)

for output in outputs:
    if output in targets:
        fig_2.add_traces(go.Scatter(x=targets.index, y=targets[output], mode='markers'))
        
    if output in all_targets and len(all_targets[output]) > 0:
        fig_2.add_trace(go.Scatter(x=all_targets[output].index, y=all_targets[output],
                                 mode='markers',
                                 marker=dict(color="rgba(255,255,255, 0.3)", size=3),
                                 showlegend=False))
    
fig_2.show()