In [2]:
# Example of training a linear model

# Example Predictor: Linear Predictor

This example contains basic functionality for training and evaluating a linear predictor.

First, a training data set is created from historical case and npi data.

Second, a linear model is trained to predict future cases from prior case data along with prior and future npi data.
The model is an off-the-shelf sklearn Lasso model, that uses a positive weight constraint to enforce the assumption that increased npis has a negative correlation with future cases. An independent model is trained for each number of days into the future we need to predict.

Third, a sample evaluation set is created, and the predictor is applied to this evaluation set to produce prediction results in the correct format.

## Training

In [3]:
import pickle
import numpy as np
import pandas as pd
from sklearn.linear_model import Lasso
from sklearn.model_selection import train_test_split

In [4]:
# Load historical data from URL
URL = 'https://raw.githubusercontent.com/OxCGRT/covid-policy-tracker/master/data/OxCGRT_latest.csv'
df = pd.read_csv(URL, 
                 parse_dates=['Date'],
                 encoding="ISO-8859-1",
                 error_bad_lines=False)

In [5]:
# For testing, restrict training data to that before a hypothetical predictor submission date
HYPOTHETICAL_SUBMISSION_DATE = np.datetime64("2020-07-31")
df = df[df.Date <= HYPOTHETICAL_SUBMISSION_DATE]

In [6]:
# Add RegionID column that combines CountryName and RegionName for easier manipulation of data
df['GeoID'] = df['CountryName'] + '__' + df['RegionName'].astype(str)

In [7]:
# Add new cases column
df['NewCases'] = df.groupby('GeoID').ConfirmedCases.diff().fillna(0)

In [8]:
# Keep only columns of interest
id_cols = ['CountryName',
           'RegionName',
           'GeoID',
           'Date']
cases_col = ['NewCases']
npi_cols = ['C1_School closing',
            'C2_Workplace closing',
            'C3_Cancel public events',
            'C4_Restrictions on gatherings',
            'C5_Close public transport',
            'C6_Stay at home requirements',
            'C7_Restrictions on internal movement',
            'C8_International travel controls',
            'H1_Public information campaigns',
            'H2_Testing policy',
            'H3_Contact tracing']
df = df[id_cols + cases_col + npi_cols]

In [9]:
# Fill any missing case values by interpolation and setting NaNs to 0
df.update(df.groupby('GeoID').NewCases.apply(
    lambda group: group.interpolate()).fillna(0))

In [10]:
# Fill any missing NPIs by assuming they are the same as previous day
for npi_col in npi_cols:
    df.update(df.groupby('GeoID')[npi_col].ffill().fillna(0))

In [11]:
df

Unnamed: 0,CountryName,RegionName,GeoID,Date,NewCases,C1_School closing,C2_Workplace closing,C3_Cancel public events,C4_Restrictions on gatherings,C5_Close public transport,C6_Stay at home requirements,C7_Restrictions on internal movement,C8_International travel controls,H1_Public information campaigns,H2_Testing policy,H3_Contact tracing
0,Aruba,,Aruba__nan,2020-01-01,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0
1,Aruba,,Aruba__nan,2020-01-02,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0
2,Aruba,,Aruba__nan,2020-01-03,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0
3,Aruba,,Aruba__nan,2020-01-04,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0
4,Aruba,,Aruba__nan,2020-01-05,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
61168,Zimbabwe,,Zimbabwe__nan,2020-07-27,78.0,3.0,1.0,2.0,3.0,1.0,2.0,2.0,4.0,2.0,1.0,1.0
61169,Zimbabwe,,Zimbabwe__nan,2020-07-28,192.0,3.0,1.0,2.0,3.0,1.0,2.0,2.0,4.0,2.0,1.0,1.0
61170,Zimbabwe,,Zimbabwe__nan,2020-07-29,113.0,3.0,1.0,2.0,3.0,1.0,2.0,2.0,4.0,2.0,1.0,1.0
61171,Zimbabwe,,Zimbabwe__nan,2020-07-30,62.0,3.0,1.0,2.0,3.0,1.0,2.0,2.0,4.0,2.0,1.0,1.0


In [12]:
# Next 2 cells: Functions for training one model for each day into the future we want to predict

