In [1]:
%matplotlib inline

# to allow relative imports
import os
from sys_path_util import append_sys_path
append_sys_path()

from lib.experiments.utils.data_repo_api import DataRepoAPI

import pandas as pd
import matplotlib
import seaborn
import matplotlib.pyplot as plt
import numpy as np
from urllib.request import urlopen
import pickle
import datetime
import math
import plotly.express as px
from plotly.graph_objects import Figure
from plotly.subplots import make_subplots

from lib.configuration import DATA_REPO_URL_RAW

In [2]:
# https://censusreporter.org/profiles/05000US09009-new-haven-county-ct/
# https://censusreporter.org/profiles/04000US09-connecticut/
NH_POPULATION = 854757
NYC_POPULATION = 8419000

MAX_TIME = 120
NH_CODE = 5

CT_FILE = 'validation/case-data/ct_covid_count_by_county.csv'
US_FILE = 'validation/case-data/us_covid_count_by_state.csv'
POP_FILE = 'validation/case-data/us_population_by_state.csv'

covid_us = DataRepoAPI.get_csv(US_FILE)

In [3]:
def make_state_population():
    df = DataRepoAPI.get_csv(POP_FILE)
    df['population'] = df['population'].apply(lambda x: x.replace(',', '') if type(x) == str else x, 1)
    df = df.astype({'population': 'float32'})
    records = df.drop('state_name', 1).dropna().to_dict('records')

    dict_out = {x['state']: x['population'] for x in records}

    # not a state by hey...
    dict_out['NYC'] = NYC_POPULATION
    
    return dict_out
    
POP_MAP = make_state_population()

In [4]:
def transform_covid(df):
    # transform dates
    df['date'] = pd.to_datetime(df['date'], format="%m/%d/%Y")
    df = df.sort_values('date').reset_index(drop=True)

    num_cols = ['tot_cases', 'new_cases']
    
    col_transform = {c: 'float32' for c in num_cols if c in df.columns}

    for c in col_transform:
        df[c] = df[c].apply(lambda x: x.replace(',', '') if type(x) == str else x, 1)

    df.replace('NaN', np.NaN, inplace=True)

    # transform numbers
    df = df.astype(col_transform)

    return df


def make_first_wave_map(threshold=0.001):
    """
    First date when tot_cases / population > 0.001 for each state.
    """
    df = transform_covid(covid_us)
    
    first_wave_map = {}
    
    for state in df.state.unique():
        subset = df[df.state == state].sort_values('date')
        if not state in POP_MAP:
            continue
        threshold_cases = POP_MAP[state] * threshold
        date = subset[subset.tot_cases >= threshold_cases].iloc[0].date
        first_wave_map[state] = date.strftime("%Y-%m-%d")
    
    return first_wave_map

FIRST_WAVE_MAP = make_first_wave_map()

In [5]:
def extract_region(df, region_filter):
    # filter new haven county
    region_col = region_filter['column']
    region_val = region_filter['value']
    df_out = df[df[region_col] == region_val].copy()
    df_out.reset_index(drop=True, inplace=True)
    return df_out


def get_data_after_date(df, date: str, max_time=MAX_TIME):
    
    df = df[df['date'] >= np.datetime64(date)].copy()
    df.reset_index(drop=True, inplace=True)
    
    start_date = df.loc[0, 'date'].date()
    df['date'] = df['date'].apply(lambda x: (x.date() - start_date).days)
    
    df = df.loc[df['date'] <= max_time, :].copy()
    
    return df


def calc_new_cases(df):
    first_val = 0
    df['new_cases'] = df['tot_cases'].diff().fillna(first_val).values.tolist()
    
    return df.copy()


def make_counts_relative(df, population):
    df['new_cases'] = [c / population for c in df['new_cases'].values.tolist()]
    df['tot_cases'] = [c / population for c in df['tot_cases'].values.tolist()]
    return df
    
    
def make_new_cases_rolling(df):
    df['new_cases_rolling'] = df['new_cases'].rolling(7).mean().values.tolist()
    return df


def make_new_haven_covid(max_time=MAX_TIME):
    
    start_date = FIRST_WAVE_MAP['CT']
    
    covid = DataRepoAPI.get_csv(CT_FILE)
    df = transform_covid(covid)
    
    nh_filter = {'column': 'county_code', 'value': NH_CODE}
    df_new_haven = extract_region(df, nh_filter)
    
    df_new_haven = get_data_after_date(df_new_haven, start_date, max_time)
    df_new_haven = calc_new_cases(df_new_haven)
    
    df_new_haven = make_counts_relative(df_new_haven, NH_POPULATION)
    df_new_haven = make_new_cases_rolling(df_new_haven)
    
    return df_new_haven

