## set up workspace

In [None]:
import os

import warnings
from datetime import timedelta

import numpy as np
import pandas as pd

import matplotlib.pyplot as plt
from matplotlib.backends.backend_pdf import PdfPages
import seaborn as sns
sns.set_style('whitegrid')

from db_queries import get_location_metadata, get_population

from covid_model_deaths.preprocessing import expanding_moving_average_by_location

pd.options.display.max_rows = 99
pd.options.display.max_columns = 99
warnings.simplefilter('ignore')

MODEL_INPUTS_VERSION = '2020_05_03.24'
SNAPSHOT_VERSION = '2020_05_03.02'
# US_MODEL = '2020_05_03_US'
# GLOBAL_MODEL = '2020_05_03_Europe'


## define smoother

In [None]:
def add_moving_average(data: pd.DataFrame, smoooth_var: str,
                       rate_threshold: float, n_smooths: int = 3) -> pd.DataFrame:
    """Smooths over the log age specific death rate.

    Parameters
    ----------
    data
        The data with the age specific death rate to smooth over.
    rate_threshold
        The minimum age specific death rate.  Values produced in the
        averaging will be pinned to this.

    Returns
    -------
        The same data with the log asdr replaced with its average and a new
        column with the original observed asdr.

    """
    required_columns = ['location_id', 'Date', 'Days', smoooth_var]
    assert set(required_columns).issubset(data.columns)
    data[f'Observed {smoooth_var}'] = data[smoooth_var]
    # smooth n times
    for i in range(n_smooths):
        moving_average = expanding_moving_average_by_location(data, smoooth_var)
        # noinspection PyTypeChecker
        moving_average[moving_average < rate_threshold] = rate_threshold
        data = data.set_index(['location_id', 'Date'])
        data = (pd.concat([data.drop(columns=smoooth_var), moving_average], axis=1)
                .fillna(method='pad')
                .reset_index())

    return data


## load data and smooth

In [None]:
loc_df = get_location_metadata(location_set_version_id=655, location_set_id=111)
loc_df = loc_df.loc[((loc_df['most_detailed'] == 1) & (loc_df['parent_id'] != 570)) | (loc_df['location_id'] == 570)]

wa_pop_df = get_population(decomp_step='step4', gbd_round_id=6,
                        location_id=570, year_id=2019,
                        age_group_id=22, sex_id=3)


# get WA from JHU snapshot
data_dir = f'/ihme/covid-19/snapshot-data/{SNAPSHOT_VERSION}/johns_hopkins_repo/COVID-19-master/csse_covid_19_data/csse_covid_19_daily_reports'
date_files = os.listdir(data_dir)
date_files = sorted([i for i in date_files if i.endswith('.csv')])
wa_dfs = []
for date_file in date_files:
    wa_df = pd.read_csv(f'{data_dir}/{date_file}')
    wa_df = wa_df.rename(index=str, columns={'Province_State':'Province/State',
                                       'Country_Region':'Country/Region'})
    wa_df = wa_df.loc[(wa_df['Province/State'] == 'Washington') & (wa_df['Country/Region'] == 'US')]
    wa_df['Date'] = pd.to_datetime(date_file[:-4])
    wa_df = wa_df.groupby(['Province/State', 'Country/Region', 'Date'], as_index=False)['Confirmed'].sum()
    wa_dfs.append(wa_df)
wa_df = pd.concat(wa_dfs).reset_index(drop=True)
wa_df['population'] = wa_pop_df['population'].item()
wa_df['location_id'] = wa_pop_df['location_id'].item()
wa_df['Confirmed case rate'] = wa_df['Confirmed'] / wa_df['population']
wa_df = wa_df[['location_id', 'Date', 'Confirmed', 'Confirmed case rate', 'population']]

# get rest of data
df = pd.read_csv(f'/ihme/covid-19/model-inputs/{MODEL_INPUTS_VERSION}/full_data.csv')
df['Date'] = pd.to_datetime(df['Date'])
df = df.loc[~df['Confirmed'].isnull()]
df = df[['location_id', 'Date', 'Confirmed', 'Confirmed case rate', 'population']].append(wa_df).reset_index(drop=True)
df['ln(case rate)'] = np.log(df['Confirmed case rate'])
df.loc[df['Confirmed'] == 0, 'ln(case rate)'] = np.log(0.1 / df['population'])
df['day0'] = df.groupby('location_id', as_index=False)['Date'].transform(min)
df['Days'] = df.apply(lambda x: (x['Date'] - x['day0']).days, axis=1)
df = df[['location_id', 'Date', 'Days', 'ln(case rate)', 'population']]
df = loc_df[['location_id', 'location_name']].merge(df)
locations = df['location_id'].unique().tolist()