In [13]:
# Create training data across all countries for predicting nb_days_ahead
def create_training_data(nb_days_ahead, nb_lookback_days, df, cases_col, npi_cols):
    X_cols = cases_col + npi_cols
    y_col = cases_col
    X_samples = []
    y_samples = []
    geo_ids = df.GeoID.unique()
    for g in geo_ids:
        gdf = df[df.GeoID == g]
        all_case_data = np.array(gdf[cases_col])
        all_npi_data = np.array(gdf[npi_cols])

        # Create one sample for each day where we have enough data
        nb_total_days = len(gdf)
        for d in range(nb_lookback_days, nb_total_days - nb_days_ahead):
            X_cases = all_case_data[d-nb_lookback_days:d]

            # Take negative of npis to support positive
            # weight constraint in Lasso.
            X_npis = -all_npi_data[d - nb_lookback_days:d + nb_days_ahead]

            # Flatten all input data so it fits Lasso input format.
            X_sample = np.concatenate([X_cases.flatten(),
                                       X_npis.flatten()])
            y_sample = all_case_data[d]
            X_samples.append(X_sample)
            y_samples.append(y_sample)

    X_samples = np.array(X_samples)
    y_samples = np.array(y_samples).flatten()
    return X_samples, y_samples

In [14]:
# Compute mae
def mae(pred, true):
    return np.mean(np.abs(pred - true))

# Create and train Lasso model
def create_and_train_lasso_model(X_samples, y_samples):
    
    # Split data into train and test sets
    X_train, X_test, y_train, y_test = train_test_split(X_samples,
                                                        y_samples,
                                                        test_size=0.2,
                                                        random_state=301)
    
    # Train Lasso model.
    # Set positive=True to enforce assumption that cases are positively correlated
    # with future cases and npis are negatively correlated.
    model = Lasso(alpha=0.1,
                  precompute=True,
                  max_iter=10000,
                  positive=True,
                  selection='random')
    # Fit model
    model.fit(X_train, y_train)
    
    # Evaluate model
    train_preds = model.predict(X_train)
    train_preds = np.maximum(train_preds, 0) # Don't predict negative cases
    print('Train MAE:', mae(train_preds, y_train))

    test_preds = model.predict(X_test)
    test_preds = np.maximum(test_preds, 0) # Don't predict negative cases
    print('Test MAE:', mae(test_preds, y_test))
    
    return model

In [15]:
# Train a model for each day ahead we want to predict

# Set number of past days to use to make predictions
nb_lookback_days = 30

# Maximum number of days ahead we want to predict
max_days_ahead = 30

models = {}
for nb_days_ahead in range(max_days_ahead):
    print('Days ahead predicting:', nb_days_ahead)
    X_samples, y_samples = create_training_data(nb_days_ahead, nb_lookback_days, df, cases_col, npi_cols)
    model = create_and_train_lasso_model(X_samples, y_samples)
    models[nb_days_ahead] = model

Days ahead predicting: 0
Train MAE: 109.47806852777566
Test MAE: 111.87336135749506
Days ahead predicting: 1
Train MAE: 109.87989278138645
Test MAE: 107.03056706099599
Days ahead predicting: 2
Train MAE: 108.77335959846748
Test MAE: 103.4060623031037
Days ahead predicting: 3
Train MAE: 107.27389932267823
Test MAE: 104.99763555616231
Days ahead predicting: 4
Train MAE: 108.06097864432698
Test MAE: 100.93854022311957
Days ahead predicting: 5
Train MAE: 105.18733580292945
Test MAE: 109.01945050916311
Days ahead predicting: 6
Train MAE: 104.9557286882413
Test MAE: 97.46235679625002
Days ahead predicting: 7
Train MAE: 106.20518516020456
Test MAE: 94.72465914539613
Days ahead predicting: 8
Train MAE: 104.53914988259915
Test MAE: 105.25399864501442
Days ahead predicting: 9
Train MAE: 104.32714988505506
Test MAE: 98.2488237218251
Days ahead predicting: 10
Train MAE: 100.34470505431656
Test MAE: 100.0180125529238
Days ahead predicting: 11
Train MAE: 102.38034366300212
Test MAE: 97.2331484944316

In [16]:
# Inspect the learned feature coefficients for each model
# to see what features they're paying attention to.
for nb_days_ahead in range(max_days_ahead):
    print('Model for predicting {} days ahead'.format(nb_days_ahead))

    # Name features
    x_col_names = []
    for d in range(-nb_lookback_days, 0):
        x_col_names.append('Day ' + str(d) + ' ' + cases_col[0])
    for d in range(-nb_lookback_days, nb_days_ahead):
        for col_name in npi_cols:
            x_col_names.append('Day ' + str(d) + ' ' + col_name)

    # View non-zero coefficients
    model = models[nb_days_ahead]
    for (col, coeff) in zip(x_col_names, list(model.coef_)):
        if coeff != 0.:
            print(col, coeff)
    print('Intercept', model.intercept_)
            
    print()

