## Imports

In [None]:
from IPython.display import display
import pandas as pd
import pickle
from pathlib import Path
from plotly.subplots import make_subplots
import plotly.express as px
from jinja2 import Template

from autumn.projects.sm_covid2.common_school.output_plots import country_highlight as ch


import exploration_tools as et
pd.options.plotting.backend = "plotly"

file = open('included_countries.pickle', 'rb')
included_countries = pickle.load(file)

In [None]:
do_names = {
    'infection_deaths': 'COVID-19 mortality',

    'incidence': 'COVID-19 incidence',
    'cumulative_incidence': 'Cumulative COVID-19 incidence',
    'hospital_occupancy': 'Hospital occupancy indicator',
    'cumulative_infection_deaths': 'Cumulative COVID-19 mortality',
  
    'ever_infected': 'Cumulative SARS-CoV-2 infections',
    'prop_ever_infected': 'Proportion ever infected with SARS-CoV-2',
  
    'incidenceXagegroup_0': 'COVID-19 incidence (age 0-14)', 
    'incidenceXagegroup_15': 'COVID-19 incidence (age 15-24)',
    'incidenceXagegroup_25': 'COVID-19 incidence (age 25-49)', 
    'incidenceXagegroup_50': 'COVID-19 incidence (age 50-69)',
    'incidenceXagegroup_70': 'COVID-19 incidence (age 70 and above)',
   
    'cumulative_incidenceXstrain_delta': 'Cumulative incidence (Delta variant)',
    'cumulative_incidenceXstrain_omicron': 'Cumulative incidence (Omicron variant)',
    'cumulative_incidenceXstrain_wild_type': 'Cumulative incidence (Wild-type virus)',
    
    'prop_immune_vaccinated': 'Proportion fully vaccinated', 
    'prop_immune_unvaccinated': 'Proportion unvaccinated',
    
    'transformed_random_process': 'Random-process-based transmission adjustment',
}

