In [None]:
from typing import List, Tuple, Dict
from pathlib import Path
import functools
import yaml

import numpy as np
import pandas as pd

import matplotlib.pyplot as plt
import seaborn as sns

from covid_model_seiir_pipeline.pipeline.fit.specification import FitSpecification
from covid_model_seiir_pipeline.pipeline.fit.data import FitDataInterface
from covid_model_seiir_pipeline.lib import (
    cli_tools,
    parallel,
)

LOGGER = cli_tools.task_performance_logger

sns.set_style('white_grid')

NUM_CORES = 50


In [2]:
def create_index(fit_version: str, variants: List[str]) -> Tuple[pd.Index, int]:
    specification = FitSpecification.from_version_root(fit_version)
    data_interface = FitDataInterface.from_specification(specification)

    hierarchy = data_interface.load_hierarchy('pred')
    location_ids = hierarchy.loc[hierarchy['most_detailed'] == 1, 'location_id'].to_list()

    date_start = data_interface.load_reported_epi_data().reset_index()['date'].min()
    date_start -= pd.Timedelta(days=30)
    date_end = data_interface.load_reported_epi_data().reset_index()['date'].max()
    date_end += pd.Timedelta(days=30)
    dates = pd.date_range(date_start, date_end)

    universal_idx = pd.MultiIndex.from_product([location_ids, dates, variants],
                                               names=['location_id', 'date', 'variant'])

    n_draws = data_interface.get_n_draws()

    return universal_idx, n_draws


def load_fit_draw_infections(draw_id: int, measure: str, data_interface: FitDataInterface,
                             variants: List[str], universal_idx: pd.Index) -> pd.Series:
    # have to change renaming logic if this pattern changes
    columns = [f'Infection_all_{variant}_all_lr' for variant in variants] \
               + [f'Infection_all_{variant}_all_hr' for variant in variants]
    if measure != 'final':
        columns += ['round']

    data = data_interface.load_compartments(draw_id, measure_version=measure, columns=columns)
    if 'round' in data:
        data = data.loc[data['round'] == 2].drop('round', axis=1)
    data = data.rename(columns={col: col.split('_all_')[1] for col in data})
    data = data.groupby(data.columns, axis=1).sum(min_count=1)
    data = data.groupby('location_id').diff()
    data = pd.melt(data, ignore_index=False, var_name='variant', value_name=f'draw_{draw_id}')
    data = data.set_index('variant', append=True).sort_index()

    return data.reindex(universal_idx)


def load_counterfactual_draw_infections(draw_id: int, counterfactual_version: Path,
                                        variants: List[str], universal_idx: pd.Index) -> pd.Series:
    data = pd.read_parquet(counterfactual_version / 'raw_outputs' / f'draw_{draw_id}.parquet',
                           columns=[f'modeled_infections_{variant}' for variant in variants])
    data = data.rename(columns={col: col.replace('modeled_infections_', '') for col in data.columns})
    data = pd.melt(data, ignore_index=False,
                   var_name='variant', value_name=f'draw_{draw_id}')
    data = data.set_index('variant', append=True)

    return data.reindex(universal_idx)


def load_infections(version: str, source: str, measure: str,  variants: List[str],
                    universal_idx: pd.Index, n_draws: int) -> pd.DataFrame:
    LOGGER.info(f'Reading infections from {version} ({measure})')
    if source == 'fit':
        specification = FitSpecification.from_version_root(version)
        data_interface = FitDataInterface.from_specification(specification)
        _load_draw_infections = functools.partial(load_fit_draw_infections,
                                                  measure=measure,
                                                  data_interface=data_interface,
                                                  variants=variants,
                                                  universal_idx=universal_idx)
    elif source == 'counterfactual':
        _load_draw_infections = functools.partial(load_counterfactual_draw_infections,
                                                  counterfactual_version=Path(version) / measure,
                                                  variants=variants,
                                                  universal_idx=universal_idx)
    infections = parallel.run_parallel(
        runner=_load_draw_infections,
        arg_list=range(n_draws),
        num_cores=NUM_CORES,
        progress_bar=True,
    )
    infections = pd.concat(infections, axis=1)
    
    return infections