smooth_dfs = []
for n_smooths in range(11):
    if n_smooths == 0:
        smooth_dfs.append(df.copy())
    else:
        smooth_dfs.append(add_moving_average(df.copy(), 'ln(case rate)', -np.inf, n_smooths))


## save cases (move this into production pipeline)

In [None]:
# save_df = smooth_dfs[10]
# us_locs = loc_df.loc[loc_df['path_to_top_parent'].str.startswith('102,'), 'location_id'].to_list()
# g_locs = loc_df.loc[~loc_df['path_to_top_parent'].str.startswith('102,'), 'location_id'].to_list()
# save_df.loc[save_df['location_id'].isin(us_locs)].to_csv(f'/ihme/covid-19/deaths/prod/{US_MODEL}/smoothed_cases.csv', 
#                                                          index=False)
# save_df.loc[save_df['location_id'].isin(g_locs)].to_csv(f'/ihme/covid-19/deaths/prod/{GLOBAL_MODEL}/smoothed_cases.csv', 
#                                                         index=False)


## make pictures

In [None]:
with PdfPages('/ihme/homes/rmbarber/covid-19/smoothing_effect_cases_05_03_firstdiff_onestep.pdf') as pdf:
    for location in locations:
        # set up figure
        fig, ax = plt.subplots(1, 2, figsize=(16.5, 8.5))
        
        # plot the data
        for n_smooths, smooth_df in enumerate(smooth_dfs):
            plot_df = smooth_df.loc[smooth_df['location_id'] == location].reset_index(drop=True)
            location_name = plot_df['location_name'][0]
            if n_smooths == 0:
                metadata = dict(color='black', linewidth=3, alpha=0.5, label=n_smooths)
            else:
                metadata = dict(linewidth=3, alpha=0.75, label=n_smooths)
            if n_smooths == 0:
                ax[0].scatter(plot_df['Date'], 
                              np.exp(plot_df['ln(case rate)']) * plot_df['population'],
                              c='black', s=75, alpha=0.5)
                ax[1].scatter(plot_df['Date'][1:], 
                              (np.exp(plot_df['ln(case rate)']) * plot_df['population']).values[1:] - \
                              (np.exp(plot_df['ln(case rate)']) * plot_df['population']).values[:-1],
                              c='black', s=75, alpha=0.5)
            ax[0].plot(plot_df['Date'], 
                       np.exp(plot_df['ln(case rate)']) * plot_df['population'],
                       **metadata)
            ax[1].plot(plot_df['Date'][1:], 
                       (np.exp(plot_df['ln(case rate)']) * plot_df['population']).values[1:] - \
                       (np.exp(plot_df['ln(case rate)']) * plot_df['population']).values[:-1],
                       **metadata)
            
        # major ticks every week, minor ticks every day
        major_ticks = np.arange(0, 70, 7)
        major_ticks = np.array([plot_df['Date'].min() + timedelta(days=int(t)) for t in major_ticks])
        major_ticks = major_ticks[major_ticks <= plot_df['Date'].max()]
        minor_ticks = np.arange(0, 70)
        minor_ticks = np.array([plot_df['Date'].min() + timedelta(days=int(t)) for t in minor_ticks])
        minor_ticks = minor_ticks[minor_ticks <= plot_df['Date'].max()]
        ax[0].set_xticks(major_ticks)
        ax[0].set_xticks(minor_ticks, minor=True)
        ax[0].grid(axis='y', which='major', color='darkgrey', alpha=0.25, linewidth=2)
        ax[0].grid(axis='x', which='major', color='darkgrey', alpha=0.25, linewidth=2)
        ax[0].grid(axis='x', which='minor', color='darkgrey', alpha=0.25, linewidth=0.2)
        ax[1].set_xticks(major_ticks)
        ax[1].set_xticks(minor_ticks, minor=True)
        ax[1].grid(axis='y', which='major', color='darkgrey', alpha=0.25, linewidth=2)
        ax[1].grid(axis='x', which='major', color='darkgrey', alpha=0.25, linewidth=2)
        ax[1].grid(axis='x', which='minor', color='darkgrey', alpha=0.25, linewidth=0.2)
        
        # other settings
        
        ax[0].set_ylabel('Cumulative reported cases')
        ax[0].axhline(0, color='darkgrey', linestyle='--', linewidth=3)
        ax[0].tick_params(axis='x', rotation=60)
        ax[1].set_ylabel('Daily reported cases')
        ax[1].axhline(0, color='darkgrey', linestyle='--', linewidth=3)
        ax[1].tick_params(axis='x', rotation=60) 
        
        # legend
        ax[0].legend(loc=2)
        
        # title
        plt.suptitle(location_name, y=1.0025)
        
        # save
        plt.tight_layout()
        pdf.savefig()
