In [1]:
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import matplotlib.dates as mdates
import seaborn as sns
from sklearn.preprocessing import MinMaxScaler
from sklearn.linear_model import LassoCV, LinearRegression
from sklearn.metrics import mean_squared_error
import warnings
warnings.filterwarnings('ignore')
plt.rcParams['font.family'] = 'Times New Roman'
plt.rcParams['font.size'] = 16

### I. Reusable functions for loading data and running models

In [2]:
'''
Loads data for a single outbreak.
Arguments: Two filenames for CSV files, the first for the reports CSV and the second for the trends CSV.
Outputs: Two dataframes, one containing the reports and one containing the trends
'''
def load_data(reports_fname, trends_fname):
    reports = pd.read_csv(reports_fname, index_col=0)
    reports.index = pd.to_datetime(reports.index)
    if 'Epi Week' in reports.columns:
        reports = reports.drop('Epi Week', axis=1)
    trends = pd.read_csv(trends_fname, index_col=0)
    trends.index = pd.to_datetime(trends.index)
    return reports, trends

In [3]:
'''
Converts a time-series into a supervised learning problem.
Inputs: Epidemiological timeseries (list/array/series), trends dataframe, list of autoregressive lags to be used as 
predictors
Outputs: X, a matrix of predictors, and Y, a vector of targets
'''
# Convert timeseries to supervised learning problem
def to_supervised(y, trends, lags, max_train):
    if lags is not None:
        xcols = []
        for lag in lags:
            xcols.append(y.shift(lag))
        x = pd.concat(xcols, axis=1)
        x.columns = ['lag_' + str(lag) for lag in lags]
        nullrows = x[(pd.isnull(x).any(axis=1))].index
        if trends is not None:
            x = pd.merge(x, trends, on='date')
        x = x.drop(nullrows, axis=0)
        y = y.drop(nullrows, axis=0)
    elif trends is not None:
        x = trends   
    return x, y

In [4]:
'''
Performs train/test split on X and Y matrices.
Inputs: X, a matrix of predictors, Y, a vector of targets, a list of dates in the training period, and a list of dates
in the test period
Outputs: X and Y matrices for each train and test periods
'''
def split(x, y, dates_train, dates_test):
    min_train, max_train = dates_train
    min_test, max_test = dates_test
    x_train = x[(x.index >= min_train) & (x.index <= max_train)]
    x_test = x[(x.index >= min_test) & (x.index <= max_test)]
    y_train = y[(y.index >= min_train) & (y.index <= max_train)]
    y_test = y[(y.index >= min_test) & (y.index <= max_test)]
    return x_train, y_train, x_test, y_test

In [5]:
'''
Given input data (epi timeseries and trends dataframe), a time horizon of prediction, and a set of prediction dates,
use earlier dates to train model and then uses regression to predict for each of the prediction dates.
'''
def run(epi, trends, th, predict_dates, use_lags=True, use_trends=True, weekly=True):
    min_predict_date, max_predict_date = predict_dates[0], predict_dates[1]
    predictions = []
    coefs = []
    days_before = th*7 if weekly else th
    if use_lags and use_trends:
        x, y = to_supervised(epi, trends, [th], min_predict_date - pd.Timedelta(days=days_before))
    elif use_trends and not use_lags:
        x, y = to_supervised(epi, trends, None, min_predict_date - pd.Timedelta(days=days_before))
    elif use_lags and not use_trends:
        x, y = to_supervised(epi, None, [th], min_predict_date - pd.Timedelta(days=days_before))
    x_train, y_train, x_test, y_test = split(x, y, (epi.index[0], min_predict_date - pd.Timedelta(days=days_before)), 
                                            (min_predict_date, max_predict_date))
    lr = LassoCV(max_iter=100000)
    lr.fit(x_train, y_train)
    return predict_dates, lr.predict(x_test), lr.coef_