def load_draw_sero_points(draw_id: int, measure: str, data_interface: FitDataInterface) -> pd.Series:
    sero_points = data_interface.load_final_seroprevalence(draw_id,
                                                           measure_version=measure,
                                                           columns=['location_id', 'sero_date', 'is_outlier'])
    sero_points['date'] = sero_points['sero_date'] - pd.Timedelta(days=14)
    sero_points = sero_points.loc[sero_points['is_outlier'] == 0]
    sero_points[f'draw_{draw_id}'] = 1
    sero_points = (sero_points
                   .groupby(['location_id', 'date'])[f'draw_{draw_id}'].sum())

    return sero_points


def load_sero_points(fit_version: str, measure: str, n_draws: int) -> List[int]:
    LOGGER.info(f'Reading sero locations from {fit_version}')
    specification = FitSpecification.from_version_root(fit_version)
    data_interface = FitDataInterface.from_specification(specification)
    _load_draw_sero_points = functools.partial(load_draw_sero_points,
                                               measure=measure,
                                               data_interface=data_interface)
    sero_points = parallel.run_parallel(
        runner=_load_draw_sero_points,
        arg_list=range(n_draws),
        num_cores=NUM_CORES,
        progress_bar=True,
    )
    sero_points = pd.concat(sero_points, axis=1).max(axis=1).rename('n')
    
    return sero_points


def compile_inputs(reference_version: str, comparator_version: str, variants: List[str]) -> Tuple[Dict, Dict]:
    universal_idx, n_draws = create_index(comparator_version, variants)

    reference_infections = {}
    comparator_infections = {}
    for measure in ['case', 'admission', 'death']:
        reference_infections[measure] = load_infections(reference_version,
                                                        'counterfactual', measure, variants,
                                                        universal_idx, n_draws)
        comparator_infections[measure] = load_infections(comparator_version,
                                                         'fit', measure, variants,
                                                         universal_idx, n_draws)

    infections_locations = universal_idx.get_level_values('location_id').unique().to_list()
    sero_points = load_sero_points(comparator_version, measure, n_draws)
    sero_locations = sero_points.index.get_level_values('location_id').unique().to_list()
    overlap = [l for l in sero_locations if l in infections_locations]
    sero_points = sero_points.loc[overlap]

    return reference_infections, comparator_infections, sero_points


In [3]:
def calculate_relative_error(reference_infections: pd.DataFrame,
                             comparator_infections: pd.DataFrame,
                             units: str):
    comparator_infections = (comparator_infections
                             .where(comparator_infections >= 0)
                             .replace(np.inf, np.nan)
                             .dropna(how='all')) + 0.1
    reference_infections = (reference_infections
                            .where(reference_infections >= 0)
                            .replace(np.inf, np.nan)
                            .dropna(how='all')) + 0.1
    idx = comparator_infections.index.intersection(reference_infections.index)

    comparator_infections = comparator_infections.loc[idx]
    reference_infections = reference_infections.loc[idx]
    comparator_infections = comparator_infections.where(reference_infections >= 0)
    reference_infections = reference_infections.where(comparator_infections >= 0)

    if units == 'global_draws':
        relative_error = comparator_infections.loc[idx].sum() / reference_infections.loc[idx].sum()
    elif units == 'global_mean':
        relative_error = (comparator_infections.loc[idx].sum() \
                          / reference_infections.loc[idx].sum()).mean(axis=1)
    elif units == 'global_median':
        relative_error = (comparator_infections.loc[idx].sum() \
                          / reference_infections.loc[idx].sum()).median(axis=1)
    elif units == 'location_draws':
        relative_error = comparator_infections.loc[idx].groupby('location_id').sum() \
                        / reference_infections.loc[idx].groupby('location_id').sum()
    elif units == 'clipped_location_means':
        relative_error = (comparator_infections.loc[idx].groupby('location_id').sum() \
                          / reference_infections.loc[idx].groupby('location_id').sum())
        ## CLIP 2.5TH AND 97.5TH PERCENTILE
        relative_error = relative_error.clip(
            **dict(zip(['lower', 'upper'], np.quantile(relative_error, [0.025, 0.975], axis=1))),
            axis=0,
        ).mean(axis=1)
    elif units == 'location_daily_medians':
        # relative_error = (comparator_infections.loc[idx].groupby('location_id').sum() \
        #                   / reference_infections.loc[idx].groupby('location_id').sum()).median(axis=1)
        relative_error = (comparator_infections.loc[idx].median(axis=1).groupby('location_id').sum() \
                          / reference_infections.loc[idx].median(axis=1).groupby('location_id').sum())

    return relative_error


