In [1]:
import functools

import click
import numpy as np
import pandas as pd

from covid_model_seiir_pipeline.lib import (
    cli_tools,
    aggregate,
    parallel,
)
from covid_model_seiir_pipeline.pipeline.forecasting import model
from covid_model_seiir_pipeline.side_analysis.oos_holdout.specification import (
    OOSHoldoutSpecification,
    OOS_HOLDOUT_JOBS,
)
from covid_model_seiir_pipeline.side_analysis.oos_holdout.data import (
    OOSHoldoutDataInterface,
)


logger = cli_tools.task_performance_logger


In [2]:
def build_dataset(data_interface, dates, scenario, draw_id):
    ref = data_interface.load_raw_outputs(scenario, draw_id)
    oos = data_interface.load_oos_outputs(draw_id)
    hierarchy = data_interface.load_hierarchy('pred')
    pop = data_interface.load_population('total')

    cols = {
        'modeled_infections_total': 'infections',
        'modeled_deaths_total': 'deaths',
        'modeled_admissions_total': 'admissions',
        'modeled_cases_total': 'cases',
        'beta': 'beta'
    }
    data = {
        'reference': ref,
        'oos': oos,
        'delta': ref - oos,
    }

    idx_cols = ['location_id', 'date', 'measure']
    output = []
    for key, df in data.items():
        df = df.loc[:, list(cols.keys())].rename(columns=cols)
        df = df.stack().reset_index()
        df.columns = idx_cols + [key]
        df = df.set_index(idx_cols)[key].sort_index()
        output.append(df)

    data = pd.concat(output, axis=1).reset_index().merge(dates.reset_index(), on='location_id')
    in_window = (
        (data['plot_start'] <= data['date']) 
        & (data['date'] <= data['plot_end'])
    )
    data = data.loc[in_window].drop(columns=dates.columns).set_index(idx_cols).unstack().reorder_levels([1, 0], axis=1).sort_index(axis=1)   
    
    data = aggregate.sum_aggregator(data, hierarchy, pop)
    return data


def load_boundary_dates(data_interface, specification, draw_id):
    location_ids = data_interface.load_location_ids()
    holdout_days = pd.Timedelta(days=specification.parameters.holdout_weeks * 7)
    past_compartments = data_interface.load_past_compartments(draw_id).loc[location_ids]
    past_compartments = past_compartments.loc[past_compartments.notnull().any(axis=1)]
    beta_fit_end_dates = past_compartments.reset_index(level='date').date.groupby('location_id').max()
    beta_fit_oos_end_dates = beta_fit_end_dates - holdout_days
    # We want the forecast to start at the last date for which all reported measures
    # with at least one report in the location are present.
    all_measures_present = past_compartments[
        [c for c in past_compartments if c.split('_')[0] in ['Death', 'Admission', 'Case']]
    ].notnull().all(axis=1)
    forecast_start_dates = (past_compartments
                            .loc[all_measures_present]
                            .reset_index(level='date')
                            .date.groupby('location_id')
                            .max())
    return beta_fit_oos_end_dates, forecast_start_dates