Model for predicting 0 days ahead
Day -7 NewCases 0.28686188448216304
Day -6 NewCases 0.2692580928539549
Day -5 NewCases 0.010494654470650318
Day -4 NewCases 0.04868890241945707
Day -3 NewCases 0.0014844913663233936
Day -1 NewCases 0.4324817936835496
Day -24 C2_Workplace closing 3.6644078141445746
Day -22 C2_Workplace closing 9.719858818171364
Day -21 C2_Workplace closing 0.2383605215692106
Day -20 C4_Restrictions on gatherings 0.006252405293720856
Intercept 22.363098106310645

Model for predicting 1 days ahead
Day -7 NewCases 0.3389724954417338
Day -6 NewCases 0.2058417946714347
Day -5 NewCases 0.03368959356370514
Day -1 NewCases 0.4781742051217004
Day -30 C3_Cancel public events 0.027330975656086044
Day -27 C6_Stay at home requirements 0.7983104738817173
Day -23 C2_Workplace closing 8.124375550540716
Day -22 C2_Workplace closing 4.105492289043539
Day -19 C2_Workplace closing 3.745146947641244
Intercept 23.029405140487995

Model for predicting 2 days ahead
Day -7 NewCases 0.3234962359

In [17]:
# Save models to file
with open('models.pkl', 'wb') as models_file:
    pickle.dump(models, models_file)

## Evaluation

Now that the predictor has been trained and saved, this section contains the functionality for evaluating it on sample evaluation data.

First, a sample evaluation data set is created of the form that is given to the predictor.

Second, the predictor is evaluated on this data set, and a resulting predictions file is produced.

### Create sample evaluation data

In [18]:
# Create hypothetical evaluation data
nb_eval_days = 10
test_df = pd.read_csv(URL, 
                      parse_dates=['Date'],
                      encoding="ISO-8859-1",
                      error_bad_lines=False)

# Pull out relevant evaluation days
test_df = test_df[(test_df.Date > HYPOTHETICAL_SUBMISSION_DATE) & \
                  (test_df.Date <= HYPOTHETICAL_SUBMISSION_DATE + nb_eval_days)]

# Only include columns we would see during evaluation
test_df = test_df[['CountryName', 'RegionName', 'Date'] + npi_cols]

# Fill any missing NPIs by assuming they are the same as previous day
for npi_col in npi_cols:
    test_df.update(test_df.groupby(['CountryName', 'RegionName'])[npi_col].ffill().fillna(0))

In [19]:
# test_df is now in the form of input to a predictor during evaluation
test_df

Unnamed: 0,CountryName,RegionName,Date,C1_School closing,C2_Workplace closing,C3_Cancel public events,C4_Restrictions on gatherings,C5_Close public transport,C6_Stay at home requirements,C7_Restrictions on internal movement,C8_International travel controls,H1_Public information campaigns,H2_Testing policy,H3_Contact tracing
213,Aruba,,2020-08-01,0.0,1.0,0.0,0.0,0.0,1.0,1.0,3.0,2.0,2.0,1.0
214,Aruba,,2020-08-02,0.0,1.0,0.0,0.0,0.0,1.0,1.0,3.0,2.0,2.0,1.0
215,Aruba,,2020-08-03,0.0,1.0,0.0,0.0,0.0,1.0,1.0,3.0,2.0,2.0,1.0
216,Aruba,,2020-08-04,0.0,1.0,0.0,4.0,0.0,1.0,1.0,3.0,2.0,2.0,1.0
217,Aruba,,2020-08-05,0.0,1.0,0.0,4.0,0.0,1.0,1.0,3.0,2.0,2.0,1.0
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
61178,Zimbabwe,,2020-08-06,3.0,1.0,2.0,3.0,1.0,2.0,2.0,4.0,2.0,1.0,1.0
61179,Zimbabwe,,2020-08-07,3.0,1.0,2.0,3.0,1.0,2.0,2.0,4.0,2.0,1.0,1.0
61180,Zimbabwe,,2020-08-08,3.0,1.0,2.0,3.0,1.0,2.0,2.0,4.0,2.0,1.0,1.0
61181,Zimbabwe,,2020-08-09,3.0,1.0,2.0,3.0,1.0,2.0,2.0,4.0,2.0,1.0,1.0


### Apply predictor to the evaluation data