def minimize_scaling(reference_infections: pd.DataFrame,
                     comparator_infections: pd.DataFrame,
                     sero_points: pd.Series,
                     relative_error: pd.Series,
                     N: int = 4):
    max_infections = ((comparator_infections + reference_infections)
                      .groupby('location_id')
                      .cumsum()
                      .max(axis=1)
                      .dropna()
                      .astype(bool))
    variant_idx = max_infections.loc[max_infections].index

    variant_sero_points = sero_points.loc[sero_points.index.intersection(variant_idx)].groupby('location_id').sum()
    minimized_relative_error = 1 - (1 - relative_error) * (1 - variant_sero_points.clip(1, N) / N)
    minimized_relative_error = minimized_relative_error.fillna(relative_error)

    return minimized_relative_error


def calculate_kappa_scalars(reference_infections: Dict[str, pd.DataFrame],
                            comparator_infections: Dict[str, pd.DataFrame],
                            sero_points: pd.Series,
                            variants: str,
                            units: str,
                            minimize_variants: List[str],
                            infections_threshold: int,):
    relative_error = {}
    for variant in variants:
        LOGGER.info(variant)
        relative_error[variant] = {}
        for measure in ['case', 'admission', 'death']:
            ref = reference_infections[measure].loc[:, :, variant].reset_index('variant', drop=True)
            comp = comparator_infections[measure].loc[:, :, variant].reset_index('variant', drop=True)
            over_threshold = comp.groupby('location_id').sum().mean(axis=1) > infections_threshold
            over_threshold_locations = over_threshold.loc[over_threshold].index
            comp = comp.loc[over_threshold_locations]
            _relative_error = calculate_relative_error(
                ref, comp,
                units
            )
            if variant in minimize_variants:
                _relative_error = minimize_scaling(
                    ref, comp,
                    sero_points,
                    _relative_error
                )
            relative_error[variant][measure] = _relative_error # .loc[_relative_error != 1.].to_dict()
    return relative_error


def store_kappa_scalars(relative_error: pd.DataFrame, variants: List[str]):
    for variant in variants:
        variant_relative_error_dict = {
            measure: measure_scalars.to_dict()
            for measure, measure_scalars in kappa_scalars[variant].items()
        }
        with open(f'../../pipeline/fit/model/kappa_scaling_factors/{variant}.yaml',
                  'w') as file:
            yaml.dump(variant_relative_error_dict, file, default_flow_style=False)


def make_histograms(relative_error: Dict, variants: str, sero_locations: List[int] = None,):
    measures = ['case', 'admission', 'death']
    fig, ax = plt.subplots(3, len(variants), figsize=(5 * len(variants), 15), sharex='col')
    summary = {}
    for ii, variant in enumerate(variants):
        _summary = {}
        for i, measure in enumerate(measures):
            if len(variants) > 1:
                plot_idx = i, ii
            else:
                plot_idx = i
            # _relative_error = np.array(list(relative_error[variant][measure].values()))
            # bins = 20
            _relative_error = relative_error[variant][measure].copy()
            if sero_locations[variant]:
                _relative_error = _relative_error.query(
                    f'location_id in [{",".join([str(sero_location) for sero_location in sero_locations[variant]])}]'
                )
            _relative_error = _relative_error.where(_relative_error != 1.0)
            n_locs = _relative_error.index.get_level_values('location_id').unique().size
            _relative_error = np.hstack(_relative_error.values.tolist())
            _relative_error = _relative_error[~np.isnan(_relative_error)]
            _relative_error = np.log(_relative_error)
            bins = 40
            ax[plot_idx].hist(_relative_error.clip(*np.quantile(_relative_error, [0.05, 0.95])),
                           bins=bins)
            ax[plot_idx].axvline(np.median(_relative_error), linestyle=':', color='red')
            ax[plot_idx].axvline(np.mean(_relative_error), linestyle=':', color='forestgreen')
            for q in np.quantile(_relative_error, [0.25, 0.75]):
                ax[plot_idx].axvline(q, linestyle='--', color='red')
            if i == 0:
                ax[plot_idx].set_title(variant)
            if ii == 0:
                ax[plot_idx].set_ylabel(measure)
            ax[plot_idx].set_title(f'n={n_locs}')
            _summary[measure] = [np.quantile(np.exp(_relative_error), 0.25),
                                 np.quantile(np.exp(_relative_error), 0.50),
                                 np.exp(np.mean(_relative_error)),
                                 np.quantile(np.exp(_relative_error), 0.75)]
        summary[variant] = pd.DataFrame(_summary,
                                        index=pd.Index(['p25', 'p50', 'geom_mean', 'p75'])).T
        fig.tight_layout()
        fig.show()
        
    return summary


