# Plotting Notebook for COVID-19 CSV Datasets

Authors (in alphabetical order): Frederic Poitevin, Joao Rodrigues, Andrea Scaiewicz

In [None]:
# Uncomment agg and comment inline for production
%matplotlib agg
# %matplotlib inline

import pathlib
import re

import matplotlib.pyplot as plt
from matplotlib.ticker import MaxNLocator

import numpy as np
import pandas as pd

## Define what data files to read

In [None]:
# Add raw files to plot to this list.
# Must be generated by Data-Wrangler_v2
dslist = [
#     'Data_COVID-19_v2.csv',  # county-level 
    'Data_COVID-19_v2_bycountry.csv',
#     'Data_COVID-19_v2_bycountry_smooth_3.csv',
#     'Data_COVID-19_v2_bycountry_smooth_5.csv',
#     'Data_COVID-19_v2_bycountry_smooth_7.csv'
]
# dslist = [
#      'Data_COVID-19_v2.csv',  # county-level
#      'Data_COVID-19_v2_bystate.csv',
#      'Data_COVID-19_v2_bystate_smooth_3.csv',
#      'Data_COVID-19_v2_bystate_smooth_5.csv',
#      'Data_COVID-19_v2_bystate_smooth_7.csv'
#  ]

# dslist = [
#      'Data_COVID-19_v2.csv',  # combined 
#     'Data_COVID-19_v2.csv',
#     'Data_COVID-19_v2_smooth_3.csv',
#     'Data_COVID-19_v2_smooth_5.csv',
#     'Data_COVID-19_v2_smooth_7.csv'
# ]

# These are prefixes to add to plot names as they as generated
# Must be paired with raw file names above: e.g. UNS = Data_COVID-19.csv
# prefixes = ['UNS', 'SMO3', 'SMO7']
# prefixes = ['UNS', 'SMO3']
prefixes = ['UNS']

Set output/input folders

In [None]:
output_dir = pathlib.Path('..') / 'output'  # directory where the csv files are

plotdir = output_dir / 'plots'  
plotdir.mkdir(parents=True, exist_ok=True)

Get column labels that have numerical data (dates)

In [None]:
def get_date_columns(dataframe):
    date_regex = re.compile('\d{1,2}/\d{1,2}/\d{2,4}')
    cols = dataframe.columns
    return [c for i, c in enumerate(cols) if date_regex.match(c)]  # indexes of the date cols

### Plotting Functions

In [None]:
def setup_figure(title):
    fig, axes = plt.subplots(
        nrows=3,
        ncols=1,
        sharex=True,
        figsize=(6, 6),
        dpi=300
    )
    
    axes[0].set_title(title)
    
    return fig, axes

In [None]:
def plot_data(ax, x, y1, y2, y1_label='', y2_label='', sharey=False):
    
    h1, = ax.plot(x, y1, '.-', color='red')  # returns a tuple-> h1, first elem
    
    if not sharey:
        ax2 = ax.twinx()
    else:
        ax2 = ax

    h2, = ax2.plot(x, y2, '.-', color='black')
    
    # Set Legend Manually
    ax.legend([h1, h2], [y1_label, y2_label])

    # Set ylim
    if sharey:
        max_yy = max(max(y1), max(y2))
        max_y = max_yy + 0.1 * max_yy + 1
        ax.set_ylim((0, max_y))
    else:
        max_y1 = max(y1) + 0.1 * max(y1) + 1
        max_y2 = max(y2) + 0.1 * max(y2) + 1
        ax.set_ylim((0, max_y1))
        ax2.set_ylim((0, max_y2))
    
    # Set Grid
    ax.grid()
    
    # Format Axes Tick Labels
    ax.set_xticks(list(range(0, len(x), 7)))
    ax.set_xticklabels(x[::7], horizontalalignment='left')
    ax.tick_params(axis='x', labelrotation=-45)
       
    # Ensure yaxis has same number of ticks on both sides
    # and that they are integers
    ax.yaxis.set_major_locator(MaxNLocator(nbins=5, integer=True))
    ax2.yaxis.set_major_locator(MaxNLocator(nbins=5, integer=True))

In [None]:
# Data Transformations
def get_change_per_day(data):
    return [0] + (data[1:] - data[:-1]).tolist()

def get_log(data):
    v = np.log10(data + 1e-10)
    return v

In [None]:
def generate_plots_for_dataset(dataframe, fname_prefix='', min_deaths=None):
    
    df = dataframe

    date_cols = get_date_columns(df)
    
    # Create full name seed column
    df.insert(loc=0, column='FullName_Safe', value='')

    df_cols = set(df.columns)
    opt_cols = ['Country_Region_Safe', 'Province_State_Safe', 'County_Name_Safe']
    for col in opt_cols:
        if col in df_cols:
            df['FullName_Safe'] = df['FullName_Safe'] + '_' + df[col].astype(str)
    df['FullName_Safe'] = df['FullName_Safe'].str.replace('nan', '')
    df['FullName_Safe'] = df['FullName_Safe'].str.strip('_')    
    
    entries_list = list(df['FullName_Safe'].unique())
                        
    for idx, entry in enumerate(entries_list, start=1):
        print(f'\tPlotting: {entry} ({idx} out of {len(entries_list)})')

        mask = df['FullName_Safe'] == entry

        cmask = mask & (df['Case_Type'] == 'Confirmed')
        dmask = mask & (df['Case_Type'] == 'Deaths')

        cvals = df.loc[cmask, date_cols].values[0, :]
        dvals = df.loc[dmask, date_cols].values[0, :]
        
        if min_deaths is not None and dvals[-1] < min_deaths:
            print(f'\tSkipped: deaths = {dvals[-1]} < {min_deaths}')
            continue
            
        fig, axes = setup_figure(f'{entry}_{fname_prefix}')

        # Per-day change
        cvals_perday = get_change_per_day(cvals)
        dvals_perday = get_change_per_day(dvals)
        plot_data(
            axes[0],
            date_cols,
            y1=cvals_perday,
            y1_label='Daily New Confirmed Cases',
            y2=dvals_perday,
            y2_label='Daily New Deaths'
        )

        # Cumulative data
        plot_data(
            axes[1],
            date_cols,
            y1=cvals,
            y1_label='Total Confirmed',
            y2=dvals,
            y2_label='Total Deaths'
        )

        # Log
        cvals_log = get_log(cvals)
        dvals_log = get_log(dvals)
        plot_data(
            axes[2], 
            date_cols, 
            y1=cvals_log,
            y1_label='log10(Total Confirmed)',
            y2=dvals_log, y2_label='log10(Total Deaths)',
            sharey=True
        )

        plt.tight_layout()  # ensure tight layout

        # Save PNG figure
        figpath = str(plotdir / f'{fname_prefix}_{entry}_total.png').strip('_')
        plt.savefig(figpath)
        plt.close(fig)  # avoid accumulation of plots in memory.
        

## Now we iterate over datasets and plot

In [None]:
for ds, prefix in zip(dslist, prefixes):
    print(f'Dataset: {ds}')
    df_fpath = output_dir / ds
    df = pd.read_csv(df_fpath)
    generate_plots_for_dataset(df, prefix, min_deaths=50)