In [29]:
def predict(start_date: str, end_date: str, path_to_ips_file: str):
    """
    Generates a file with daily new cases predictions for the given countries, regions and npis, between
    start_date and end_date, included.
    :param start_date: day from which to start making predictions, as a string, format YYYY-MM-DDD
    :param end_date: day on which to stop making predictions, as a string, format YYYY-MM-DDD
    :param path_to_ips_file: path to a csv file containing the intervention plans between start_date and end_date
    :return: Nothing. Saves a csv file called 'start_date_end_date.csv'
    with columns "CountryName,RegionName,Date,PredictedDailyNewCases"
    """
    
    # Add RegionID column that combines CountryName and RegionName for easier manipulation of data\n",
    test_df['GeoID'] = test_df['CountryName'] + '__' + test_df['RegionName'].astype(str)

    # Copy the test data frame
    pred_df = test_df[id_cols].copy()
    # Keep only the requested prediction period.
    # Note: this period *might* be in the future, and pred_df doesn't necessarily contain the requested rows
    pred_df = pred_df[(pred_df.Date >= start_date) & (pred_df.Date <= end_date)]

    # Load historical data to use in making predictions in the same way 
    # df is loaded above to make training data (just copy the reference here for simplicity)
    hist_df = df
    
        # Load models
    with open('models.pkl', 'rb') as models_file:
        models = pickle.load(models_file)
        
    # Make predictions for each country,region pair

    geo_pred_dfs = []
    for g in test_df.GeoID.unique():
        print('\nPredicting for', g)

        # Pull out all relevant data for country c
        hist_gdf = hist_df[hist_df.GeoID == g]
        test_gdf = test_df[test_df.GeoID == g]
        X_cases = np.array(hist_gdf[cases_col])[-nb_lookback_days:]
        X_hist_npis = np.array(hist_gdf[npi_cols])[-nb_lookback_days:]
        future_npi_data = np.array(test_gdf[npi_cols])

        # Make prediction for each day
        geo_preds = []
        nb_days_to_predict = len(future_npi_data)
        for days_ahead in range(nb_days_to_predict):

            # Prepare data
            X_future_npis = future_npi_data[:days_ahead]
            X_npis = np.concatenate([X_hist_npis, X_future_npis])
            X = np.concatenate([X_cases.flatten(),
                                X_npis.flatten()])

            # Grab the right model
            model = models[days_ahead]

            # Make the prediction (reshape so that sklearn is happy)
            pred = model.predict(X.reshape(1, -1))[0]
            pred = max(0, pred)
            geo_preds.append(pred)
            print(pred)

        # Create geo_pred_df with pred column
        geo_pred_df = test_gdf[id_cols].copy()
        geo_pred_df['PredictedNewCases'] = geo_preds
        geo_pred_dfs.append(geo_pred_df)

    # Combine all predictions into a single dataframe
    pred_df = pd.concat(geo_pred_dfs)
    
    # Drop GeoID column to match expected output format
    pred_df = pred_df.drop(columns=['GeoID'])
    pred_df
    
    # Write predictions to csv
    # Save to expected file name
    output_file_name = start_date + "_" + end_date + ".csv"
    pred_df.to_csv(output_file_name, index=None)
    print(f"Predictions saved to {output_file_name}")


In [30]:
predict(start_date="2020-08-01", end_date="2020-08-04", path_to_ips_file="../2020-08-01_2020-08-04_npis_example.csv")


Predicting for Aruba__nan
36.97432703121546
40.82571889683008
39.00155848076339
37.471016230849514
38.088621080325765
39.18813817161799
37.89193329405863
38.980765359209734
42.76946059656962
42.47295030636633

Predicting for Afghanistan__nan
163.9011800387384
171.64308454558443
165.5573596365428
160.72276166832216
167.5292803113956
169.9907868085004
157.53630261522443
164.21529061005168
176.6024863603904
174.77762537829574

Predicting for Angola__nan
103.011871503107
114.00381418551
108.6334029236115
108.43309295483776
109.0865563121736
113.47696090875584
106.92536024245616
106.06504638603047
108.18449203093203
107.6169791837519

Predicting for Anguilla__nan
22.363098106310645
23.029405140487995
22.34957465421587
21.776692775916786
21.74221449856691
21.965426020333837
22.054977816857047
22.40081883872017
22.98753995515318
23.84126145097349

Predicting for Albania__nan
144.4882287977406
152.3155520954009
148.6207266119369
144.887721874097
147.64778096665253
152.05019819476183
147.49847

808.3458411728479
827.0975636333619
799.7193850995264
775.8443839410697
776.0130318842853

Predicting for Djibouti__nan
52.495258854362305
58.089131183329705
54.42847771416662
53.272454457444226
55.03910218939507
56.99018601415549
52.092717468519915
52.702294850116765
53.42782138229431
55.59165901576136

Predicting for Dominica__nan
22.375602916898085
23.05673611614408
22.34957465421587
21.776692775916786
21.74221449856691
22.389399704527424
22.054977816857047
22.40081883872017
22.98753995515318
23.84126145097349