In [4]:
##################################
## GLOBAL -- 2022 VARIANTS ONLY ##
##################################

# # ancestral, alpha, beta, gamma, and delta scalars applied
# reference_version = '/mnt/share/covid-19/seir-counterfactual/2022_10_19.01'
# comparator_version = '/mnt/share/covid-19/seir-fit/2022_10_19.03'
# variants = ['omicron']
# RUN = 'GLOBAL'

# # ancestral, alpha, beta, gamma, delta, and omicron scalars applied
# reference_version = '/mnt/share/covid-19/seir-counterfactual/2022_10_19.02'
# comparator_version = '/mnt/share/covid-19/seir-fit/2022_10_19.11'
# variants = ['ba5']
# RUN = 'GLOBAL'

#################################
## LOCAL -- 2022 VARIANTS ONLY ##
#################################

# # ancestral, alpha, beta, gamma, and delta scalars applied
# reference_version = '/mnt/share/covid-19/seir-counterfactual/2022_10_19.04'
# comparator_version = '/mnt/share/covid-19/seir-fit/2022_10_19.13'
# variants = ['omicron']
# RUN = 'LOCAL'

## DAILY MEDIAN OMICRON
# # ancestral, alpha, beta, gamma, delta, and omicron scalars applied
# reference_version = '/mnt/share/covid-19/seir-counterfactual/2022_10_19.06'
# comparator_version = '/mnt/share/covid-19/seir-fit/2022_10_19.16'
# variants = ['ba5']
# RUN = 'LOCAL'

## CLIPPED MEAN RATIO OMICRON
# # ancestral, alpha, beta, gamma, delta, and omicron scalars applied
# reference_version = '/mnt/share/covid-19/seir-counterfactual/2022_10_20.02'
# comparator_version = '/mnt/share/covid-19/seir-fit/2022_10_20.03'
# variants = ['ba5']
# RUN = 'LOCAL'

###########################
## LOCAL -- ALL VARIANTS ##
###########################

# # ancestral scalars applied
# reference_version = '/mnt/share/covid-19/seir-counterfactual/2022_10_19.03'
# comparator_version = '/mnt/share/covid-19/seir-fit/2022_10_19.12'
# variants = ['alpha', 'beta', 'gamma']
# RUN = 'LOCAL'

## DAILY MEDIAN ALPHA/BETA/GAMMA
# # ancestral, alpha, beta, and gamma scalars applied
# reference_version = '/mnt/share/covid-19/seir-counterfactual/2022_10_19.05'
# comparator_version = '/mnt/share/covid-19/seir-fit/2022_10_19.15'
# variants = ['delta']
# RUN = 'LOCAL'

# # ancestral, alpha, beta, gamma, and delta scalars applied
# reference_version = '/mnt/share/covid-19/seir-counterfactual/2022_10_20.01'
# comparator_version = '/mnt/share/covid-19/seir-fit/2022_10_19.17'
# variants = ['omicron']
# RUN = 'LOCAL'

# # ancestral, alpha, beta, gamma, delta, and omicron scalars applied
# reference_version = '/mnt/share/covid-19/seir-counterfactual/2022_10_20.06'
# comparator_version = '/mnt/share/covid-19/seir-fit/2022_10_20.08'
# variants = ['ba5']
# RUN = 'LOCAL'

