## set up workspace

In [None]:
import warnings
from datetime import timedelta

import numpy as np
import pandas as pd

import matplotlib.pyplot as plt
from matplotlib.backends.backend_pdf import PdfPages
import seaborn as sns
sns.set_style('whitegrid')

from db_queries import get_location_metadata

from covid_model_deaths.preprocessing import expanding_moving_average_by_location
from covid_model_deaths.data import add_moving_average_rates

pd.options.display.max_rows = 99
pd.options.display.max_columns = 99
warnings.simplefilter('ignore')


## load data and smooth

In [None]:
loc_df = get_location_metadata(location_set_version_id=664, location_set_id=111)
loc_df = loc_df.loc[loc_df['most_detailed'] == 1]

for date_label, version in [('05_12', '2020_05_12.01'), 
                            ('05_11', '2020_05_11.02'), 
                            ('05_10', '2020_05_10.04')]:
    print(date_label)
    df = pd.read_csv(f'/ihme/covid-19/model-inputs/{version}/full_data.csv')
    df['Date'] = pd.to_datetime(df['Date'])
    df.loc[df['Death rate'] == 0, 'Death rate'] = 0.1 / df['population']
    df['ln(death rate)'] = np.log(df['Death rate'])
    df['location_id'] = df['location_id'].astype(int)
    df = df.merge(loc_df[['location_id', 'location_name']])
    df = df.rename(index=str, columns={'location_name':'Location'})

    locations = df['Location'].unique().tolist()

    smooth_dfs = []
    for n_smooths in range(11):
        if n_smooths == 0:
            smooth_dfs.append(df.copy())
        else:
            smooth_dfs.append(add_moving_average_rates(df.copy(), 'ln(death rate)', -np.inf, n_smooths))

    with PdfPages(f'/ihme/homes/rmbarber/covid-19/smoothing_effect_{date_label}.pdf') as pdf:
        for location in locations:
            # set up figure
            fig, ax = plt.subplots(1, 2, figsize=(16.5, 8.5))

            # plot the data
            for n_smooths, smooth_df in enumerate(smooth_dfs):
                plot_df = smooth_df.loc[smooth_df['Location'] == location].reset_index(drop=True)
                if n_smooths == 0:
                    metadata = dict(color='black', linewidth=3, alpha=0.5, label=n_smooths)
                else:
                    metadata = dict(linewidth=3, alpha=0.75, label=n_smooths)
                if n_smooths == 0:
                    ax[0].scatter(plot_df['Date'], 
                                  np.exp(plot_df['ln(death rate)']) * plot_df['population'],
                                  c='black', s=75, alpha=0.5)
                    ax[1].scatter(plot_df['Date'][1:], 
                                  (np.exp(plot_df['ln(death rate)']) * plot_df['population']).values[1:] - \
                                  (np.exp(plot_df['ln(death rate)']) * plot_df['population']).values[:-1],
                                  c='black', s=75, alpha=0.5)
                ax[0].plot(plot_df['Date'], 
                           np.exp(plot_df['ln(death rate)']) * plot_df['population'],
                           **metadata)
                ax[1].plot(plot_df['Date'][1:], 
                           (np.exp(plot_df['ln(death rate)']) * plot_df['population']).values[1:] - \
                           (np.exp(plot_df['ln(death rate)']) * plot_df['population']).values[:-1],
                           **metadata)

            # major ticks every week, minor ticks every day
            major_ticks = np.arange(0, 70, 7)
            major_ticks = np.array([plot_df['Date'].min() + timedelta(days=int(t)) for t in major_ticks])
            major_ticks = major_ticks[major_ticks <= plot_df['Date'].max()]
            minor_ticks = np.arange(0, 70)
            minor_ticks = np.array([plot_df['Date'].min() + timedelta(days=int(t)) for t in minor_ticks])
            minor_ticks = minor_ticks[minor_ticks <= plot_df['Date'].max()]
            ax[0].set_xticks(major_ticks)
            ax[0].set_xticks(minor_ticks, minor=True)
            ax[0].grid(axis='y', which='major', color='darkgrey', alpha=0.25, linewidth=2)
            ax[0].grid(axis='x', which='major', color='darkgrey', alpha=0.25, linewidth=2)
            ax[0].grid(axis='x', which='minor', color='darkgrey', alpha=0.25, linewidth=0.2)
            ax[1].set_xticks(major_ticks)
            ax[1].set_xticks(minor_ticks, minor=True)
            ax[1].grid(axis='y', which='major', color='darkgrey', alpha=0.25, linewidth=2)
            ax[1].grid(axis='x', which='major', color='darkgrey', alpha=0.25, linewidth=2)
            ax[1].grid(axis='x', which='minor', color='darkgrey', alpha=0.25, linewidth=0.2)

            # other settings

            ax[0].set_ylabel('Cumulative deaths')
            ax[0].axhline(0, color='darkgrey', linestyle='--', linewidth=3)
            ax[0].tick_params(axis='x', rotation=60)
            ax[1].set_ylabel('Daily deaths')
            ax[1].axhline(0, color='darkgrey', linestyle='--', linewidth=3)
            ax[1].tick_params(axis='x', rotation=60) 

            # legend
            ax[0].legend(loc=2)

            # title
            plt.suptitle(location, y=1.0025)

            # save
            plt.tight_layout()
            pdf.savefig()
            plt.close(fig)
