In [1]:
# Just a bunch of imports.  Nothing special here. Some imports for plotting stuff I cleaned out
import functools
import itertools
import multiprocessing
from pathlib import Path
from typing import List, NamedTuple, Union, Tuple
import warnings

from covid_shared.shell_tools import mkdir
from covid_shared.cli_tools.run_directory import (
    make_run_directory,
    get_run_directory,
)
from covid_shared.cli_tools.metadata import (
    RunMetadata,
    update_with_previous_metadata,
)
import matplotlib.pyplot as plt
import matplotlib.dates as mdates
import matplotlib.lines as mlines
import numpy as np
import pandas as pd
import seaborn as sns
import tqdm
import yaml

from covid_model_seiir_pipeline.lib import (
    cli_tools,
)

logger = cli_tools.task_performance_logger

warnings.simplefilter('ignore')

NUM_CORES = 50


In [3]:
########################
# THIS NEEDS A CLEANUP #
########################
# CODE FOR GENERATING CALIBRATION BETAS

def write_data(data: Union[pd.Series, pd.DataFrame], output_path: Path) -> None:
    """Very thin abstraction around writing parquet files to handle both dataframes and series consistently"""
    if isinstance(data, pd.Series):
        data = data.to_frame()
    data.to_parquet(output_path, engine='fastparquet', compression='gzip')


class DataLoader:
    
    def __init__(self, 
                 fit_path: Path, 
                 regression_path: Path,
                 measure: str):
        self.fit_path = fit_path
        self.regression_path = regression_path
        self.measure = measure
        
    def load_fit_beta(self, draw: int) -> pd.Series:
        fit_beta = pd.read_parquet(self.fit_path / 'beta' / f'{self.measure}_draw_{draw}.parquet')    
        if 'round' in fit_beta:
            # There is first and second pass betas for non-final betas and we only want the second one.
            fit_beta = fit_beta[fit_beta['round'] == 2]        
        fit_beta = fit_beta['beta'].dropna()
        return fit_beta
    
    def load_regression_beta(self, draw: int) -> pd.Series:
        if 'regression' in str(self.regression_path):
            return pd.read_parquet(self.regression_path / 'beta' / f'draw_{draw}.parquet')['beta_hat'].dropna()
        elif 'fit' in str(self.regression_path):
            fit_beta = pd.read_parquet(self.regression_path / 'beta' / f'case_draw_{draw}.parquet')    
            if 'round' in fit_beta:
                # There is first and second pass betas for non-final betas and we only want the second one.
                fit_beta = fit_beta[fit_beta['round'] == 2]        
            fit_beta = fit_beta['beta'].dropna().rename('beta_hat')
            return fit_beta
    
    def load_infections(self, draw: int) -> pd.Series:
        infections = pd.read_parquet(self.fit_path / 'posterior_epi_measures' / f'{self.measure}_draw_{draw}.parquet')
        return infections['daily_total_infections']
    
    def load_invasion_dates(self, draw: int, variants: List[str]) -> pd.Series:
        # Rho is written by measure, but is the same for all measures.
        rho = pd.read_parquet(self.fit_path / 'ode_parameters' / f'{self.measure}_draw_{draw}.parquet')
        rho = rho.loc[:, [f'rho_{v}_infection' for v in variants]].sum(axis=1).groupby('location_id').cummax()
        invasion_dates = rho[rho < 0.01].reset_index().groupby('location_id').date.max().rename('invasion_dates')
        return invasion_dates