In [6]:
'''
Give input data (reports dataframe and trends dataframe) and a time horizon of prediction, runs theoretical linear 
regression model using trends and ground truth time series from data frame.
'''
def run_model_theoretical(reports, trends, th, use_lags=True, use_trends=True, weekly=True):
    epi = reports['groundtruth']
    dates, yhats, coefs = [], [], []
    predict_dates = epi.index[2*(th+1):] if weekly else epi.index[2*(7*(th+1)):]
    for predict_date in predict_dates:
        date, preds, coef = run(epi, trends, th, (predict_date, predict_date), use_lags, use_trends, weekly)
        dates.append(date[0])
        yhats.append(preds[0])
        coefs.append(coef)
    return predict_dates, yhats, coefs

In [7]:
'''
Given input data (reports dataframe and trends dataframe) and a reporting delay (in weeks), runs linear regression 
model for each report, in each case predicting until the end of the next report
'''
def run_model_practical(reports, trends, delay, use_lags=True, use_trends=True, weekly=True):
    reportids = reports.columns[:-1]
    results = {reportid:{'dates':None, 'yhats':None} for reportid in reportids}
    for r, reportid in enumerate(reportids):
        next_reportid = reportids[r+1] if r < len(reportids) - 1 else None
        epi = reports[reportid]
        min_predict_date = epi.index[len(epi.dropna()) - delay]
        max_predict_date = pd.to_datetime(next_reportid) if next_reportid is not None else epi.index[-1]
        predict_dates = epi[(epi.index >= min_predict_date) & (epi.index <= max_predict_date)].index
        dates, yhats = [], []
        for i, predict_date in enumerate(predict_dates):
            date, preds, _ = run(epi, trends, i+1, (predict_date, predict_date), use_lags, use_trends, weekly)
            dates.append(date[0])
            yhats.append(preds[0])
        results[reportid]['dates'] = dates
        results[reportid]['yhats'] = yhats
        results[reportid]['dates_report'] = epi.index[:len(epi.dropna())]
        results[reportid]['report'] = epi[:len(epi.dropna())]
    return results

### II. Run Models - Theoretical Version

In [8]:
def rmse(x, y):
    return np.sqrt(mean_squared_error(x, y))

In [9]:
colors = {'AR':'#658E9C', 'GT':'#8CBA80', 'ARGO':'#CD5555'}
models = ['AR', 'GT', 'ARGO']
countries = ['angola', 'colombia', 'drc', 'madagascar', 'yemen']
titles = ['Yellow Fever in Angola', 'Zika in Colombia', 'Ebola in the DRC', 'Pneumonic Plague in Madagascar', 'Cholera in Yemen']

In [61]:
'''
Given country, plotting axes (optional), and time horizon, run theoretical models for time horizon and plot 
'''
def run_all_theoretical(country, line_plot_ax, heatmap_ax, th, weekly=True, models = ['AR', 'GT', 'ARGO']):
    reports, trends = load_data('data/'  + country + '/' + country + 'reports.csv', 
                                         'data/' + country + '/' + country + 'predictors.csv')
    results = {model:{} for model in models}
    # Get results for each model
    for model in results.keys():
        use_lags = False if model == 'GT' else True
        use_trends = False if model == 'AR' else True
        results[model]['dates'], yhats, results[model]['coefs'] = run_model_theoretical(reports, trends, th, 
                                                                                        use_lags, use_trends, weekly)
        results[model]['corr'] = np.corrcoef(list(reports['groundtruth'].values)[-len(yhats):], yhats)[0][1]
        results[model]['rmse'] = rmse(list(reports['groundtruth'].values)[-len(yhats):], yhats)
        results[model]['yhats'] = yhats
        results_series = reports[['groundtruth']]
        results_series = results_series.merge(pd.DataFrame(results[model], index=results[model]['dates'])[['yhats']],
                                            how='outer', left_index=True, right_index=True)
        results_series = results_series.rename({'yhats':model}, axis='columns')
        results_series.to_csv('results/' + country + '_' + str(th) + '.csv', index=False)
    # Plot of predictions
    if line_plot_ax is not None:
        line_plot_ax.fill_between(reports.index, 0, reports['groundtruth'], color='lightgrey', label='Ground truth')
        for model in results.keys():
            line_plot_ax.plot(results[model]['dates'], results[model]['yhats'], color=colors[model], 
                              label = model + '', linewidth=3)#', corr = %.2f ' % results[model]['corr'], linewidth=3)
        line_plot_ax.spines['right'].set_visible(False)
        line_plot_ax.spines['top'].set_visible(False)
        #line_plot_ax.legend(loc='best')
        line_plot_ax.set_xlim(reports.index[0], reports.index[-1])
        line_plot_ax.set_ylim(0, max(reports['groundtruth'])*1.3)
    if heatmap_ax is not None:
        coefs = results['ARGO']['coefs']
        for i in range(len(reports.index) - len(yhats)):
            to_add = np.zeros(len(results['ARGO']['coefs'][0]))
            coefs = [to_add] + coefs
        coefs = np.array(coefs).T
        #mask = pd.DataFrame(coefs).isnull()
        heatmap_ax.set_xlim(min(reports.index), max(reports.index))
        sns.heatmap(coefs, cmap='RdBu_r', vmin=-1, vmax=1, ax=heatmap_ax, cbar=True)
        heatmap_ax.tick_params(axis='x', which='both', bottom=False, top=False,labelbottom=False)
        heatmap_ax.set_yticklabels(['Autoregressive term'] + list(trends.columns))
        for tick in heatmap_ax.get_yticklabels():
            tick.set_rotation(0)
    return results