def make_state_covid(state, max_time=MAX_TIME):
    
    start_date = FIRST_WAVE_MAP[state]
    
    df = transform_covid(covid_us)
    
    state_filter = {'column': 'state', 'value': state}
    df_state = extract_region(df, state_filter)
    
    df_state = get_data_after_date(df_state, start_date, max_time)
    
    pop = POP_MAP[state]
    df_state = make_counts_relative(df_state, pop)
    df_state = make_new_cases_rolling(df_state)
    
    return df_state

## Modelled

In [6]:
def get_df(name):
    url = DATA_REPO_URL_RAW + '/validation/' + name + '.pkl'

    with urlopen(url) as f:
        df = pickle.load(f)

    return df


def df_group_mean(df):
    grouped = df.groupby(['time', 'compartment']).mean()
    grouped.reset_index(inplace=True)
    return grouped


def ffill_gap(df, max_time):
    """
    If an experiment earlier than max_time, propagate (forward fill)
    the last row's results up until that time.
    """
    
    ls_experiment_id = []
    ls_time = []
    ls_compartment = []
    ls_value = []
    
    for exp in df.experiment_id.unique():
        
        exp_df = df[df.experiment_id == exp]
        
        last_time = max(exp_df.time.values)

        if last_time >= max_time:
            continue

        fill_gap = [*range(int(last_time)+1, max_time+1)]
    
        values = []
        compartments = exp_df.compartment.unique()
        for c in compartments:
            v = exp_df[(exp_df['time'] == last_time) & (exp_df['compartment'] == c)].value.values[0]
            values.extend([v] * len(fill_gap))
            
        ls_experiment_id.extend([exp] * len(fill_gap) * len(compartments))
        ls_time.extend(fill_gap * len(compartments))
        ls_compartment.extend(sum([[c] * len(fill_gap) for c in exp_df.compartment.unique()], []))
        ls_value.extend(values)
        
        
    fill_df = pd.DataFrame({
        'experiment_id': ls_experiment_id,
        'time': ls_time,
        'compartment': ls_compartment,
        'value': ls_value
    })
    
    if fill_df.empty:
        return df

    df = pd.concat([fill_df, df])
    df.sort_values(by=['experiment_id', 'time'], inplace=True)
    df = df.reset_index(drop=True)
    
    return df


def get_wide(name, max_time):
    
    # get data frame
    df = get_df(name)
    
    # transform
    df = df[df.time <= max_time]
    
    df = ffill_gap(df, max_time)
    
    grouped = df_group_mean(df)
    wide = grouped.pivot(index=['time'], columns=['compartment'], values='value')
    wide.reset_index(inplace=True)
    
    # calc new cases
    first_val = 0
    new_cases = (wide['S'].diff().fillna(first_val) * (-1)).values.tolist()
    new_cases = [int(x) if np.isclose(x, 0) else x for x in new_cases]
    wide['new_cases'] = new_cases
    
    # calc total cases
    wide['tot_cases'] = wide['E'] + wide['I'] + wide['R']
    wide['new_cases_rolling'] = wide['new_cases'].rolling(7).mean().values.tolist()
    
    return wide

## Comparison

In [34]:
px_layout = dict(
    template='seaborn',
    plot_bgcolor='#F5F5F5',
    font_size=16,
    font_color='black',
    width=1000,
    height=700,
    margin=dict(l=25,r=25,b=25,t=25)
)

x_title = 'Days since start of first wave'
y_title_tot = 'Total cases (fraction)'
y_title_new = 'New cases (fraction)'

In [35]:
data_pre = [
    {'name': 'v_seir_mobility_pre', 'title': 'SEIR, M (Pre)'},
    {'name': 'v_seirq_25_mobility_pre', 'title': 'SEIR_Q (p=0.25), M (Pre)'},
    {'name': 'v_seirq_50_mobility_pre', 'title': 'SEIR_Q (p=0.5), M (Pre)'},
    {'name': 'v_seirq_75_mobility_pre', 'title': 'SEIR_Q (p=0.75), M (Pre)'},
]

data_post = [
    {'name': 'v_seir_mobility_post', 'title': 'SEIR, M (Post)'},
    {'name': 'v_seirq_25_mobility_post', 'title': 'SEIR_Q (p=0.25), M (Post)'},
    {'name': 'v_seirq_50_mobility_post', 'title': 'SEIR_Q (p=0.5), M (Post)'},
    {'name': 'v_seirq_75_mobility_post', 'title': 'SEIR_Q (p=0.75), M (Post)'},
]