# ##################################################
# # PROD 2022-11-14 UPDATE
# ##################################################
# # ancestral, alpha, beta, gamma, delta, and omicron scalars applied
# reference_version = '/mnt/share/covid-19/seir-counterfactual/2022_11_16.04'
# comparator_version = '/mnt/share/covid-19/seir-fit/2022_11_16.03'
# variants = ['ba5']
# RUN = 'LOCAL'
# ##################################################

##################################################
# PROD 2022-12-12 UPDATE
##################################################
# ancestral, alpha, beta, gamma, delta, and omicron scalars applied
reference_version = '/mnt/share/covid-19/seir-counterfactual/2022_12_13.01'
comparator_version = '/mnt/share/covid-19/seir-fit/2022_12_13.02'
variants = ['ba5']
RUN = 'LOCAL'
##################################################

## CLIPPED MEAN RATIO ALPHA/BETA/GAMMA
# # ancestral, alpha, beta, and gamma scalars applied
# reference_version = '/mnt/share/covid-19/seir-counterfactual/2022_10_20.03'
# comparator_version = '/mnt/share/covid-19/seir-fit/2022_10_20.04'
# variants = ['delta']
# RUN = 'LOCAL'

# # ancestral, alpha, beta, gamma, and delta scalars applied
# reference_version = '/mnt/share/covid-19/seir-counterfactual/2022_10_20.04'
# comparator_version = '/mnt/share/covid-19/seir-fit/2022_10_20.06'
# variants = ['omicron']
# RUN = 'LOCAL'

# # ancestral, alpha, beta, gamma, delta, and omicron scalars applied
# reference_version = '/mnt/share/covid-19/seir-counterfactual/2022_10_20.05'
# comparator_version = '/mnt/share/covid-19/seir-fit/2022_10_20.07'
# variants = ['ba5']
# RUN = 'LOCAL'

########################################################
########################################################
LOGGER.info('Generating infections')
reference_infections, comparator_infections, sero_points = compile_inputs(
    reference_version, comparator_version, variants
)