In [None]:
def write_html_output_file(iso3, analysis):
    folder_path = Path.cwd() / analysis / iso3
    
    uncertainty_dfs, diff_quantiles_df, derived_outputs = et.load_analysis_outputs(folder_path)

    for sc in ['baseline', 'scenario_1']:
        derived_outputs[sc] = derived_outputs[sc][list(do_names.keys())]
        uncertainty_dfs[sc] = uncertainty_dfs[sc].iloc[:, uncertainty_dfs[sc].columns.get_level_values(0).isin(list(do_names.keys()))]

    bdf = derived_outputs["baseline"].round(decimals=4)
    sdf = derived_outputs["scenario_1"].round(decimals=4)

    ### CREATE main Figure
    cols = derived_outputs["baseline"].columns

    DX_1DAY = 24*60*60*1000

    output = "infection_deaths"

    fig = make_subplots(rows=3, cols=1, shared_xaxes=True,subplot_titles=("MLE Comparison","Historical uncertainty","Counterfactual (schools open) uncertainty"))
    fig.add_scatter(x0=sdf.index[0], dx=DX_1DAY, y=bdf[output], mode='lines', name="Historical",row=1,col=1)
    fig.add_scatter(x0=sdf.index[0], dx=DX_1DAY, y=sdf[output], mode='lines', name="Counterfactual (schools open)",row=1,col=1)

    ubdf = uncertainty_dfs["baseline"].round(decimals=4)
    usdf = uncertainty_dfs["scenario_1"].round(decimals=4)

    for sc in uncertainty_dfs:
        df = uncertainty_dfs[sc][output]

        if sc == 'baseline':
            light_shade = "lightskyblue"
            dark_shade = "cornflowerblue"
            urow = 2
        else:
            light_shade = "pink"
            dark_shade = "coral"
            urow = 3
            
        colour_map = {
            "0.025": light_shade,
            "0.25": dark_shade,
            "0.5": "black",
            "0.75": dark_shade,
            "0.975": light_shade
        }

        fig.add_scatter(x0=df.index[0], dx=DX_1DAY, y = df['0.025'],
                                    line = dict(width=0, color=light_shade),
                                    #fill='tonexty', 
                                    #fillcolor =light_shade,
                                    showlegend=False,
                                    name='95% CI',
                                    #hoverinfo='skip',
                                    row=urow,col=1)

        fig.add_scatter(x0=df.index[0], dx=DX_1DAY, y = df['0.25'],
                                    line = dict(width=0, color=dark_shade),
                                    fill='tonexty',
                                    fillcolor =light_shade,
                                    name="25%",
                                    showlegend=False,
                                    #hoverinfo='skip',
                                    row=urow,col=1)

        fig.add_scatter(x0=df.index[0], dx=DX_1DAY, y = df['0.75'],
                                    line = dict(width=0, color=dark_shade),
                                    fill='tonexty', 
                                    fillcolor =dark_shade,
                                    showlegend=False,
                                    name="75%",
                                    row=urow,col=1
                                    )

        fig.add_scatter(x0=df.index[0], dx=DX_1DAY, y = df['0.975'],
                                    line = dict(width=0, color=light_shade),
                                    fill='tonexty', 
                                    fillcolor =light_shade,
                                    showlegend=False,
                                    name='95% CI',
                                    row=urow,col=1
                                    )

        fig.add_scatter(x0=df.index[0], dx=DX_1DAY, y = df['0.5'],
                                    line = dict(width=2, color="black"),
                                    name="median",
                                    showlegend=False,
                                    row=urow,col=1
                                )


    #The trace restyling  to be performed at an option selection in the first/second dropdown menu
    # is defined within  buttons1/buttons2 below:

    buttons1 = [dict(method = "restyle",
                    args = [{'y': [bdf[cols[k]], sdf[cols[k]]] + [ubdf[cols[k]][q] for q in ubdf["incidence"].columns] + [ubdf[cols[k]][q] for q in ["0.025","0.25","0.75","0.5"]] +
                            [usdf[cols[k]][q] for q in usdf["incidence"].columns] + [usdf[cols[k]][q] for q in ["0.025","0.25","0.75","0.5"]],
                            #'x': [sdf.index] * 20,
                            'visible': [True] * 20}], 
                    label = do_names[cols[k]])   for k in range(len(cols))]

    buttons1 = [dict(method = "restyle",
                    args = [{'y': [bdf[cols[k]], sdf[cols[k]]] + [ubdf[cols[k]][q] for q in ["0.025","0.25","0.75","0.975", "0.5"]] +
                            [usdf[cols[k]][q] for q in ["0.025","0.25","0.75","0.975", "0.5"]],
                            #'x': [sdf.index] * 12,
                            'visible': [True] * 12}], 
                    label = do_names[cols[k]])   for k in range(len(cols))]


    button_layer_1_height = 1.08

    fig.update_layout(width=1200,
                    height=800,
                    #title=COUNTRY,
                    
                    updatemenus=[dict(active=0,
                                        buttons=buttons1,
                                                        direction="down",
                    pad={"r": 0, "t": 20, "b": 0.0},
                    showactive=True,
                    x=0.0,
                    xanchor="left",
                    y=1.08,#button_layer_1_height,
                    yanchor="top")]
    )

    fig.update_layout(xaxis_showticklabels=True, xaxis2_showticklabels=True)

    fig.update_layout(legend=dict(
        orientation="h",
        yanchor="bottom",
        y=1.02,
        xanchor="right",
        x=1
    ))

    ### Make table
    reldf = et.get_diff_quantiles_table(diff_quantiles_df)


    ### Write html file
    output_html_path= Path.cwd() / "html_outputs" / f"{iso3}_{analysis}.html"
    input_template_path = "template.html"

    plotly_jinja_data = {"fig":fig.to_html(full_html=False, include_plotlyjs="cdn"), "table": reldf.to_html(), "country": included_countries[iso3] , "analysis": et.analysis_names[analysis]}

    with open(output_html_path, "w", encoding="utf-8") as output_file:
        with open(input_template_path) as template_file:
            j2_template = Template(template_file.read())
            output_file.write(j2_template.render(plotly_jinja_data))


In [None]:
for iso3 in included_countries:
    for analysis in et.analysis_names:
        print(f"Writing html for {iso3} / {analysis}")
        write_html_output_file(iso3, analysis)

### Write index table 

In [None]:
for iso3 in included_countries:
    string = f"| {included_countries[iso3]} |"
    for analysis, analysis_desc in et.analysis_names.items():
        string += f"[{analysis_desc}]({iso3}_{analysis}.html) |"

    print(string)