In [61]:
def make_single_subplot(df_all, high_df_all, d, y):
    
    fig1 = px.line(df_all, x='date', y=y, color='state')
    fig1.update_traces(opacity=0.2, showlegend=False)
    
    fig2 = px.line(high_df_all, x='date', y=y, color='state')
    fig2.update_traces(opacity=0.8, showlegend=False, line=dict(color='magenta', width=2.5))
    
    model_df = get_wide(d['name'], MAX_TIME)
    fig3 = px.line(model_df, x='time', y=y)
    fig3.update_traces(line=dict(color="blue", width=3))
    
    fig = Figure(data = fig1.data + fig2.data + fig3.data)
    
    return fig

In [62]:
def make_comparison_subplots(data, y, high_states, img_name):

    dfs = []
    high_dfs = []
    
    for state in POP_MAP:
        if state == 'NYC':
            continue
        if state in high_states:
            high_dfs.append(make_state_covid(state))
        else:
            dfs.append(make_state_covid(state))

    df_all = pd.concat(dfs)
    high_df_all = pd.concat(high_dfs)
    
    fig = make_subplots(
        rows=2, cols=2, 
        subplot_titles=[d['title'] for d in data],
        horizontal_spacing = 0.075,
        vertical_spacing = 0.15,
    )
    
    r = 1
    c = 1
    for d in data:
        _plot = make_single_subplot(df_all, high_df_all, d, y)

        fig.add_traces(_plot.data, rows=r, cols=c)

        # row and column count
        if c==1:
            c += 1
        else:
            r += 1
            c = 1
    
    fig.update_layout(**px_layout)
    fig.write_image('graphics/' + img_name + '.png')
    
    fig.show()

In [63]:
make_comparison_subplots(data_pre, 'tot_cases', ['CT'], 'validation-pre-total')

In [64]:
make_comparison_subplots(data_post, 'tot_cases', ['CT'], 'validation-post-total')

In [65]:
make_comparison_subplots(data_pre, 'new_cases', ['CT'], 'validation-pre-new')

In [66]:
make_comparison_subplots(data_post, 'new_cases', ['CT'], 'validation-post-new')

In [67]:
ar = make_state_covid('AR')

In [69]:
ar.new_cases.values

array([ 3.77758131e-05,  2.18702076e-05,  2.35270415e-05,  2.08761073e-05,
        1.82251730e-05,  2.05447405e-05,  1.95506401e-05,  1.25919377e-05,
        8.94690311e-06,  3.81071799e-05,  1.78938062e-05,  2.71720761e-05,
        7.85339273e-05,  0.00000000e+00,  1.95506401e-05,  4.00953806e-05,
        2.38584083e-05,  3.71130796e-05,  3.81071799e-05,  3.81071799e-05,
        5.99773875e-05,  1.78938062e-05,  3.64503460e-05,  2.65093426e-05,
        1.50771886e-04,  5.10304844e-05,  5.40127855e-05,  4.87109169e-05,
        3.54562457e-05,  5.00363841e-05,  3.21425778e-05,  8.64867301e-05,
        7.91966609e-05,  7.82025605e-05,  7.95280277e-05,  6.29596886e-05,
        1.24262543e-04,  8.25103287e-05,  1.18629308e-04,  7.48888927e-05,
        1.49115052e-04,  1.07694204e-04,  1.04049170e-04,  1.12664706e-04,
        9.54336332e-05,  1.48452318e-04,  2.42229118e-04,  1.81588997e-04,
        1.34534913e-04,  1.37848581e-04,  9.07944983e-05,  1.37517215e-04,
        1.06700104e-04,  

In [80]:
start = np.datetime64('2020-02-01')
end = np.datetime64('2020-12-31')

covid_us[(covid_us['new_cases'] < 0) & (covid_us['date'] <= end) & (covid_us['date'] >= start)]

Unnamed: 0,date,state,tot_cases,new_cases
353,2020-06-11,GU,182,-1
409,2020-12-24,NC,502214,-16304
1503,2020-05-05,MT,456,-1
3012,2020-12-21,ME,19285,-2
4322,2020-07-24,ME,3757,-1
5203,2020-10-06,VI,1321,-1
5388,2020-05-26,ND,2422,-35
6730,2020-08-17,OR,23451,-293
6889,2020-06-19,LA,48515,-119
9640,2020-07-19,NJ,178541,-10