Predicting for Denmark__nan
108.09769600972575
115.58960861175476
111.65074007888174
110.33970786302517
109.93808084578677
114.94445066857241
116.93626103147558
114.8626049000519
114.49201938361449
112.24501909123387

Predicting for Dominican Republic__nan
1790.9623175018673
1822.295779709369
1750.8707375505564
1804.9159903365467
1799.4590412518457
1834.177746164561
1777.4416932130862
1795.1770989587678
1796.9010737877045
1735.9578667402768

Predicting for Algeria__nan
703.68


Predicting for Kuwait__nan
742.4419251470883
751.6668032694197
755.7496356463288
748.7938427348495
751.5532824287708
748.5731601159584
747.1780016998532
750.6273108375738
762.1482380924308
765.5114702406702

Predicting for Laos__nan
22.64995999079281
23.36837763592973
22.67307089013426
22.06454995274516
22.021123600825085
22.2775570036194
22.337791848994673
22.680124300363623
23.279101432235006
24.103886380812444

Predicting for Lebanon__nan
203.11172704184298
209.02800221697066
204.75609711684933
204.50992351514475
206.6294983420398
208.47705637187994
202.30192434579197
205.60012368879876
210.73620646714724
210.80518926492329

Predicting for Liberia__nan
61.37875340743603
68.04123399058383
64.96062060558086
62.28871243425116
63.97357249265719
66.33408272814226
61.72919487162773
64.40945151321958
72.44063192923143
70.98189610814944

Predicting for Libya__nan
216.61617892353408
217.39507560791066
209.96824384815872
218.1901578497915
212.64069391240642
218.2868029582065
228.463395665827

170.65272513412555
166.2728731806731

Predicting for Singapore__nan
402.2467735966537
388.4867377846885
383.7693841287563
387.8715563288054
396.8849518116657
392.2197369426254
385.6712062275672
401.56131819863606
414.38260602487827
412.22624837213164

Predicting for Solomon Islands__nan
49.60835241408094
55.034096666251166
51.096770127418594
50.54444352877658
52.40480693463275
54.413301609746
48.47593675620203
49.213219176942765
50.1697804494144
52.271472910824436

Predicting for Sierra Leone__nan
33.33283259897333
34.0554927868242
32.95335655458548
33.51334161770107
33.32302757751464
34.14562120812569
33.00310251307461
33.33499472601918
33.70170722227642
34.605049775382255

Predicting for El Salvador__nan
477.04337637724217
485.3992426962765
480.3477953619191
476.94232466414553
482.1950432211367
485.23939169688873
477.4524463937507
483.4024285755679
493.6986263121077
490.84581263678183

Predicting for San Marino__nan
22.363098106310645
23.05673611614408
22.34957465421587
21.7766927759

Predicting for United States__Maryland
1142.436999886291
1216.2956793109634
1189.231632154358
1170.9674441173115
1169.6289012061181
1211.4431091497586
1178.689771743013
1163.540100785179
1158.1369531797354
1135.3522833747813

Predicting for United States__Maine
76.4579120480191
84.04129616889203
80.60380555568248
78.55025370221672
80.32725855726534
82.72149598447774
76.7635139614437
77.9546216295942
82.3761964463384
83.07621521442499

Predicting for United States__Michigan
875.2144092726035
852.5507898871765
862.0143433884043
874.4080275467005
890.0722913401544
859.83677041698
854.4027963896381
870.2455621443281
887.8661182986041
908.772029667472

Predicting for United States__Minnesota
862.9820066977743
880.0584811646356
870.0497824431674
876.4499928469576
880.1381834861866
878.9394543291025
853.4181894611427
859.6311935879787
871.8878551740049
869.6604875798685

Predicting for United States__Missouri
1478.208931456209
1496.867214111188
1516.6169946529476
1500.7672594850028
1491.59762

Predictions saved to 2020-08-01_2020-08-04.csv


In [32]:
# Check that predictions are written correctly
!head 2020-08-01_2020-08-04.csv

CountryName,RegionName,Date,PredictedNewCases
Aruba,,2020-08-01,36.97432703121546
Aruba,,2020-08-02,40.82571889683008
Aruba,,2020-08-03,39.00155848076339
Aruba,,2020-08-04,37.471016230849514
Aruba,,2020-08-05,38.088621080325765
Aruba,,2020-08-06,39.18813817161799
Aruba,,2020-08-07,37.89193329405863
Aruba,,2020-08-08,38.980765359209734
Aruba,,2020-08-09,42.76946059656962