def get_plot_start_end(data_interface,
                       specification):    
    holdout_days = pd.Timedelta(days=specification.parameters.holdout_weeks * 7)
    beta_fit_oos_end_dates, forecast_start_dates = load_boundary_dates(data_interface, specification, 0)
    
    leading_days = pd.Timedelta(days=120)
    trailing_days = pd.Timedelta(days=holdout_days.days//2)    
    start = (beta_fit_oos_end_dates - leading_days).rename('plot_start')
    end = (forecast_start_dates + holdout_days + trailing_days).rename('plot_end')
    return pd.concat([start, end], axis=1)
    


def load_dates(data_interface,
               specification,
               draw_id):
    plot_days = get_plot_start_end(data_interface, specification)
    beta_fit_oos_end_dates, forecast_start_dates = load_boundary_dates(data_interface, specification, draw_id)
    holdout_days = pd.Timedelta(days=specification.parameters.holdout_weeks * 7)
    oos_start = beta_fit_oos_end_dates.rename('oos_start')
    oos_end = (forecast_start_dates + holdout_days).rename('oos_end')
    return pd.concat([plot_days, oos_start, oos_end], axis=1)


def build_oos_metrics(data_interface,
                      data,
                      dates,
                      draw_id):
    hierarchy = data_interface.load_hierarchy('pred')
    pop = data_interface.load_population('total')
    params = data_interface.load_fit_ode_params(draw_id).filter(like='exposure').iloc[0]
    delay = {
        'infections': 0,
        'deaths': params.loc['exposure_to_death'],
        'admissions': params.loc['exposure_to_admission'],
        'cases': params.loc['exposure_to_case'],
        'beta': 0,
    }

    err = []
    for measure, lag in delay.items():
        measure_dates = (dates + pd.Timedelta(days=lag)).reset_index()
        measure_data = data.loc[:, measure].reset_index().merge(measure_dates, on='location_id')
        in_oos_frame = (
            (measure_data['oos_start'] <= measure_data['date'])
            & (measure_data['date'] <= measure_data['oos_end'])
        )
        oos_data = measure_data.loc[in_oos_frame].set_index(['location_id', 'date'])

        delta = oos_data['delta'].groupby('location_id').sum().to_frame()
        delta['date'] = 1
        delta = aggregate.sum_aggregator(delta.set_index('date', append=True), hierarchy, pop).reset_index(level='date', drop=True)['delta']
        ref = oos_data['reference'].groupby('location_id').sum().to_frame()
        ref['date'] = 1
        ref = aggregate.sum_aggregator(ref.set_index('date', append=True), hierarchy, pop).reset_index(level='date', drop=True)['reference']              
        p = 100 * delta / ref

        measure_err = pd.concat([
            delta.rename('Total Error'),
            np.abs(delta).rename('Absolute Error'),                        
            p.rename('Percent Error'),
            np.abs(p).rename('Absolute Percent Error'),            
        ], axis=1).stack().reset_index()
        measure_err.columns = ['location_id', 'metric', 'value']
        measure_err['measure'] = measure
        err.append(measure_err)    
    return pd.concat(err).set_index(['location_id', 'metric', 'measure'])['value'].rename(f'draw_{draw_id}')


def build_oos_data(draw_id,
                   oos_holdout_version,
                   scenario):
    specification = OOSHoldoutSpecification.from_version_root(oos_holdout_version)
    data_interface = OOSHoldoutDataInterface.from_specification(specification)

    scenario = specification.data.seir_forecast_scenario    

    dates = load_dates(data_interface, specification, draw_id)
    data = build_dataset(data_interface, dates, scenario, draw_id)    
    metrics = build_oos_metrics(data_interface, data, dates, draw_id)
    
    data = data.stack().stack().reset_index()
    data.columns = ['location_id', 'date', 'measure', 'version', f'draw_{draw_id}']
    data = data.set_index(['location_id', 'measure', 'version', 'date'])[f'draw_{draw_id}'].sort_index()

    return data, metrics, dates


In [3]:
oos_holdout_version = '/ihme/covid-19/seir-oos-analysis/2022_05_22.02/'
scenario = 'reference'

In [4]:
%%time

df, metrics, dates = build_oos_data(0, oos_holdout_version, scenario)

CPU times: user 23.7 s, sys: 7.26 s, total: 30.9 s
Wall time: 33.3 s


In [5]:
%%time 

runner = functools.partial(build_oos_data, oos_holdout_version=oos_holdout_version, scenario=scenario)
arg_list = list(range(100))

dfs = parallel.run_parallel(
    runner,
    arg_list=arg_list,
    num_cores=26,
    progress_bar=True,
)

 ... (more hidden) ...

CPU times: user 8.38 s, sys: 11.3 s, total: 19.7 s
Wall time: 4min 12s





In [6]:
%%time 
data, metrics, dates = zip(*dfs)
data = aggregate.summarize(pd.concat(data, axis=1))
metrics = pd.concat(metrics, axis=1)

CPU times: user 1min 15s, sys: 3.3 s, total: 1min 18s
Wall time: 1min 18s


In [7]:
dates = dates[0].reindex(metrics.reset_index().location_id.unique())
for c in dates:
    fill = dates[c].min() if 'start' in c else dates[c].max()
    dates[c] = dates[c].fillna(fill)

In [8]:
%%time 
specification = OOSHoldoutSpecification.from_version_root(oos_holdout_version)
data_interface = OOSHoldoutDataInterface.from_specification(specification)

hierarchy = data_interface.load_hierarchy('pred') 
name_map = data_interface.load_hierarchy('pred').set_index('location_id').location_name
inputs = [
    (location_id, name_map.loc[location_id], data.loc[location_id], metrics.loc[location_id], dates.loc[location_id], hierarchy)
    for location_id in data.reset_index().location_id.unique()
]


CPU times: user 1.1 s, sys: 10.2 ms, total: 1.11 s
Wall time: 1.11 s


In [19]:
aggregate.summarize(metrics).loc[1].loc['Percent Error']

Unnamed: 0_level_0,mean,upper,lower
measure,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1
infections,17.484087,37.438804,-15.309616
deaths,10.582283,35.971891,-31.017571
admissions,11.016187,35.114426,-31.666891
cases,-0.984777,30.950364,-53.618969
beta,19.709617,31.284307,8.577656


In [9]:
def make_plot(inp, output_dir=None):
    from pathlib import Path
    import matplotlib.pyplot as plt
    import matplotlib.dates as mdates
    import matplotlib.lines as mlines
    import numpy as np
    import pandas as pd
    import seaborn as sns

    sns.set_style('whitegrid')


    FILL_ALPHA = 0.2
    OBSERVED_ALPHA = 0.3
    AX_LABEL_FONTSIZE = 14
    TITLE_FONTSIZE = 24
    FIG_SIZE = (30, 15)
    GRID_SPEC_MARGINS = {'top': 0.92, 'bottom': 0.08}


    def make_axis_legend(axis, elements: dict):
        handles = [mlines.Line2D([], [], label=e_name, linewidth=2.5, **e_props)
                   for e_name, e_props in elements.items()]
        axis.legend(handles=handles,
                    loc='upper left',
                    fontsize=AX_LABEL_FONTSIZE,
                    frameon=False)

    def format_date_axis(ax, start=None, end=None):
        date_locator = mdates.AutoDateLocator(maxticks=15)
        date_formatter = mdates.ConciseDateFormatter(date_locator, show_offset=False)
        ax.set_xlim(start, end)
        ax.xaxis.set_major_locator(date_locator)
        ax.xaxis.set_major_formatter(date_formatter)

    def write_or_show(fig, plot_file: str):
        if plot_file:
            fig.savefig(plot_file, bbox_inches='tight')
            plt.close(fig)
        else:
            plt.show()
    
    location_id, location_name, data, metrics, dates, h = inp
    
    # Configure the plot layout.
    fig = plt.figure(figsize=FIG_SIZE, tight_layout=True)
    grid_spec = fig.add_gridspec(
        nrows=1, ncols=2,
        width_ratios=[5, 2],
        wspace=0.1,
    )
    grid_spec.update(**GRID_SPEC_MARGINS)

    gs_line = grid_spec[0, 0].subgridspec(5, 2)
    gs_hist = grid_spec[0, 1].subgridspec(5, 2)


    mmap = {
        'infections': ('k', 0),
        'deaths': ('darkred', 26),
        'cases': ('darkgreen', 12),
        'admissions': ('navy', 12),
        #'beta': ('darkgoldenrod', 0),
    }
    
    
    start, end = dates['plot_start'], dates['plot_end']
    oos_start, oos_end = dates['oos_start'], dates['oos_end']
    
    align = []
    for row, (measure, (color, lag)) in enumerate(mmap.items()):
        for col, plot_type in enumerate(['abs', 'delta']):
            ax = fig.add_subplot(gs_line[row, col])
            if measure == 'beta' and not h.loc[location_id, 'most_detailed'] == 1:
                continue
            if plot_type == 'abs':
                ref = data.loc[('reference', measure)]
                oos = data.loc[('oos', measure)]
                ax.plot(ref.index, ref['mean'], linewidth=3, color=color)
                ax.fill_between(ref.index, ref['upper'], ref['lower'],
                                alpha=0.2, color=color)
                ax.plot(oos.index, oos['mean'], linewidth=3, color=color, linestyle='dashed')
                ax.fill_between(oos.index, oos['upper'], oos['lower'],
                                alpha=0.2, color=color)
                
                ax.axvline(oos_start + pd.Timedelta(days=lag), linewidth=3, color='darkslategrey')
                ax.set_ylabel(measure.title(), fontsize=16)
                make_axis_legend(ax, {'forecast': {'linestyle': 'solid', 'color': color},
                                      'oos': {'linestyle': 'dashed', 'color': color}})
                if row == 0:
                    ax.set_title('Daily Values', fontsize=20)
            else:
                delta = data.loc[('delta', measure)]
                ax.plot(delta.index, delta['mean'], linewidth=3, color=color)
                ax.fill_between(delta.index, delta['upper'], delta['lower'],
                                alpha=0.3, color=color)
                
                if row == 0:
                    ax.set_title('Daily Difference', fontsize=20)
            ax.axvline(oos_start + pd.Timedelta(days=lag), linewidth=3, color='darkslategrey')
            ax.axvline(oos_end + pd.Timedelta(days=lag), linewidth=3, color='darkslategrey')
            format_date_axis(ax, start, end)
            if col == 0:
                align.append(ax)
    fig.align_ylabels(align)
    
    metrics = metrics.reorder_levels(['measure', 'metric'])
    
    for row, measure in enumerate(mmap):
        for col in [0, 1]:
            cols = {0: ['Total Error'], 1: ['Percent Error']}[col]
            ax = fig.add_subplot(gs_hist[row, col])
            if measure == 'beta' and not h.loc[location_id, 'most_detailed'] == 1:
                continue
            err = metrics.loc[measure]            
            for i, col in enumerate(cols):
                d = err.loc[col].values     
                l, u = np.quantile(d, 0.05), np.quantile(d, 0.95)
                d = d[(l < d) & (d < u)]
                # Boxplots are super annoying.
                props = dict(color=mmap[measure][0], linewidth=2)
                ax.boxplot(
                    d,
                    positions=[i],
                    widths=[.7],
                    boxprops=props,
                    capprops=props,
                    whiskerprops=props,
                    flierprops={**props, **dict(markeredgecolor=mmap[measure][0])},
                    medianprops=props,
                    labels=[col],
                )

    
        if row == 0:
            ax.set_title('Error', fontsize=20)    
        
    sns.despine(fig=fig, left=True, bottom=True)
    fig.suptitle(f'{location_name} ({location_id})',
                 x=0.5,
                 fontsize=24,
                 ha='center')
    plot_file = str(Path(output_dir) / f'oos_{location_id}.pdf') if output_dir else None
    write_or_show(fig, plot_file)

In [10]:
def make_plot_(inputs, output_dir):
    try:
        make_plot(inputs, output_dir)
    except:
        return inputs[1]

In [13]:
runner = functools.partial(
    make_plot,
    output_dir="/mnt/share/homes/collijk/code/covid/dev/covid-model-seiir-pipeline/notebooks/loc_plots/",
)

_ = parallel.run_parallel(
    runner,
    inputs,
    num_cores=26,
    progress_bar=True,
)


 ... (more hidden) ...


In [14]:
from pathlib import Path
from typing import List

import pandas as pd
from PyPDF2 import PdfFileMerger as PdfFileMerger_


class PdfFileMerger(PdfFileMerger_):
    """Super annoying that the real class isn't a context manager."""
    def __enter__(self):
        return self

    def __exit__(self, exc_type, exc_val, exc_tb):
        self.close()


def merge_pdfs(plot_cache: Path, output_path: Path, hierarchy: pd.DataFrame):
    """Merge together all pdfs in the plot cache and write to the output path.
    The final pdf will have locations ordered by a depth first search of the
    provided hierarchy with nodes at the same level sorted alphabetically.
    """
    parent_map = hierarchy.set_index('location_id').parent_id
    name_map = hierarchy.set_index('location_id').location_ascii_name

    sorted_locations = get_locations_dfs(hierarchy)
    merged = []
    with PdfFileMerger() as merger:
        current_page = 0
        for location_id in sorted_locations:
            result_page_path = plot_cache / f'oos_{location_id}.pdf'
            
            if not result_page_path.exists():
                # We didn't model the location for some reason.
                continue

            # Add the results page.
            merger.merge(current_page, str(result_page_path))

            # Bookmark it and add a reference to it's parent.
            parent_id = parent_map.loc[location_id]
            if parent_id != location_id and parent_id in merged:
                parent = name_map.loc[parent_id]
            else:
                parent = None
            merger.addBookmark(name_map.loc[location_id], current_page, parent)

            merged.append(location_id)
            current_page += 1

        if output_path.exists():
            output_path.unlink()
        merger.write(str(output_path))


def get_locations_dfs(hierarchy: pd.DataFrame) -> List[int]:
    """Return location ids sorted by a depth first search of the hierarchy.
    Locations at the same level are sorted alphabetically by name.
    """
    def _get_locations(location: pd.Series):
        locs = [location.location_id]

        children = hierarchy[(hierarchy.parent_id == location.location_id)
                             & (hierarchy.location_id != location.location_id)]
        for child in children.sort_values('location_ascii_name').itertuples():
            locs.extend(_get_locations(child))
        return locs

    top_locs = hierarchy[hierarchy.location_id == hierarchy.parent_id]
    locations = []
    for top_loc in top_locs.sort_values('location_ascii_name').itertuples():
        locations.extend(_get_locations(top_loc))

    return locations

In [15]:
merge_pdfs(
    Path("/mnt/share/homes/collijk/code/covid/dev/covid-model-seiir-pipeline/notebooks/loc_plots/"),
    Path("/mnt/share/homes/collijk/code/covid/dev/covid-model-seiir-pipeline/notebooks/oos_analysis.pdf"),
    hierarchy
)