In [63]:
'''
Run theoretical models for each outbreak for time horizons 1 and 2. Save evaluations for heatmaps.
'''
ths = [2, 1]
evaluations = {country:{} for country in countries}
for c, country in enumerate(countries):
    print(country)
    if True:
        df_corr = pd.DataFrame(index=models, columns=ths)
        df_rmse = pd.DataFrame(index=models, columns=ths)
        #fig, ax = plt.subplots(2, figsize=(20, 9))
        for th in [2, 1]:
            #a = ax[0] if th == 1 else ax[1]
            fig, a = plt.subplots(1, figsize=(20, 5))
            weekly = False if country == 'madagascar' else True
            results = run_all_theoretical(country, a, None, th, weekly)
            df_corr[th] = [results[model]['corr'] for model in models]
            df_rmse[th] = [results[model]['rmse'] for model in models]
            a.set_ylabel('Cases')
            #a.set_title('Assuming Reporting Delay of ' + str(th) + ' Weeks')
            a.set_title('Digital Epidemiological Modeling of ' + titles[c], fontsize='x-large')
            a.legend(loc='best')
            plt.tight_layout()
            #plt.subplots_adjust(top=0.88)
            #plt.suptitle('Digital Epidemiological Modeling of ' + titles[c], fontsize='x-large')
            plt.savefig(country + '_' + str(th) + '.png')
        #plt.savefig('newnewfigures/theoretical/' + country + '_both' + '.png')
        evaluations[country]['corr'] = df_corr
        evaluations[country]['rmse'] = df_rmse

angola
colombia
drc
madagascar
yemen


In [12]:
'''
Figure 1 in paper
'''
ths = [2, 1]
evaluations = {country:{} for country in countries}
fig, ax = plt.subplots(5, 2, figsize=(20, 12), sharex=False, sharey=False)
for c, country in enumerate(countries):
    print(country)
    df_corr = pd.DataFrame(index=models, columns=ths)
    df_rmse = pd.DataFrame(index=models, columns=ths)
    for th in [2, 1]:
        a = ax[c][th-1]
        weekly = False if country == 'madagascar' else True
        results = run_all_theoretical(country, a, None, th, weekly)
        df_corr[th] = [results[model]['corr'] for model in models]
        df_rmse[th] = [results[model]['rmse'] for model in models]
        months = mdates.MonthLocator()
        months_format = mdates.DateFormatter('%b %y')
        a.xaxis.set_major_locator(months)
        a.xaxis.set_major_formatter(months_format)
    evaluations[country]['corr'] = df_corr
    evaluations[country]['rmse'] = df_rmse