def compute_initial_beta_scaling_parameters_by_draw(
    measure_draw: str,    
    fit_path: Path,
    regression_path: Path,    
    output_path: Path,
    variants: List[str],
    weighting: str
) -> Tuple[pd.DataFrame, pd.DataFrame]:
    measure, draw = measure_draw.split('_')
    draw_id = int(draw)
    measure_root = output_path / f'beta_{measure}'
    mkdir(measure_root, exists_ok=True, parents=True)  
    beta_scaling = {
        'window_size': 42,
        'min_avg_window': 21,
        'average_over_min': 7, 
        'average_over_max': 42,
        'residual_rescale_upper': 1,
        'residual_rescale_lower': 1,
    }    

    beta_scales = []
    loader = DataLoader(fit_path, regression_path, measure)
    
    fit_beta = loader.load_fit_beta(draw_id)    
    invasion_dates = loader.load_invasion_dates(draw_id, variants) 
    regression_beta = loader.load_regression_beta(draw_id)
    regression_beta = (regression_beta
                       .reset_index('date')
                       .join(fit_beta.reset_index().groupby('location_id')['date'].max().rename('fit_end_date')))
    regression_beta = (regression_beta
                       .loc[regression_beta['date'] < regression_beta['fit_end_date'] + pd.Timedelta(days=30)])
    regression_beta = (regression_beta
                       .set_index('date', append=True)
                       .sort_index()
                       .loc[:, 'beta_hat'])
    infections = loader.load_infections(draw_id)

    trimmed_beta = pd.concat([fit_beta, invasion_dates.reindex(fit_beta.index, level='location_id')], axis=1).reset_index()
    trimmed_beta = trimmed_beta.loc[trimmed_beta['date'] <= trimmed_beta['invasion_dates']].set_index(['location_id', 'date'])['beta']
    infections = infections.loc[trimmed_beta.index]

    norm_infections = (infections
                       .groupby('location_id')
                       .apply(lambda x: x / x.max())
                       .fillna(0.)
                       .rename('weight'))

    threshold_weights_one = norm_infections.copy()
    threshold_weights_one[threshold_weights_one < 0.01] = 0.
    threshold_weights_one[threshold_weights_one >= 0.01] = 1.

    threshold_weights_five = norm_infections.copy()
    threshold_weights_five[threshold_weights_five < 0.05] = 0.
    threshold_weights_five[threshold_weights_five >= 0.05] = 1.

    if weighting == 'infection':
        weights = infections.copy().rename('weight')
    elif weighting == 'threshold_one':
        weights = threshold_weights_one.copy()
    elif weighting == 'threshold_five':
        weights = threshold_weights_five.copy()

    betas = pd.concat([trimmed_beta, regression_beta], axis=1)
    betas = betas.loc[betas['beta'].notnull()]

    # Select out the transition day to compute the initial scaling parameter.
    beta_transition = betas.groupby('location_id').last()
    beta_scales.append((beta_transition['beta'] / beta_transition['beta_hat']).rename('scale_init'))

    beta_scales.append(pd.Series(beta_scaling['window_size'], index=beta_transition.index, name='window_size'))
    log_beta_residual = np.log(betas['beta']/betas['beta_hat']).rename('log_beta_residual')
    weighted_log_beta_residual_mean = ((weights * log_beta_residual)
                                       .groupby(level='location_id')
                                       .apply(lambda x: x.iloc[60:].mean()))
    total_weights = (weights
                    .groupby(level='location_id')
                    .apply(lambda x: x.iloc[60:].mean()))
    log_beta_residual_mean = ((weighted_log_beta_residual_mean / total_weights)
                              .fillna(0.)
                              .rename('log_beta_residual_mean'))
    beta_scales.append(np.exp(log_beta_residual_mean).rename('scale_final'))
    
    beta_scales = pd.concat(beta_scales, axis=1)
    
    beta_hat = pd.concat([regression_beta, invasion_dates.reindex(regression_beta.index, level='location_id')], axis=1).reset_index()
    beta_hat = beta_hat[beta_hat['date'] > beta_hat['invasion_dates']].drop(columns='invasion_dates')
    beta_hat = beta_hat.loc[beta_hat.groupby('location_id').transform('count')['date'] > 1]

    beta_future = beta_shift(beta_hat, beta_scales)
    # betas = betas.loc[beta_future['location_id'].unique()]
    beta_future = (beta_future
                   .set_index(['location_id', 'date'])
                   .beta_hat
                   .rename('beta'))
    beta = betas['beta'].append(beta_future).sort_index()

    write_data(beta, measure_root /  f'draw_{draw}.parquet')