2022-12-13 23:53:16.987 | INFO     | covid_model_seiir_pipeline.lib.cli_tools.performance_logger.performance_logger:info:26 - Generating infections
2022-12-13 23:53:18.343 | INFO     | covid_model_seiir_pipeline.lib.cli_tools.performance_logger.performance_logger:info:26 - Reading infections from /mnt/share/covid-19/seir-counterfactual/2022_12_13.01 (case)
100%|█████████████████████████████████████████| 100/100 [00:20<00:00,  4.91it/s]
2022-12-13 23:53:48.516 | INFO     | covid_model_seiir_pipeline.lib.cli_tools.performance_logger.performance_logger:info:26 - Reading infections from /mnt/share/covid-19/seir-fit/2022_12_13.02 (case)
100%|█████████████████████████████████████████| 100/100 [00:20<00:00,  4.80it/s]
2022-12-13 23:54:18.658 | INFO     | covid_model_seiir_pipeline.lib.cli_tools.performance_logger.performance_logger:info:26 - Reading infections from /mnt/share/covid-19/seir-counterfactual/2022_12_13.01 (admission)
100%|█████████████████████████████████████████| 100/100 [00:18<

In [None]:
# # reference_version = '/mnt/share/covid-19/seir-counterfactual/2022_10_20.05'
# # comparator_version = '/mnt/share/covid-19/seir-fit/2022_10_20.07'
# # variant = 'ba5'
# # LOGGER.info('Generating infections')
# # reference_infections, comparator_infections, sero_points = compile_inputs(
# #     reference_version, comparator_version, [variant]
# # )

# measure = 'admission'
# location_id = 60382

# ref = reference_infections[measure].loc[:, '2021-10-01':, variant].reset_index('variant', drop=True)
# comp = comparator_infections[measure].loc[:, '2021-10-01':, variant].reset_index('variant', drop=True)


# _comp = (comp
#          .where(comp >= 0)
#          .replace(np.inf, np.nan)
#          .dropna(how='all'))
# _ref = (ref
#         .where(ref >= 0)
#         .replace(np.inf, np.nan)
#         .dropna(how='all'))
# _idx = _comp.index.intersection(_ref.index)


# _comp = _comp.loc[_idx]
# _ref = _ref.loc[_idx]
# _comp = _comp.where(_ref >= 0)
# _ref = _ref.where(_comp >= 0)


# fig, ax = plt.subplots(1, 2, figsize=(16, 4.5))
# ref.loc[_idx].loc[location_id].plot(ax=ax[0], color='darkgrey',
#                                     alpha=0.1, legend=False)
# comp.loc[_idx].loc[location_id].plot(ax=ax[0], color='dodgerblue',
#                                     alpha=0.1, legend=False)
# ref.loc[_idx].loc[location_id].mean(axis=1).plot(ax=ax[0], color='black')
# comp.loc[_idx].loc[location_id].mean(axis=1).plot(ax=ax[0], color='navy')

# ref.loc[_idx].loc[location_id].cumsum().plot(ax=ax[1], color='darkgrey',
#                                     alpha=0.1, legend=False)
# comp.loc[_idx].loc[location_id].cumsum().plot(ax=ax[1], color='dodgerblue',
#                                     alpha=0.1, legend=False)
# ref.loc[_idx].loc[location_id].cumsum().mean(axis=1).plot(ax=ax[1], color='black')
# comp.loc[_idx].loc[location_id].cumsum().mean(axis=1).plot(ax=ax[1], color='navy')
# fig.show()


In [None]:
# def clipped_mean(comparator_infections, reference_infections, idx, pctiles):
#     relative_error = (comparator_infections.loc[idx].groupby('location_id').sum() \
#                               / reference_infections.loc[idx].groupby('location_id').sum())
#     ## CLIP 2.5TH AND 97.5TH PERCENTILE
#     relative_error = relative_error.clip(
#         **dict(zip(['lower', 'upper'], np.quantile(relative_error, pctiles, axis=1))),
#         axis=0,
#     ).mean(axis=1)
#     return relative_error

# {
#     'mean pre': (_comp.loc[_idx].mean(axis=1).groupby('location_id').sum() \
#                  / _ref.loc[_idx].mean(axis=1).groupby('location_id').sum()).loc[location_id],
    
#     'median pre': (_comp.loc[_idx].median(axis=1).groupby('location_id').sum() \
#                  / _ref.loc[_idx].median(axis=1).groupby('location_id').sum()).loc[location_id],

#     'median old way': (_comp.loc[_idx].groupby('location_id').sum().median(axis=1) \
#                  / _ref.loc[_idx].groupby('location_id').sum().median(axis=1)).loc[location_id],
    
#     'clipped mean post (2.5/97.5)': clipped_mean(_comp, _ref, _idx, [0.025, 0.975]).loc[location_id],
    
#     'clipped mean post (5/95)': clipped_mean(_comp, _ref, _idx, [0.05, 0.95]).loc[location_id],
    
#     'median post': (_comp.loc[_idx].groupby('location_id').sum() \
#                     / _ref.loc[_idx].groupby('location_id').sum()).median(axis=1).loc[location_id],
# }


In [5]:
LOGGER.info(f'Run type: {RUN}')
if RUN == 'LOCAL':
    LOGGER.info('Calculating ratios')
    kappa_scalars = calculate_kappa_scalars(
        reference_infections, comparator_infections, sero_points, variants,
        units='location_daily_medians',
        # units='clipped_location_means',
        infections_threshold=0,
        minimize_variants=['alpha', 'beta', 'gamma', 'delta'],
    )

    LOGGER.info('Writing kappa scaling factors')
    store_kappa_scalars(kappa_scalars, variants)
elif RUN == 'GLOBAL':
    LOGGER.info('Calculating ratios')
    kappa_scalars = calculate_kappa_scalars(
        reference_infections, comparator_infections, sero_points, variants,
        units='location_draws',
        infections_threshold=1e5,
        minimize_variants=[]
    )

    LOGGER.info('Plotting global distributions')
    global_summary = make_histograms(kappa_scalars, variants,
                                     sero_locations={
                                         'alpha': [
                                             # ## ALPHA
                                             # 47,    # Czechia
                                             # 51,    # Poland
                                             # 58,    # Estonia (?)
                                             # 545,   # Michigan
                                             # 4749,  # England
                                             # 434,   # Scotland
                                             # 385,   # Puerto Rico (?)
                                             # 144,   # Jordan
                                             # 180,   # Kenya (?)
                                             # 207,   # Ghana (?)
                                         ],
                                         'beta': [
                                             # 181,   # Madagascar (?)
                                             # 182,   # Malawi (?)
                                             # 196,   # South Africa
                                         ],
                                         'gamma': [
                                             # 98,    # Chile
                                             # 4772,  # Rio Grande do Sul (?)
                                             # 4775,  # São Paulo
                                         ],
                                         'delta': [
                                             # 89,    # Netherlands (?)
                                             # 4749,  # England
                                             # 434,   # Scotland
                                             # 4646,  # Campeche
                                             # 180,   # Kenya (?)
                                             # 182,   # Malawi (?)
                                             # 196,   # South Africa (?)
                                             # 214,   # Nigeria
                                         ]
                                         # + [state for state in range(4840, 4876)
                                         #    if state not in [4840, 4842, 4845, 4850,
                                         #                     4847, 4848, 4858, 4861,
                                         #                     4862, 4863, 4864, 4866,
                                         #                     4869, 4872]]
                                         # + list(range(523, 574))
                                         ,
                                         'omicron': 
                                             [],
                                             # sero_points.index.get_level_values('location_id').unique().to_list(),
                                         'ba5': 
                                             [],
                                             # sero_points.index.get_level_values('location_id').unique().to_list(),
                                     },)
    for variant in variants:
        print(variant)
        print(global_summary[variant])
        print('')


2022-12-13 23:56:15.183 | INFO     | covid_model_seiir_pipeline.lib.cli_tools.performance_logger.performance_logger:info:26 - Run type: LOCAL
2022-12-13 23:56:15.185 | INFO     | covid_model_seiir_pipeline.lib.cli_tools.performance_logger.performance_logger:info:26 - Calculating ratios
2022-12-13 23:56:15.187 | INFO     | covid_model_seiir_pipeline.lib.cli_tools.performance_logger.performance_logger:info:26 - ba5
2022-12-13 23:57:10.737 | INFO     | covid_model_seiir_pipeline.lib.cli_tools.performance_logger.performance_logger:info:26 - Writing kappa scaling factors


In [None]:
# omicron (1e5)
#                 p25       p50  geom_mean       p75
# case       0.152473  0.245849   0.227270  0.367569
# admission  0.152440  0.220426   0.223511  0.312732
# death      0.112856  0.187334   0.188816  0.325152

# omicron (1e6)
#                 p25       p50  geom_mean       p75
# case       0.165986  0.266392   0.255385  0.398874
# admission  0.203510  0.285317   0.304645  0.439119
# death      0.145871  0.258278   0.252940  0.435830

####################################################

# ba5 (1e5)
#                 p25       p50  geom_mean       p75
# case       0.056931  0.130835   0.109367  0.221797
# admission  0.212842  0.412113   0.382877  0.730584
# death      0.075324  0.138294   0.135091  0.265208


# ba5 (1e6)
#                 p25       p50  geom_mean       p75
# case       0.077849  0.170130   0.142058  0.243022
# admission  0.386437  0.630114   0.576358  0.887026
# death      0.086445  0.173559   0.169608  0.326595



## Pre-omicron global calibration (uses case-based infections to get risk ratios for IHR and IFR)

In [None]:
def loader(draw_id: int, version: str, measure: str, columns: List[str]) -> pd.DataFrame:
    data = pd.read_parquet(f'{version}/compartments/{measure}_draw_{draw_id}.parquet',
                           columns=columns)
    data = data.loc[data['round'] == 2].drop('round', axis=1)
    data = data.rename(columns={col: col.split('_all_')[1] for col in data})
    data = data.groupby(data.columns, axis=1).sum(min_count=1)
    data = data.groupby('location_id').diff()
    data['draw_id'] = draw_id
    data = data.set_index('draw_id', append=True).sort_index()

    return data


def get_crude_ratio_dist(
    measure: str,
    prior_variants: List[str],
    variant: str,
    cf_version: str,
    base_version: str,
    threshold: int,
    n_draws: int = 100,
):
    columns = [f'Infection_all_{v}_all_lr' for v in prior_variants + [variant]]
    columns += [f'Infection_all_{v}_all_hr' for v in prior_variants + [variant]]
    columns += ['round']

    cf_loader = functools.partial(loader,
                                  version=cf_version,
                                  measure='case',
                                  columns=columns)
    daily_cf = parallel.run_parallel(
        runner=cf_loader,
        arg_list=range(n_draws),
        num_cores=NUM_CORES,
        progress_bar=False,
    )
    daily_cf = pd.concat(daily_cf)

    base_loader = functools.partial(loader,
                                    version=base_version,
                                    measure=measure,
                                    columns=columns)
    daily_base = parallel.run_parallel(
        runner=base_loader,
        arg_list=range(n_draws),
        num_cores=NUM_CORES,
        progress_bar=False,
    )
    daily_base = pd.concat(daily_base)

    cumul_cf = daily_cf.groupby(['location_id', 'draw_id']).sum()
    cumul_base = daily_base.groupby(['location_id', 'draw_id']).sum()

    prior_error = (cumul_base.loc[:, prior_variants].sum(axis=1)
                   / cumul_cf.loc[:, prior_variants].sum(axis=1))

    error_ratio = cumul_base.loc[:, variant] / (cumul_cf.loc[:, variant] * prior_error)

    over_threshold = cumul_cf.loc[:, variant].groupby('location_id').mean() > threshold
    threshold_idx = over_threshold.loc[over_threshold].index

    subset_error_ratio = (error_ratio
                          .loc[threshold_idx]
                          .replace(np.inf, np.nan)
                          .dropna())
    
    return subset_error_ratio


def evaluate_variant(variant: str, prior_variants: List[str],
                     cf_version: str, base_version: str,
                     threshold: int):
    error_ratio = []
    for measure in ['admission', 'death']:
        error_ratio.append(
            get_crude_ratio_dist(
                measure=measure,
                prior_variants=prior_variants,
                variant=variant,
                cf_version=cf_version,
                base_version=base_version,
                threshold=threshold,
            ).rename(measure)
        )
    error_ratio = pd.concat(error_ratio, axis=1)
    LOGGER.info(
        '\n' + '\n' + variant.upper() + '\n' +
        str(error_ratio.describe().loc[['25%', '50%', 'mean', '75%'], :].T)
    )

    fig, ax = plt.subplots(1, 2,figsize=(16, 5))
    for i, measure in enumerate(['admission', 'death']):
        ax[i].hist(error_ratio[measure].dropna().clip(*np.quantile(error_ratio[measure].dropna(), [0.05, 0.95])).values,
                   bins=40)
        ax[i].axvline(error_ratio[measure].median(), linestyle='--', color='firebrick')
        ax[i].axvline(error_ratio[measure].mean(), linestyle='--', color='mediumseagreen')
        ax[i].set_title(measure)
    fig.suptitle(variant)
    fig.show()


In [None]:
# for variant in ['alpha', 'beta', 'gamma']:
#     evaluate_variant(
#         variant=variant,
#         prior_variants=['ancestral'],
#         cf_version='/ihme/covid-19/seir-fit/2022_10_18.03',
#         base_version='/ihme/covid-19/seir-fit/2022_10_18.04',
#         threshold=1e5,
#     )

# evaluate_variant(
#     variant='delta',
#     prior_variants=['ancestral', 'alpha', 'beta', 'gamma'],
#     cf_version='/ihme/covid-19/seir-fit/2022_10_18.05',
#     base_version='/ihme/covid-19/seir-fit/2022_10_18.06',
#     threshold=1e5,
# )


In [None]:
## Full calibration --
# Kazakhstan (36)
# Kyrgyzstan (37)
Alagoas (4751)
# Amapá (4753)
# Mato Grosso (4762)
# Mato Grosso do Sul (4761)
# Pará (4763)
# Mizoram (4863)
# Sikkim (4869)


## Broken -- 
# Qatar (151)
# Lakshadweep (4858)
# Seychelles (186) [not public, but screws up aggregate]
# Central African Republic (169) - splice
# Djibouti (177) - splice