ax[0, 1].legend(loc='upper right', prop={'size': 10})
ax[0, 0].set_ylabel('New Cases (Weekly)', size='small')
ax[1, 0].set_ylabel('New Cases (Weekly)', size='small')
ax[2, 0].set_ylabel('New Cases (Weekly)', size='small')
ax[3, 0].set_ylabel('New Cases (Daily)', size='small')
ax[4, 0].set_ylabel('New Cases (Weekly)', size='small')
for i in range(5):
    ax[i, 1].get_yaxis().set_visible(False)
plt.tight_layout(h_pad=2, rect=[0, 0, 1, 0.97])
plt.figtext(0.5, 0.962, 'Yellow Fever in Angola', ha='center', va='center', size='medium')
plt.figtext(0.5, 0.77, 'Zika in Colombia', ha='center', va='center', size='medium')
plt.figtext(0.5, 0.572, 'Ebola in the DRC', ha='center', va='center', size='medium')
plt.figtext(0.5, 0.38, 'Plague in Madagascar', ha='center', va='center', size='medium')
plt.figtext(0.5, 0.19, 'Cholera in Yemen', ha='center', va='center', size='medium')
plt.figtext(0.25, 0.98, 'Assuming reporting delay of 1 week', ha='center', va='center', size='large')
plt.figtext(0.75, 0.98, 'Assuming reporting delay of 2 weeks', ha='center', va='center', size='large')
#plt.figtext(0.01, 0.99, 'a', ha='center', va='center', size='large', weight='bold')
plt.savefig('results1_part1.png')

angola
colombia
drc
madagascar
yemen


In [35]:
# Print evaluation metrics for latex table in paper
evaluations