def beta_shift(beta_hat: pd.DataFrame,
               beta_scales: pd.DataFrame) -> pd.DataFrame:
    beta_hat = beta_hat.sort_values(['location_id', 'date']).set_index('location_id')
    scale_init = beta_scales['scale_init']
    scale_final = beta_scales['scale_final']
    window_size = beta_scales['window_size']

    beta_final = []
    for location_id in beta_hat.index.intersection(window_size.index).unique():
        if window_size is not None:
            t = np.arange(len(beta_hat.loc[location_id])) / window_size.at[location_id]
            scale = scale_init.at[location_id] + (scale_final.at[location_id] - scale_init.at[location_id]) * t
            scale[(window_size.at[location_id] + 1):] = scale_final.at[location_id]
        else:
            scale = scale_init.at[location_id]
        loc_beta_hat = beta_hat.loc[location_id].set_index('date', append=True)['beta_hat']
        loc_beta_final = loc_beta_hat * scale
        beta_final.append(loc_beta_final)

    beta_final = pd.concat(beta_final).reset_index()

    return beta_final


def log_scale(y, n_times_upper, n_times_lower):
    z = y.copy()
    for i in range(n_times_upper):
        z[z > 0] = np.log(z[z > 0] + 1)
    z = -z
    for i in range(n_times_lower):
        z[z > 0] = np.log(z[z > 0] + 1)
    z = -z
    return z


def compute_initial_beta_scaling_parameters(fit_path: Path,
                                            regression_path: Path,
                                            output_path: Path,
                                            variants: List[str],
                                            weighting: str,
                                            num_cores: int,
                                            progress_bar: bool) -> List[Tuple[pd.DataFrame, pd.DataFrame]]:
    mkdir(output_path, exists_ok=True, parents=True)
    # Serialization is our bottleneck, so we parallelize draw level data
    # ingestion and computation across multiple processes.
    _runner = functools.partial(
        compute_initial_beta_scaling_parameters_by_draw,
        fit_path=fit_path,
        regression_path=regression_path,
        output_path=output_path,
        variants=variants,
        weighting=weighting
    )
    measures = ['case', 'death', 'admission']
    draws = list(range(100))
    measure_draws = [f'{m}_{d}' for m, d in itertools.product(measures, draws)]
    with multiprocessing.Pool(num_cores) as pool:
        list(tqdm.tqdm(pool.imap(_runner, measure_draws), total=len(measure_draws), disable=not progress_bar))

# compute_initial_beta_scaling_parameters_by_draw(
#     measure_draw='death_0',
#     fit_path=Path('/ihme/covid-19/seir-fit/2022_09_27.03'),
#     regression_path=Path('/ihme/covid-19/seir-regression/2022_09_28.01'),
#     output_path=Path('/mnt/share/covid-19/seir-counterfactual-input/2022_09_29.01'),
#     variants=['alpha', 'beta', 'gamma'],
# )

In [4]:
def build_output_scenario_version(
    preprocessing_version: Path,
    fit_version: Path,
    beta_hat_version: Path,
    forecast_version: Path,
    variants: List[str],
    weighting: str
):
    """ONLY FUNCTION YOU MAY HAVE TO MODIFY."""
    # BOILERPLATE IGNORE THIS SECTION
    # Handle all the metadata generation 
    run_metadata = RunMetadata()
    run_directory = make_run_directory('/ihme/covid-19/seir-counterfactual-input/')
    print('Building counterfactual input version at ', str(run_directory))
    run_metadata['output_path'] = str(run_directory)
    for v in [preprocessing_version, beta_hat_version, forecast_version]:
        run_metadata = update_with_previous_metadata(run_metadata, v)
    
    
    # This function should not need to change.
    # Make scenario data
    compute_initial_beta_scaling_parameters(
        fit_path=fit_version, 
        regression_path=beta_hat_version, 
        output_path=run_directory / 'beta', 
        variants=variants,
        weighting=weighting,
        num_cores=NUM_CORES,
        progress_bar=True
    )

    # Write our metadata.
    run_metadata.dump(run_directory / 'metadata.yaml')


In [7]:
pp = Path('/ihme/covid-19/seir-preprocess/2022_12_13.04/')  # Shouldn't need to change
fp = Path('/ihme/covid-19/seir-forecast/2022_11_16.09/')    # Just used for filler data for postprocessing, shouldn't need to change


In [None]:
###########
## GLOBAL ##
###########
# build_output_scenario_version(
#     preprocessing_version=pp,
#     fit_version=Path('/ihme/covid-19/seir-fit/2022_10_19.03'),
#     beta_hat_version=Path('/ihme/covid-19/seir-regression/2022_10_19.03'),
#     forecast_version=fp,
#     variants=['omicron'],
#     weighting='threshold_one',
# )
## INPUTS: /mnt/share/covid-19/seir-counterfactual-input/2022_10_19.01
## OUTPUTS: /mnt/share/covid-19/seir-counterfactual/2022_10_19.01

# build_output_scenario_version(
#     preprocessing_version=pp,
#     fit_version=Path('/ihme/covid-19/seir-fit/2022_10_19.11'),
#     beta_hat_version=Path('/ihme/covid-19/seir-regression/2022_10_19.03'),
#     forecast_version=fp,
#     variants=['ba5'],
#     weighting='threshold_one',
# )
## INPUTS: /mnt/share/covid-19/seir-counterfactual-input/2022_10_19.02
## OUTPUTS: /mnt/share/covid-19/seir-counterfactual/2022_10_19.02


In [None]:
#################################
## LOCAL -- 2022 VARIANTS ONLY ##
#################################
# build_output_scenario_version(
#     preprocessing_version=pp,
#     fit_version=Path('/ihme/covid-19/seir-fit/2022_10_19.13'),
#     beta_hat_version=Path('/ihme/covid-19/seir-regression/2022_10_19.04'),
#     forecast_version=fp,
#     variants=['omicron'],
#     weighting='threshold_one',
# )
## INPUTS: /mnt/share/covid-19/seir-counterfactual-input/2022_10_19.04
## OUTPUTS: /mnt/share/covid-19/seir-counterfactual/2022_10_19.04

## DAILY MEDIAN OMICRON
# build_output_scenario_version(
#     preprocessing_version=pp,
#     fit_version=Path('/ihme/covid-19/seir-fit/2022_10_19.16'),
#     beta_hat_version=Path('/ihme/covid-19/seir-regression/2022_10_19.04'),
#     forecast_version=fp,
#     variants=['ba5'],
#     weighting='threshold_one',
# )
## INPUTS: /mnt/share/covid-19/seir-counterfactual-input/2022_10_19.06
## OUTPUTS: /mnt/share/covid-19/seir-counterfactual/2022_10_19.06

# ## CLIPPED MEAN RATIO OMICRON
# build_output_scenario_version(
#     preprocessing_version=pp,
#     fit_version=Path('/ihme/covid-19/seir-fit/2022_10_20.03'),
#     beta_hat_version=Path('/ihme/covid-19/seir-regression/2022_10_19.04'),
#     forecast_version=fp,
#     variants=['ba5'],
#     weighting='threshold_one',
# )
# ## INPUTS: /mnt/share/covid-19/seir-counterfactual-input/2022_10_20.02
# ## OUTPUTS: /mnt/share/covid-19/seir-counterfactual/2022_10_20.02


In [None]:
###########################
## LOCAL -- ALL VARIANTS ##
###########################
# build_output_scenario_version(
#     preprocessing_version=pp,
#     fit_version=Path('/ihme/covid-19/seir-fit/2022_10_19.12'),
#     beta_hat_version=Path('/ihme/covid-19/seir-regression/2022_10_19.04'),
#     forecast_version=fp,
#     variants=['alpha', 'beta', 'gamma'],
#     weighting='threshold_five',
# )
## INPUTS: /mnt/share/covid-19/seir-counterfactual-input/2022_10_19.03
## OUTPUTS: /mnt/share/covid-19/seir-counterfactual/2022_10_19.03

## DAILY MEDIAN ALPHA/BETA/GAMMA
# build_output_scenario_version(
#     preprocessing_version=pp,
#     fit_version=Path('/ihme/covid-19/seir-fit/2022_10_19.15'),
#     beta_hat_version=Path('/ihme/covid-19/seir-regression/2022_10_19.04'),
#     forecast_version=fp,
#     variants=['delta'],
#     weighting='threshold_five',
# )
## INPUTS: /mnt/share/covid-19/seir-counterfactual-input/2022_10_19.05
## OUTPUTS: /mnt/share/covid-19/seir-counterfactual/2022_10_19.05