{'angola': {'corr':              2         1
  AR    0.537874  0.878984
  GT    0.796699  0.789052
  ARGO  0.688780  0.881781, 'rmse':               2          1
  AR    62.650579  17.597337
  GT    17.633600  17.658728
  ARGO  20.420224  13.217858}, 'colombia': {'corr':              2         1
  AR    0.778537  0.920216
  GT    0.727812  0.780606
  ARGO  0.819780  0.929856, 'rmse':                 2           1
  AR    1176.744826  644.237268
  GT    1072.013798  997.450605
  ARGO   823.338520  542.392854}, 'drc': {'corr':              2         1
  AR    0.190784  0.573803
  GT    0.497807  0.581774
  ARGO  0.170057  0.580719, 'rmse':               2          1
  AR    28.108858  15.252126
  GT    18.130043  16.980575
  ARGO  27.413349  15.246174}, 'madagascar': {'corr':              2         1
  AR    0.876100  0.912609
  GT    0.675649  0.740227
  ARGO  0.842592  0.922090, 'rmse':               2          1
  AR    11.648131   8.451411
  GT    15.379634  13.605132
  ARGO  11.8507

In [13]:
'''
Summary heatmaps of model performances across models, outbreaks, and time horizons

fig, ax = plt.subplots(2, 4, figsize=(20, 5))
for c, country in enumerate(countries):
    sns.heatmap(evaluations[country]['corr'].T, ax=ax[0, c], cmap='Blues')
    sns.heatmap(evaluations[country]['rmse'].T, ax=ax[1, c], cmap='Greens_r')
    ax[0, c].set_title(titles[c] + ' - corr', size='medium')
    ax[1, c].set_title(titles[c] + ' - RMSE', size='medium')
ax[0, 0].set_ylabel('Delay (Weeks)')
ax[1, 0].set_ylabel('Delay (Weeks)')
plt.tight_layout()
plt.subplots_adjust(top=0.83)
plt.suptitle('Performance of Models Across Outbreaks and Reporting Delays', fontsize='large')
plt.figtext(0.01, 0.96, 'b', ha='center', va='center', size='large', weight='bold')
plt.savefig('results1_part2.png')
'''

"\nSummary heatmaps of model performances across models, outbreaks, and time horizons\n\nfig, ax = plt.subplots(2, 4, figsize=(20, 5))\nfor c, country in enumerate(countries):\n    sns.heatmap(evaluations[country]['corr'].T, ax=ax[0, c], cmap='Blues')\n    sns.heatmap(evaluations[country]['rmse'].T, ax=ax[1, c], cmap='Greens_r')\n    ax[0, c].set_title(titles[c] + ' - corr', size='medium')\n    ax[1, c].set_title(titles[c] + ' - RMSE', size='medium')\nax[0, 0].set_ylabel('Delay (Weeks)')\nax[1, 0].set_ylabel('Delay (Weeks)')\nplt.tight_layout()\nplt.subplots_adjust(top=0.83)\nplt.suptitle('Performance of Models Across Outbreaks and Reporting Delays', fontsize='large')\nplt.figtext(0.01, 0.96, 'b', ha='center', va='center', size='large', weight='bold')\nplt.savefig('results1_part2.png')\n"

In [25]:
'''
Figures with ARGO predictions and heatmaps of predictor importances
'''
plt.rcParams['font.size'] = 16
height_ratios = [[6, 1], [1, 1], [4, 1], [3, 2]]
figsizes = [9, 10, 4, 6]
for c, country in enumerate(countries):
    print(country)
    for th in [2, 1]:
        fig, ax = plt.subplots(2, figsize=(20, figsizes[c]), gridspec_kw = {'height_ratios':height_ratios[c]})
        weekly = False if country == 'madagascar' else True
        results = run_all_theoretical(country, ax[0], ax[1], th, weekly, ['ARGO'])
        plt.tight_layout()
        plt.subplots_adjust(top=0.93)
        #plt.suptitle('Coefficient Analysis for ARGO model: ' + titles[c], fontsize='large')
        plt.suptitle('Assuming Reporting Delay of ' + str(th) + ' Weeks', fontsize='x-large')
        ax[0].set_ylabel('Cases')
        plt.savefig('newnewfigures/coef_analysis/cbar' + country + str(th) + '.png')
        #if country == 'madagascar' and th == 1:
        #    plt.suptitle('Feature Importances for Nowcasting Plague in Madagascar with ARGO (Assuming One Week Reporting Delay)', size='x-large')
        #    plt.savefig('newnewfigures/coef_analysis/' + country + str(th) + 'maintext.png')
        

angola
colombia
drc
madagascar


### III. Run Models - Practical Version

In [14]:
plt.rcParams['font.size'] = 20

In [37]:
def run_all_practical(country, ax, reportids):
    reports, trends = load_data('data/'  + country + '/' + country + 'reports.csv', 
                                'data/' + country + '/' + country + 'predictors.csv')
    for m, model in enumerate(models):
        use_lags = False if model == 'GT' else True
        use_trends = False if model == 'AR' else True
        delay = 14 if country == 'madagascar' else 2
        weekly = False if country == 'madagascar' else True
        results = run_model_practical(reports, trends, delay, use_lags=use_lags, use_trends=use_trends, weekly=weekly)
        all_reportids = [x for x in results.keys() if x != 'groundtruth']
        i = 0
        for r, reportid in enumerate(all_reportids):
            if reportid in reportids:
            #if reportid in ['5/8/16', '5/15/16', '5/23/16', '5/30/16', '6/6/16']:
            #if reportid in ['2/24/16', '4/14/16', '5/5/16', '5/12/16', '5/19/16']:
            #if reportid in ['9/7/18', '9/14/18', '11/29/18', '9/27/18', '10/11/18']:
            #if reportid in ['10/13/17', '10/18/17', '11/6/17', '11/14/17', '11/20/17']:
            #if reportid in ['10/5/17', '10/16/17', '11/9/17']:
            #if reportid in ['2/24/16', '4/14/16', '5/19/16']:
            #if i < 5:
            #if reportid in ['10/5/17', '10/16/17', '11/9/17']:
            #if reportid in ['9/7/18', '11/29/18', '12/28/18']:
                if m == 0:
                    ax[i].fill_between(reports.index, 0, reports['groundtruth'], color='lightgrey', label='Ground Truth')
                    ax[i].plot(results[reportid]['dates_report'], results[reportid]['report'], color='black', 
                              label='Reported Epi Curve', linewidth=3)
                    ax[i].axvline(pd.to_datetime(reportid), dashes=[2, 2], color='black', label='Report Released', linewidth=4)
                    if i != len(all_reportids) - 1:
                        ax[i].axvline(pd.to_datetime(all_reportids[r+1]), dashes=[2, 2], color='gold', 
                                     label='Next Report Released', linewidth=3)
                    ax[i].set_title('Report Released ' + reportid, size='medium')
                    ax[i].set_xlim(reports.index[0], reports.index[-1])
                    ax[i].set_ylim(0, 1.3*max(reports['groundtruth']))
                    ax[i].spines['right'].set_visible(False)
                    ax[i].spines['top'].set_visible(False)
                    months = mdates.MonthLocator()
                    months_format = mdates.DateFormatter('%b %y')
                    ax[i].xaxis.set_major_locator(months)
                    ax[i].xaxis.set_major_formatter(months_format)
                dates_ext = results[reportid]['dates']
                yhats_ext = results[reportid]['yhats']
                dates_ext = [results[reportid]['dates_report'][-(delay + 1)]] + list(results[reportid]['dates'])
                yhats_ext = [results[reportid]['report'][-(delay + 1)]] + list(results[reportid]['yhats'])
                ax[i].plot(dates_ext, yhats_ext, color=colors[model], label=model, linewidth=3)
                i = i + 1;
    #ax[0].legend(loc='best')
    #plt.tight_layout()
    #plt.subplots_adjust(top=top)
    #plt.suptitle('Digital Epidemiological Models vs. Reported Values - ' + titles[countries.index(country)], 
    #            fontsize=30)
    #plt.savefig('newnewfigures/practical/' + country + '_medselection.png')

In [64]:
# Figure 3 in paper
fig, ax = plt.subplots(9, 2, figsize=(20, 18))
print('angola')
run_all_practical('angola', ax[:3, 0], ['5/8/16', '5/23/16', '6/6/16'])
print('colombia')
run_all_practical('colombia', ax[:3, 1], ['2/24/16', '4/14/16', '5/19/16'])
print('drc')
run_all_practical('drc', ax[3:6, 0], ['9/7/18', '11/29/18', '12/28/18'])
print('madagascar')
run_all_practical('madagascar', ax[3:6, 1], ['10/5/17', '10/16/17', '11/9/17'])
print('yemen')
run_all_practical('yemen', ax[6:, 0], ['9/3/17', '10/1/17', '11/5/17'])
ax[0, 1].legend(loc='upper right', prop={'size':10})
for a in [ax[0, 0], ax[1, 0], ax[3, 0], ax[4, 0], ax[0, 1], ax[1, 1], ax[3, 1], ax[4, 1], ax[6, 0], ax[7, 0]]:
    a.get_xaxis().set_visible(False)
plt.tight_layout(rect=[0, 0.04, 1, 0.98])
for a in [ax[3, 0], ax[4, 0], ax[5, 0], ax[3, 1], ax[4, 1], ax[5, 1], ax[6, 0], ax[7, 0], ax[8, 0]]:
    pos1 = a.get_position() 
    pos2 = [pos1.x0, pos1.y0-0.02,  pos1.width, pos1.height] 
    a.set_position(pos2) 
for a in [ax[6, 0], ax[7, 0], ax[8, 0]]:
    pos1 = a.get_position() 
    pos2 = [pos1.x0, pos1.y0-0.02,  pos1.width, pos1.height] 
    a.set_position(pos2) 
for a in [ax[6, 1], ax[7, 1], ax[8, 1]]:
    fig.delaxes(a)
plt.figtext(0.27, 0.973, 'Yellow Fever in Angola', ha='center', va='center', size='large')
plt.figtext(0.76, 0.973, 'Zika in Colombia', ha='center', va='center', size='large')
plt.figtext(0.27, 0.645, 'Ebola in the DRC', ha='center', va='center', size='large')
plt.figtext(0.76, 0.645, 'Plague in Madagascar', ha='center', va='center', size='large')
plt.figtext(0.27, 0.32, 'Cholera in Yemen', ha='center', va='center', size='large')
plt.savefig('results3.png')

angola
colombia
drc
madagascar
yemen


In [41]:
fig, ax = plt.subplots(12, figsize=(20, 30))
run_all_practical('yemen', ax, [x for x in pd.read_csv('data/yemen/yemenreports.csv').columns if '/' in x])
ax[0].legend(loc='best')
plt.tight_layout()
plt.savefig('yemenreports.png')

In [26]:
#figsizes = [(20, 35), (20, 20), (20, 55), (20, 40)]
tops = [0.95, 0.93, 0.96, 0.95]
for c, country in enumerate(['angola', 'colombia']):
    print(country)
    run_all_practical(country, (20, 20), 0.92)

angola
colombia


In [27]:
run_all_practical('drc', (20, 11), 0.86)

### IV. Old code that I'm not deleting

In [None]:
def reshaping(reports, trends, reportid, term):
    epi = reports[reportid]
    trend = trends[term]
    total_cases = np.sum(epi)
    factor = total_cases/np.sum(trend)
    return factor*trend

In [None]:
def run_all_reshaping(reports, trends, term):
    reportids = [col for col in reports.columns if col != 'groundtruth']
    results = pd.DataFrame(index=reports.index)
    for reportid in reportids:
        reshaped = reshaping(reports, trends, reportid, term)
        results[reportid] = list(reshaped.values) + [np.nan for i in range(len(results.index) - len(reshaped))]
    return results

In [None]:
def mae(ytrue, yhat, lims):
    ytrue = ytrue[lims[0]:lims[1]]
    yhat = yhat[lims[0]:lims[1]]
    errors = [np.abs(ytrue[i] - yhat[i]) for i in range(len(ytrue)) if np.isfinite(ytrue[i]) and np.isfinite(yhat[i])]
    return np.mean(errors)

In [None]:
def plot_errors(ax, reports, models, labels, colors):
    reportids = [col for col in reports.columns if col != 'groundtruth']
    x = np.arange(len(reportids))
    width = 0.2
    errors_reports = []
    for m, model in enumerate(models):
        errors = []
        for r, reportid in enumerate(reportids):
            start = np.argwhere(np.isfinite(model[reportid])).flatten()[0]
            end = np.argwhere(np.isfinite(reports[reportid])).flatten()[-1] + 1
            if r > 1:
                errors.append(mae(reports['groundtruth'], model[reportid], (start, end)))
            else:
                errors.append(np.nan)
            if m == 0:
                errors_reports.append(mae(reports['groundtruth'], reports[reportid], (start, end)))
        ax.bar(x + (m+1)*width, errors, width, label=labels[m], color=colors[m])
    ax.bar(x, errors_reports, width, label='Reported Cases', color='grey')
    ax.set_xticklabels(reportids)

In [21]:
from sklearn.preprocessing import MinMaxScaler
import matplotlib.dates as mdates
def eda(ax, reports_fname, predictors_fname, predictor, labels, title, color):
    reports, predictors = load_data(reports_fname, predictors_fname)
    scaler = MinMaxScaler()
    scaled_groundtruth = scaler.fit_transform(reports['groundtruth'].values.reshape(-1, 1)).flatten()
    scaled_searches = scaler.fit_transform(predictors[predictor].values.reshape(-1, 1)).flatten()
    corr = np.corrcoef(scaled_groundtruth, scaled_searches)[0][1]
    ax.fill_between(reports.index, 0, scaled_groundtruth, color='lightgrey', label=labels[0])
    ax.plot(predictors.index, scaled_searches, color=color, label=labels[1], linewidth=4)
    #ax.fill_between(reports.index, 0, reports['groundtruth'], color='lightgrey', label=labels[0])
    ax.legend(loc='upper right')
    ax.set_ylim(0, 1.4)
    #ax.set_ylim(0, 1.4*max(reports['groundtruth']))
    months = mdates.MonthLocator()
    months_format = mdates.DateFormatter('%b %y')
    ax.xaxis.set_major_locator(months)
    ax.xaxis.set_major_formatter(months_format)
    ax.text(0.01, 0.93, 'corr = %.2f' % corr, transform = ax.transAxes)
    ax.set_title(title)
    ax.spines['right'].set_visible(False)
    ax.spines['top'].set_visible(False)

In [30]:
fig, ax = plt.subplots(1, figsize=(20, 5))
eda(ax, 'data/yemen/yemenreports.csv', 'data/yemen/yemenpredictors.csv', 'cholera', ('Cases of Cholera in Yemen (scaled to range [0, 1])', "Searches for 'cholera' in Yemen (scaled to range [0, 1])"), 'Cholera in Yemen', 'blue')
plt.tight_layout()
plt.savefig('yemeneda.png')

In [34]:
# EDA
plt.rcParams['font.size'] = 17
colors = ['orange', 'mediumseagreen', 'royalblue', 'salmon', 'purple']
fig, ax = plt.subplots(3, 2, figsize=(20, 12))
ax = ax.flatten()
eda(ax[0], 'data/angola/angolareports.csv', 'data/angola/angolapredictors.csv', 'yellow fever', ('Confirmed cases of yellow fever in Angola (scaled to range [0, 1])', "Searches for 'yellow fever' in Angola (scaled to range [0, 1])"), 'Yellow Fever in Angola', colors[0])
eda(ax[1], 'data/colombia/colombiareports.csv', 'data/colombia/colombiapredictors.csv', 'zika', ('Suspected cases of Zika in Colombia (scaled to range [0, 1])', "Searches for 'zika' in Colombia (scaled to range [0, 1])"), 
    'Zika in Colombia', colors[1])
eda(ax[2], 'data/drc/drcreports.csv', 'data/drc/drcpredictors.csv', 'ebola', ('Suspected cases of Ebola in the DRC (scaled to range [0, 1])', "Searches for 'ebola' in the DRC (scaled to range [0, 1])"), 'Ebola in the DRC', colors[2])
eda(ax[3], 'data/madagascar/madagascarreports.csv', 'data/madagascar/madagascarpredictors.csv', 'peste', ('Confirmed cases of Plague in Madagascar (scaled to range [0, 1])', "Searches for 'peste' in Madagascar (scaled to range [0, 1])"), 'Plague in Madagascar', colors[3])
eda(ax[4], 'data/yemen/yemenreports.csv', 'data/yemen/yemenpredictors.csv', 'cholera', ('Suspected cases of Cholera in Yemen (scaled to range [0, 1])', "Searches for 'cholera' in Yemen (scaled to range [0, 1])"), 'Cholera in Yemen', colors[4])
ax[0].set_ylabel('Cases')
plt.tight_layout()
fig.delaxes(ax[5])
plt.savefig('eda_2.png')

In [62]:
months = mdates.MonthLocator()
months_format = mdates.DateFormatter('%b %y')
plt.rcParams['font.size'] = 20
fig, ax = plt.subplots(1, figsize=(20, 5))
reports, _ = load_data('data/madagascar/madagascarreports.csv', 'data/madagascar/madagascarpredictors.csv')
ax.fill_between(reports.index, 0, reports['groundtruth'], color='lightgrey', label='Suspected cases by date of symptoms onset as of Dec. 4')
ax.plot(reports.index, reports['10/13/17'], color='black', label='Suspected cases by date of symptoms onset as of Oct. 13', linewidth=4)
ax.set_ylim(0, 1.1*max(reports['groundtruth']))
months = mdates.MonthLocator()
months_format = mdates.DateFormatter('%b %y')
ax.xaxis.set_major_locator(months)
ax.xaxis.set_major_formatter(months_format)
ax.spines['right'].set_visible(False)
ax.spines['top'].set_visible(False)
ax.legend(loc='center left')
ax.set_title('Plague in Madagascar, September - November 2017')
ax.set_ylabel('New Cases')
plt.tight_layout()
plt.savefig('newnewfigures/madagascar_2')