# build_output_scenario_version(
#     preprocessing_version=pp,
#     fit_version=Path('/ihme/covid-19/seir-fit/2022_10_19.17'),
#     beta_hat_version=Path('/ihme/covid-19/seir-regression/2022_10_19.04'),
#     forecast_version=fp,
#     variants=['omicron'],
#     weighting='threshold_one',
# )
## INPUTS: /mnt/share/covid-19/seir-counterfactual-input/2022_10_20.01
## OUTPUTS: /mnt/share/covid-19/seir-counterfactual/2022_10_20.01

# build_output_scenario_version(
#     preprocessing_version=pp,
#     fit_version=Path('/ihme/covid-19/seir-fit/2022_10_20.08'),
#     beta_hat_version=Path('/ihme/covid-19/seir-regression/2022_10_19.04'),
#     forecast_version=fp,
#     variants=['ba5'],
#     weighting='threshold_one',
# )
# ## INPUTS: /mnt/share/covid-19/seir-counterfactual-input/2022_10_20.XX
# ## OUTPUTS: /mnt/share/covid-19/seir-counterfactual/2022_10_XX.XX

## CLIPPED MEAN RATIO OMICRON
# build_output_scenario_version(
#     preprocessing_version=pp,
#     fit_version=Path('/ihme/covid-19/seir-fit/2022_10_20.04'),
#     beta_hat_version=Path('/ihme/covid-19/seir-regression/2022_10_19.04'),
#     forecast_version=fp,
#     variants=['delta'],
#     weighting='threshold_five',
# )
## INPUTS: /mnt/share/covid-19/seir-counterfactual-input/2022_10_20.03
## OUTPUTS: /mnt/share/covid-19/seir-counterfactual/2022_10_20.03

# build_output_scenario_version(
#     preprocessing_version=pp,
#     fit_version=Path('/ihme/covid-19/seir-fit/2022_10_20.06'),
#     beta_hat_version=Path('/ihme/covid-19/seir-regression/2022_10_19.04'),
#     forecast_version=fp,
#     variants=['omicron'],
#     weighting='threshold_one',
# )
## INPUTS: /mnt/share/covid-19/seir-counterfactual-input/2022_10_20.04
## OUTPUTS: /mnt/share/covid-19/seir-counterfactual/2022_10_20.04

# build_output_scenario_version(
#     preprocessing_version=pp,
#     fit_version=Path('/ihme/covid-19/seir-fit/2022_10_20.07'),
#     beta_hat_version=Path('/ihme/covid-19/seir-regression/2022_10_19.04'),
#     forecast_version=fp,
#     variants=['ba5'],
#     weighting='threshold_one',
# )
## INPUTS: /mnt/share/covid-19/seir-counterfactual-input/2022_10_20.05
## OUTPUTS: /mnt/share/covid-19/seir-counterfactual/2022_10_20.05


In [None]:
# DAILY MEDIAN BA.5 ==== 2022-11-14 prod update
build_output_scenario_version(
    preprocessing_version=pp,
    fit_version=Path('/ihme/covid-19/seir-fit/2022_11_16.03'),
    beta_hat_version=Path('/ihme/covid-19/seir-regression/2022_11_16.02'),
    forecast_version=fp,
    variants=['ba5'],
    weighting='threshold_one',
)
# INPUTS: /mnt/share/covid-19/seir-counterfactual-input/2022_11_16.01
# OUTPUTS: /mnt/share/covid-19/seir-counterfactual/2022_11_16.01


In [8]:
# DAILY MEDIAN BA.5 ==== 2022-12-12 prod update
build_output_scenario_version(
    preprocessing_version=pp,
    fit_version=Path('/ihme/covid-19/seir-fit/2022_12_13.02'),
    beta_hat_version=Path('/ihme/covid-19/seir-regression/2022_11_16.06'),
    forecast_version=fp,
    variants=['ba5'],
    weighting='threshold_one',
)
# INPUTS: /mnt/share/covid-19/seir-counterfactual-input/2022_12_13.01
# OUTPUTS: /mnt/share/covid-19/seir-counterfactual/2022_12_13.01


Building counterfactual input version at  /mnt/share/covid-19/seir-counterfactual-input/2022_12_13.01


100%|█████████████████████████████████████████| 300/300 [03:24<00:00,  1.47it/s]
