# Example Predictor: SVR Predictor

This example contains basic functionality for training and evaluating a SVR predictor that rolls out predictions day-by-day.

First, a training data set is created from historical case and npi data.

Second, a linear model is trained to predict future cases from prior case data along with prior and future npi data.
The model is an off-the-shelf sklearn Lasso model, that uses a positive weight constraint to enforce the assumption that increased npis has a negative correlation with future cases.

Third, a sample evaluation set is created, and the predictor is applied to this evaluation set to produce prediction results in the correct format.

## Training

In [18]:
import pickle
import numpy as np
import pandas as pd
from sklearn.svm import SVR
from sklearn.model_selection import train_test_split
import time

### Copy the data locally

In [19]:
# Main source for the training data
DATA_URL = 'https://raw.githubusercontent.com/OxCGRT/covid-policy-tracker/master/data/OxCGRT_latest.csv'
# Local file
DATA_FILE = 'data/OxCGRT_latest.csv'

In [20]:
import os
import urllib.request
if not os.path.exists('data'):
    os.mkdir('data')
urllib.request.urlretrieve(DATA_URL, DATA_FILE)

('data/OxCGRT_latest.csv', <http.client.HTTPMessage at 0x7f65b43b49a0>)

In [21]:
# Load historical data from local file
df = pd.read_csv(DATA_FILE, 
                 parse_dates=['Date'],
                 encoding="ISO-8859-1",
                 dtype={"RegionName": str,
                        "RegionCode": str},
                 error_bad_lines=False)

In [22]:
df.columns

Index(['CountryName', 'CountryCode', 'RegionName', 'RegionCode',
       'Jurisdiction', 'Date', 'C1_School closing', 'C1_Flag',
       'C2_Workplace closing', 'C2_Flag', 'C3_Cancel public events', 'C3_Flag',
       'C4_Restrictions on gatherings', 'C4_Flag', 'C5_Close public transport',
       'C5_Flag', 'C6_Stay at home requirements', 'C6_Flag',
       'C7_Restrictions on internal movement', 'C7_Flag',
       'C8_International travel controls', 'E1_Income support', 'E1_Flag',
       'E2_Debt/contract relief', 'E3_Fiscal measures',
       'E4_International support', 'H1_Public information campaigns',
       'H1_Flag', 'H2_Testing policy', 'H3_Contact tracing',
       'H4_Emergency investment in healthcare', 'H5_Investment in vaccines',
       'H6_Facial Coverings', 'H6_Flag', 'M1_Wildcard', 'ConfirmedCases',
       'ConfirmedDeaths', 'StringencyIndex', 'StringencyIndexForDisplay',
       'StringencyLegacyIndex', 'StringencyLegacyIndexForDisplay',
       'GovernmentResponseIndex', 'Gove

In [23]:
# For testing, restrict training data to that before a hypothetical predictor submission date
HYPOTHETICAL_SUBMISSION_DATE = np.datetime64("2020-07-31")
df = df[df.Date <= HYPOTHETICAL_SUBMISSION_DATE]

In [24]:
# Add RegionID column that combines CountryName and RegionName for easier manipulation of data
df['GeoID'] = df['CountryName'] + '__' + df['RegionName'].astype(str)

In [25]:
# Add new cases column
df['NewCases'] = df.groupby('GeoID').ConfirmedCases.diff().fillna(0)

In [26]:
# Keep only columns of interest
id_cols = ['CountryName',
           'RegionName',
           'GeoID',
           'Date']
cases_col = ['NewCases']
npi_cols = ['C1_School closing',
            'C2_Workplace closing',
            'C3_Cancel public events',
            'C4_Restrictions on gatherings',
            'C5_Close public transport',
            'C6_Stay at home requirements',
            'C7_Restrictions on internal movement',
            'C8_International travel controls',
            'H1_Public information campaigns',
            'H2_Testing policy',
            'H3_Contact tracing',
            'H6_Facial Coverings']
df = df[id_cols + cases_col + npi_cols]

In [27]:
# Fill any missing case values by interpolation and setting NaNs to 0
df.update(df.groupby('GeoID').NewCases.apply(
    lambda group: group.interpolate()).fillna(0))

In [28]:
# Fill any missing NPIs by assuming they are the same as previous day
for npi_col in npi_cols:
    df.update(df.groupby('GeoID')[npi_col].ffill().fillna(0))

In [29]:
df

Unnamed: 0,CountryName,RegionName,GeoID,Date,NewCases,C1_School closing,C2_Workplace closing,C3_Cancel public events,C4_Restrictions on gatherings,C5_Close public transport,C6_Stay at home requirements,C7_Restrictions on internal movement,C8_International travel controls,H1_Public information campaigns,H2_Testing policy,H3_Contact tracing,H6_Facial Coverings
0,Aruba,,Aruba__nan,2020-01-01,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0
1,Aruba,,Aruba__nan,2020-01-02,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0
2,Aruba,,Aruba__nan,2020-01-03,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0
3,Aruba,,Aruba__nan,2020-01-04,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0
4,Aruba,,Aruba__nan,2020-01-05,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
87064,Zimbabwe,,Zimbabwe__nan,2020-07-27,78.0,3.0,1.0,2.0,3.0,1.0,2.0,2.0,4.0,2.0,1.0,1.0,4.0
87065,Zimbabwe,,Zimbabwe__nan,2020-07-28,192.0,3.0,1.0,2.0,3.0,1.0,2.0,2.0,4.0,2.0,1.0,1.0,4.0
87066,Zimbabwe,,Zimbabwe__nan,2020-07-29,113.0,3.0,1.0,2.0,3.0,1.0,2.0,2.0,4.0,2.0,1.0,1.0,4.0
87067,Zimbabwe,,Zimbabwe__nan,2020-07-30,62.0,3.0,1.0,2.0,3.0,1.0,2.0,2.0,4.0,2.0,1.0,1.0,4.0


In [98]:
# Set number of past days to use to make predictions
nb_lookback_days = 20

# Create training data across all countries for predicting one day ahead
X_cols = cases_col + npi_cols
y_col = cases_col
X_samples = []
y_samples = []
geo_ids = df.GeoID.unique()
for g in geo_ids:
    gdf = df[df.GeoID == g]
    all_case_data = np.array(gdf[cases_col])
    all_npi_data = np.array(gdf[npi_cols])

    # Create one sample for each day where we have enough data
    # Each sample consists of cases and npis for previous nb_lookback_days
    nb_total_days = len(gdf)
    for d in range(nb_lookback_days, nb_total_days - 1):
        X_cases = all_case_data[d-nb_lookback_days:d]

        # Take negative of npis to support positive
        # weight constraint in Lasso.
        X_npis = -all_npi_data[d - nb_lookback_days:d]

        # Flatten all input data so it fits Lasso input format.
        X_sample = np.concatenate([X_cases.flatten(),
                                   X_npis.flatten()])
        y_sample = all_case_data[d + 1]
        X_samples.append(X_sample)
        y_samples.append(y_sample)

X_samples = np.array(X_samples)
y_samples = np.array(y_samples).flatten()

In [99]:
# Helpful function to compute mae
def mae(pred, true):
    return np.mean(np.abs(pred - true))

In [100]:
# Split data into train and test sets
X_train, X_test, y_train, y_test = train_test_split(X_samples,
                                                    y_samples,
                                                    test_size=0.2,
                                                    random_state=301)
print(X_train.shape)

(40704, 260)


In [101]:
# Create and train SVR model.
# Set positive=True to enforce assumption that cases are positively correlated
# with future cases and npis are negatively correlated.
start_time = time.time()  
  


model = SVR(kernel='rbf', C=500)
# Fit model
model.fit(X_train, y_train)

interval = time.time() - start_time
print ('Total time in seconds:', interval)

Total time in seconds: 655.8811466693878


In [102]:
# Evaluate model
train_preds = model.predict(X_train)
train_preds = np.maximum(train_preds, 0) # Don't predict negative cases
print('Train MAE:', mae(train_preds, y_train))

test_preds = model.predict(X_test)
test_preds = np.maximum(test_preds, 0) # Don't predict negative cases
print('Test MAE:', mae(test_preds, y_test))

Train MAE: 212.76363424255487
Test MAE: 208.2076117431121


In [195]:
# Inspect the learned feature coefficients for the model
# to see what features it's paying attention to.

# Give names to the features
x_col_names = []
for d in range(-nb_lookback_days, 0):
    x_col_names.append('Day ' + str(d) + ' ' + cases_col[0])
for d in range(-nb_lookback_days, 1):
    for col_name in npi_cols:
        x_col_names.append('Day ' + str(d) + ' ' + col_name)

# View non-zero coefficients
for (col, coeff) in zip(x_col_names, list(model.coef_)):
    if coeff != 0.:
        print(col, coeff)
print('Intercept', model.intercept_)

Day -7 NewCases 0.0013246240389721805
Day -6 NewCases 0.43932558850808323
Day -5 NewCases 0.21733329048235964
Day -4 NewCases 0.05883044887710127
Day -3 NewCases 0.06951319628060729
Day -2 NewCases 0.052031127397790256
Day -1 NewCases 0.23822431735053784
Day -26 C6_Stay at home requirements 4.314491844181091
Day -22 C2_Workplace closing 9.715799339190195
Day -17 C2_Workplace closing 5.77128206222885
Intercept 26.55971959613379


In [109]:
# Save model to file
if not os.path.exists('models'):
    os.mkdir('models')
with open('models/model.pkl', 'wb') as model_file:
    pickle.dump(model, model_file)

## Evaluation

Now that the predictor has been trained and saved, this section contains the functionality for evaluating it on sample evaluation data.

In [125]:
# Reload the module to get the latest changes
import predict
from importlib import reload
reload(predict)
from predict import predict_df

In [126]:
%%time
preds_df = predict_df("2020-08-01", "2020-08-21", path_to_ips_file="../../../validation/data/2020-09-30_historical_ip.csv", verbose=True)


Predicting for Aruba__nan
2020-08-01: 8.971073893048015
2020-08-02: 8.938848401658106
2020-08-03: 9.680412154852092
2020-08-04: 12.139685932485008
2020-08-05: 9.611990085453726
2020-08-06: 7.913701155735907
2020-08-07: 8.690127328483868
2020-08-08: 9.040341221814742
2020-08-09: 10.049624138391664
2020-08-10: 10.351802813536779
2020-08-11: 10.863686043048801
2020-08-12: 9.327550679976412
2020-08-13: 7.987241552191335
2020-08-14: 7.355947135531096
2020-08-15: 8.857456291096241
2020-08-16: 8.461162540348596
2020-08-17: 8.9158701977085
2020-08-18: 8.842403671020293
2020-08-19: 7.25229590573872
2020-08-20: 6.703693768575249
2020-08-21: 6.996070635977958

Predicting for Afghanistan__nan
2020-08-01: 223.6423968861891
2020-08-02: 258.0059278881545
2020-08-03: 260.5593185395928
2020-08-04: 233.32108211691593
2020-08-05: 250.25604988187206
2020-08-06: 254.88376667404282
2020-08-07: 260.4993246287877
2020-08-08: 266.25042968259095
2020-08-09: 275.07999021654905
2020-08-10: 261.1131177776524
2020

2020-08-20: 7.138487668820744
2020-08-21: 8.789764574761648

Predicting for Burkina Faso__nan
2020-08-01: 18.383157712272805
2020-08-02: 18.832037099178706
2020-08-03: 19.737924398135874
2020-08-04: 23.205638616265787
2020-08-05: 25.01295358034804
2020-08-06: 22.5110976336382
2020-08-07: 20.469422861326166
2020-08-08: 20.803240903193
2020-08-09: 21.52607875844842
2020-08-10: 23.679936409742368
2020-08-11: 24.287962960788718
2020-08-12: 25.66236537799523
2020-08-13: 24.41480064400821
2020-08-14: 21.855893856580224
2020-08-15: 22.09862539743881
2020-08-16: 22.651591222758725
2020-08-17: 22.31731933187075
2020-08-18: 23.24207842510441
2020-08-19: 22.70153779005159
2020-08-20: 20.797345625214803
2020-08-21: 19.953248282474306

Predicting for Bangladesh__nan
2020-08-01: 2255.8197859772827
2020-08-02: 2323.7364352852255
2020-08-03: 2150.2307740241467
2020-08-04: 1915.2975034185738
2020-08-05: 1989.626749900025
2020-08-06: 2169.334463948994
2020-08-07: 2200.3843979821395
2020-08-08: 2097.9439

2020-08-16: 3.0340515262814733
2020-08-17: 3.3733536481886404
2020-08-18: 3.562858461336873
2020-08-19: 3.7485882068795036
2020-08-20: 3.588302551434026
2020-08-21: 3.3558135658495303

Predicting for Botswana__nan
2020-08-01: 151.4413575416038
2020-08-02: 275.0024258202702
2020-08-03: 158.78156348221455
2020-08-04: 122.27998908687823
2020-08-05: 103.40234912103006
2020-08-06: 216.68178900857856
2020-08-07: 239.97619299873804
2020-08-08: 236.01685198827
2020-08-09: 234.74746458380014
2020-08-10: 225.055014411495
2020-08-11: 170.46925006396123
2020-08-12: 214.36133841247465
2020-08-13: 309.07552886060785
2020-08-14: 283.20237706392527
2020-08-15: 318.7631970604407
2020-08-16: 321.5879488941837
2020-08-17: 276.9602445133196
2020-08-18: 277.8312650601747
2020-08-19: 312.4377452201152
2020-08-20: 330.11165504355904
2020-08-21: 344.8200510649722

Predicting for Central African Republic__nan
2020-08-01: 1.4327445779999834
2020-08-02: 2.5505271670317597
2020-08-03: 2.430355378999593
2020-08-04

2020-08-10: 48.91547274685763
2020-08-11: 46.85145687507975
2020-08-12: 32.68249310712963
2020-08-13: 24.109025804182238
2020-08-14: 35.69812458819433
2020-08-15: 33.86618745706073
2020-08-16: 35.9896021194254
2020-08-17: 36.71048818165036
2020-08-18: 28.078534730786487
2020-08-19: 21.036364414552736
2020-08-20: 21.749600254946927
2020-08-21: 24.747876633633496

Predicting for Costa Rica__nan
2020-08-01: 1621.479058417718
2020-08-02: 1656.0028632130561
2020-08-03: 1454.7435397235267
2020-08-04: 1044.5096636880007
2020-08-05: 1073.7548563669397
2020-08-06: 1935.1303380173913
2020-08-07: 2071.9569397927226
2020-08-08: 2031.2954889501007
2020-08-09: 1766.3745123190974
2020-08-10: 1706.3799323463663
2020-08-11: 1571.89886340239
2020-08-12: 1666.2711872910932
2020-08-13: 2350.087494004888
2020-08-14: 2499.238848091466
2020-08-15: 2311.4395782464344
2020-08-16: 2083.5376749378365
2020-08-17: 1871.9415922603202
2020-08-18: 1812.1290627642848
2020-08-19: 2184.1725930561406
2020-08-20: 2567.544

2020-08-02: 11.95832345472445
2020-08-03: 10.966654449715861
2020-08-04: 12.48035586541846
2020-08-05: 10.165924540178821
2020-08-06: 10.658121400740129
2020-08-07: 12.583483479991628
2020-08-08: 14.701692314814863
2020-08-09: 14.941661695300354
2020-08-10: 12.909039870895867
2020-08-11: 12.227636337995136
2020-08-12: 12.119647121571688
2020-08-13: 11.578687465218536
2020-08-14: 12.464921037879321
2020-08-15: 14.921901346660889
2020-08-16: 12.739209906883843
2020-08-17: 12.384221150916346
2020-08-18: 12.285793403327261
2020-08-19: 11.352692199738158
2020-08-20: 11.930350811426251
2020-08-21: 12.747596652399807

Predicting for Spain__nan
2020-08-01: 8233.841455976642
2020-08-02: 8233.841455976642
2020-08-03: 8233.841455976624
2020-08-04: 8233.841455976644
2020-08-05: 8233.841455976677
2020-08-06: 8233.841448027588
2020-08-07: 8233.841411786048
2020-08-08: 8233.841032409055
2020-08-09: 8233.842471659955
2020-08-10: 8233.833131948199
2020-08-11: 8233.827258205616
2020-08-12: 8233.85395396


Predicting for Georgia__nan
2020-08-01: 4584.614842199116
2020-08-02: 5892.79967342648
2020-08-03: 6776.977038411738
2020-08-04: 6146.121288450302
2020-08-05: 5904.6175715277095
2020-08-06: 5095.776750153476
2020-08-07: 5525.409559253189
2020-08-08: 7106.751790056229
2020-08-09: 8248.467367827356
2020-08-10: 8783.457483893148
2020-08-11: 7771.666821718894
2020-08-12: 7089.9043081422315
2020-08-13: 6583.778010849817
2020-08-14: 7723.335118520908
2020-08-15: 9434.375522008339
2020-08-16: 10458.801578777982
2020-08-17: 10165.498435614722
2020-08-18: 9336.755520564802
2020-08-19: 8417.75202474214
2020-08-20: 8741.862648127826
2020-08-21: 10007.05152716118

Predicting for Ghana__nan
2020-08-01: 85.8934583551918
2020-08-02: 151.24151105896362
2020-08-03: 112.31001043303968
2020-08-04: 108.04956134492204
2020-08-05: 68.09672790332752
2020-08-06: 107.51431752070312
2020-08-07: 99.2046504608279
2020-08-08: 116.5732149731075
2020-08-09: 150.06779881509465
2020-08-10: 111.23516312250285
2020-08-

2020-08-19: 2.572163598453699
2020-08-20: 3.2906779926306626
2020-08-21: 2.8628671931874123

Predicting for Hungary__nan
2020-08-01: 3268.2469628822873
2020-08-02: 3735.2693118883653
2020-08-03: 3582.2640132903307
2020-08-04: 3327.4994918879374
2020-08-05: 3342.5235666327926
2020-08-06: 2857.6195414092763
2020-08-07: 2361.7586817714273
2020-08-08: 2282.692318162105
2020-08-09: 2217.330222941313
2020-08-10: 2024.615948713924
2020-08-11: 1761.6358865455895
2020-08-12: 1664.3457879417774
2020-08-13: 1395.074982758515
2020-08-14: 1276.8188070084643
2020-08-15: 1228.8684191040038
2020-08-16: 1163.3201432844307
2020-08-17: 1020.9296882989056
2020-08-18: 953.0314284322067
2020-08-19: 795.9247535310042
2020-08-20: 723.0526728131163
2020-08-21: 661.1303062408642

Predicting for Indonesia__nan
2020-08-01: 4632.7500781159
2020-08-02: 5254.474045111412
2020-08-03: 5642.184348493548
2020-08-04: 5677.4774759918855
2020-08-05: 4953.477976262942
2020-08-06: 4375.978136776712
2020-08-07: 4582.828194957

2020-08-17: 1052.7250445715736
2020-08-18: 1040.326056552958
2020-08-19: 949.0762682043132
2020-08-20: 859.3875412902844
2020-08-21: 834.4494717669595

Predicting for Kenya__nan
2020-08-01: 1009.3257559674003
2020-08-02: 1158.130945658355
2020-08-03: 1241.5814637775247
2020-08-04: 1115.7985301095287
2020-08-05: 797.81987021447
2020-08-06: 694.8500533634851
2020-08-07: 874.6921482586185
2020-08-08: 1000.359642601863
2020-08-09: 1171.2397331264838
2020-08-10: 1016.7892031146584
2020-08-11: 863.0284188158539
2020-08-12: 751.6923410910522
2020-08-13: 696.4833827109323
2020-08-14: 822.5357028876397
2020-08-15: 950.449289357357
2020-08-16: 930.3939589709562
2020-08-17: 803.3997491501596
2020-08-18: 710.8617513495828
2020-08-19: 619.2027996386187
2020-08-20: 666.1794851200693
2020-08-21: 766.2675926181455

Predicting for Kyrgyz Republic__nan
2020-08-01: 362.243794240464
2020-08-02: 376.44358034528796
2020-08-03: 378.6343493832819
2020-08-04: 358.52732711664885
2020-08-05: 343.64875605442467
2

2020-08-08: 708.8157179046202
2020-08-09: 815.0866380602074
2020-08-10: 823.5982026253641
2020-08-11: 741.6676847978988
2020-08-12: 698.3991158431854
2020-08-13: 705.4423682174975
2020-08-14: 718.6715403788967
2020-08-15: 777.7882817844775
2020-08-16: 811.1799048693529
2020-08-17: 754.7460333712952
2020-08-18: 680.9904201658874
2020-08-19: 655.4622281085494
2020-08-20: 660.2414307893396
2020-08-21: 688.9818731925534

Predicting for Latvia__nan
2020-08-01: 405.82535002324585
2020-08-02: 451.1068434106228
2020-08-03: 480.6099870433309
2020-08-04: 473.6504961905557
2020-08-05: 380.1364777655526
2020-08-06: 279.0475692489963
2020-08-07: 326.87490073897425
2020-08-08: 427.18942859450544
2020-08-09: 484.3829028851278
2020-08-10: 458.1432179019657
2020-08-11: 435.36206307802513
2020-08-12: 382.09811099277067
2020-08-13: 327.3769224699299
2020-08-14: 373.99154030649333
2020-08-15: 446.92243161281567
2020-08-16: 449.9464872098324
2020-08-17: 421.611478536689
2020-08-18: 399.76458190455924
2020-

2020-08-01: 4.429712748149541
2020-08-02: 2.475214989441156
2020-08-03: 1.316816221422414
2020-08-04: 0.8252951807735371
2020-08-05: 1.8243988926351449
2020-08-06: 3.4044234212778974
2020-08-07: 4.933209769264067
2020-08-08: 3.205740601670186
2020-08-09: 0.8262498151962063
2020-08-10: 0
2020-08-11: 0.46217164496556506
2020-08-12: 0.724379926077745
2020-08-13: 2.6060647887679806
2020-08-14: 3.0572792519960785
2020-08-15: 1.5465602231979574
2020-08-16: 0.5006571372268809
2020-08-17: 0.3998065652885998
2020-08-18: 0.751130494840254
2020-08-19: 1.6195262684868794
2020-08-20: 2.2091921548399114
2020-08-21: 1.6550331746148004

Predicting for Malaysia__nan
2020-08-01: 1099.087807203915
2020-08-02: 1226.3614706676544
2020-08-03: 1301.0774311551759
2020-08-04: 1309.6560842844974
2020-08-05: 1391.170448570354
2020-08-06: 1449.4547311889273
2020-08-07: 1323.6584492549691
2020-08-08: 1150.2396853941154
2020-08-09: 1191.8092723826176
2020-08-10: 1262.4348405529308
2020-08-11: 1315.3944465287432
202

2020-08-10: 952.7228888559684
2020-08-11: 885.0422276000354
2020-08-12: 854.0802262143025
2020-08-13: 763.7037422976737
2020-08-14: 761.0279557025342
2020-08-15: 781.0162633911996
2020-08-16: 770.2134706852685
2020-08-17: 713.2726834553932
2020-08-18: 684.6158598808734
2020-08-19: 621.6742956038051
2020-08-20: 592.0907015117837
2020-08-21: 600.7869793616555

Predicting for Peru__nan
2020-08-01: 1726.1895924816126
2020-08-02: 1747.486831268342
2020-08-03: 1721.062946351697
2020-08-04: 1541.5873455048095
2020-08-05: 1242.8177427771307
2020-08-06: 845.2377129213683
2020-08-07: 1101.6875657803303
2020-08-08: 1330.4585471288565
2020-08-09: 1279.9500249888615
2020-08-10: 1121.3270070763283
2020-08-11: 983.9324091994677
2020-08-12: 828.5879679188538
2020-08-13: 695.2274967773992
2020-08-14: 881.2409160410825
2020-08-15: 944.428204556817
2020-08-16: 841.0534448116605
2020-08-17: 742.231835778116
2020-08-18: 660.6478574435723
2020-08-19: 530.6252970447658
2020-08-20: 571.102911081508
2020-08-21

2020-08-18: 8238.659865816046
2020-08-19: 8271.560162863267
2020-08-20: 8432.536919236307
2020-08-21: 8052.248084701132

Predicting for Rwanda__nan
2020-08-01: 41.14532565249647
2020-08-02: 36.66924605339773
2020-08-03: 41.72332780846045
2020-08-04: 44.41207379410935
2020-08-05: 50.296708487951946
2020-08-06: 48.85456181014979
2020-08-07: 43.076966495649685
2020-08-08: 38.888243750394395
2020-08-09: 37.81159409372049
2020-08-10: 42.37641056954453
2020-08-11: 46.21735919346065
2020-08-12: 49.60153563937365
2020-08-13: 45.342509089397026
2020-08-14: 42.4244377353325
2020-08-15: 37.935646980506135
2020-08-16: 37.17594641136384
2020-08-17: 40.51075435626262
2020-08-18: 40.80986005425075
2020-08-19: 39.946780689208026
2020-08-20: 37.21553357297307
2020-08-21: 33.23327644197343

Predicting for Saudi Arabia__nan
2020-08-01: 242.7243723349975
2020-08-02: 221.7260530623371
2020-08-03: 208.7009985356408
2020-08-04: 187.29937411375613
2020-08-05: 169.16030112895078
2020-08-06: 164.30908995052232


2020-08-07: 1.2411292207852966
2020-08-08: 2.289966016478502
2020-08-09: 2.3870430741881137
2020-08-10: 2.366199793326814
2020-08-11: 1.8677048844638193
2020-08-12: 1.0014395874277398
2020-08-13: 0.7204519118167809
2020-08-14: 1.3401482801473321
2020-08-15: 1.87632693640262
2020-08-16: 1.7836262154596625
2020-08-17: 1.640968662180967
2020-08-18: 1.2276018737447885
2020-08-19: 0.8487482033706328
2020-08-20: 1.127893909235354
2020-08-21: 1.6969966493506945

Predicting for Slovak Republic__nan
2020-08-01: 1439.191591179826
2020-08-02: 1573.6870348359435
2020-08-03: 1591.644566206388
2020-08-04: 1410.2781127135559
2020-08-05: 876.1310599932895
2020-08-06: 453.029849718293
2020-08-07: 928.488275711552
2020-08-08: 1316.676047378527
2020-08-09: 1363.3035728753111
2020-08-10: 1167.1847015629246
2020-08-11: 1016.3937755820452
2020-08-12: 713.8218199975527
2020-08-13: 627.0792451624638
2020-08-14: 990.5114575588623
2020-08-15: 1158.7603495076792
2020-08-16: 1051.618564827948
2020-08-17: 931.5192

2020-08-15: 2.4990560540118167
2020-08-16: 2.641377391822971
2020-08-17: 2.7465451661191764
2020-08-18: 2.845237598532549
2020-08-19: 2.928278662251614
2020-08-20: 3.0188312822647276
2020-08-21: 3.1063866190761473

Predicting for Trinidad and Tobago__nan
2020-08-01: 62.272471465959825
2020-08-02: 63.189367644573395
2020-08-03: 66.36775538708571
2020-08-04: 79.46285924454423
2020-08-05: 79.01547749353904
2020-08-06: 62.88018497488065
2020-08-07: 61.72958514482525
2020-08-08: 68.61333404813467
2020-08-09: 70.55507314158785
2020-08-10: 77.70819021034822
2020-08-11: 79.03003065038865
2020-08-12: 77.16597112292311
2020-08-13: 69.6658024442495
2020-08-14: 65.48930256344102
2020-08-15: 67.12708306910281
2020-08-16: 73.09415259794514
2020-08-17: 69.60169740163656
2020-08-18: 71.85267240372195
2020-08-19: 67.30101538029066
2020-08-20: 59.93141997663861
2020-08-21: 60.54502371124818

Predicting for Tunisia__nan
2020-08-01: 1501.9943711494188
2020-08-02: 1442.822500692224
2020-08-03: 1296.8877323

2020-08-10: 4388.674885484383
2020-08-11: 4113.054971746102
2020-08-12: 3415.4980423180214
2020-08-13: 3610.910405703159
2020-08-14: 4803.97497031418
2020-08-15: 5602.2247148824745
2020-08-16: 5893.3642307018
2020-08-17: 5692.777582286077
2020-08-18: 5174.873650433187
2020-08-19: 4703.628444601198
2020-08-20: 5201.635795092094
2020-08-21: 6685.62938633437

Predicting for United States__California
2020-08-01: 8533.97518976503
2020-08-02: 8551.761722320593
2020-08-03: 8412.152076268538
2020-08-04: 8167.434590752669
2020-08-05: 7879.385103933743
2020-08-06: 7656.927267785082
2020-08-07: 7794.193308935023
2020-08-08: 8172.57172521317
2020-08-09: 8343.382135425674
2020-08-10: 8278.039106943914
2020-08-11: 8102.767118851382
2020-08-12: 8062.518937247718
2020-08-13: 7969.152769253098
2020-08-14: 8113.186009317141
2020-08-15: 8319.078200126443
2020-08-16: 8503.589646923432
2020-08-17: 8454.420391360973
2020-08-18: 8281.231276799594
2020-08-19: 8180.196491763141
2020-08-20: 7407.738555065006
20

2020-08-01: 5548.833783442958
2020-08-02: 5701.010313713949
2020-08-03: 5415.042543862814
2020-08-04: 6289.346343538169
2020-08-05: 7085.650479779763
2020-08-06: 7024.473751856515
2020-08-07: 7902.473686741756
2020-08-08: 7463.276433386883
2020-08-09: 8269.060129591484
2020-08-10: 7500.008714861638
2020-08-11: 8102.228659075538
2020-08-12: 8799.862339204774
2020-08-13: 8197.323958188606
2020-08-14: 8797.364844439735
2020-08-15: 8675.500974967743
2020-08-16: 8851.06118294067
2020-08-17: 8527.748359702979
2020-08-18: 9058.281712273387
2020-08-19: 9239.410227814855
2020-08-20: 9027.438022689077
2020-08-21: 9274.944730618257

Predicting for United States__Kentucky
2020-08-01: 3009.0762600493017
2020-08-02: 2895.1710142603524
2020-08-03: 2986.1739195308573
2020-08-04: 2052.1242644913746
2020-08-05: 1458.087192588423
2020-08-06: 1687.882363061285
2020-08-07: 2734.8780913102746
2020-08-08: 2895.115678236095
2020-08-09: 2841.0599600597798
2020-08-10: 2491.2334948966836
2020-08-11: 1941.9360145

2020-08-05: 475.85896551827227
2020-08-06: 371.7847496044433
2020-08-07: 754.5277756508831
2020-08-08: 910.3489041913253
2020-08-09: 893.9067589135138
2020-08-10: 743.9810985089734
2020-08-11: 502.93104999371826
2020-08-12: 429.5571831013758
2020-08-13: 456.39534112473666
2020-08-14: 685.9981925147649
2020-08-15: 773.795141562211
2020-08-16: 715.018060081964
2020-08-17: 588.2850988588216
2020-08-18: 496.7332339248287
2020-08-19: 436.3302467964959
2020-08-20: 563.7323980016781
2020-08-21: 680.1220638742989

Predicting for United States__Nebraska
2020-08-01: 1812.4418608329297
2020-08-02: 1547.6155041138081
2020-08-03: 1197.1306768756258
2020-08-04: 971.529021056188
2020-08-05: 1266.9306209646065
2020-08-06: 1205.2131550407958
2020-08-07: 1412.1868631195202
2020-08-08: 1354.626273278548
2020-08-09: 1208.614007777469
2020-08-10: 911.7252867701345
2020-08-11: 840.1839489227577
2020-08-12: 1185.027793412126
2020-08-13: 1051.6575684320069
2020-08-14: 1074.082905836336
2020-08-15: 1177.749341

2020-08-20: 676.3956672351014
2020-08-21: 800.5297532292379

Predicting for United States__South Dakota
2020-08-01: 984.3421596902053
2020-08-02: 892.5042533323458
2020-08-03: 846.6169187082778
2020-08-04: 639.9965713170177
2020-08-05: 361.6330164724295
2020-08-06: 306.59363710085745
2020-08-07: 644.3238333207773
2020-08-08: 713.1556080363407
2020-08-09: 656.8432788174423
2020-08-10: 544.6048619264502
2020-08-11: 371.15538096697037
2020-08-12: 338.92138737838013
2020-08-13: 396.3040987291306
2020-08-14: 534.6653411057614
2020-08-15: 578.7280206642108
2020-08-16: 540.7329773170968
2020-08-17: 410.2561365781885
2020-08-18: 374.4002037309738
2020-08-19: 356.46353252653535
2020-08-20: 438.29702129041107
2020-08-21: 522.4122422420614

Predicting for United States__Tennessee
2020-08-01: 3414.876792467222
2020-08-02: 3967.17583929137
2020-08-03: 4749.301318627429
2020-08-04: 4989.171503585619
2020-08-05: 4900.42421340773
2020-08-06: 3305.8215102319746
2020-08-07: 4193.932361753716
2020-08-08:

2020-08-13: 640.1420301934832
2020-08-14: 561.28095341674
2020-08-15: 558.7350515603439
2020-08-16: 593.8756299524784
2020-08-17: 613.1424427351076
2020-08-18: 682.9491699414266
2020-08-19: 698.3712753115897
2020-08-20: 623.2416459818005
2020-08-21: 585.3341862720663

Predicting for United States Virgin Islands__nan
2020-08-01: 10.343342089832731
2020-08-02: 9.45910817330332
2020-08-03: 10.546853112755343
2020-08-04: 12.677587009586205
2020-08-05: 11.306029475255855
2020-08-06: 9.31289452022611
2020-08-07: 9.902389302251322
2020-08-08: 11.982017385404106
2020-08-09: 10.837204512121389
2020-08-10: 11.439653645225917
2020-08-11: 11.707632861858656
2020-08-12: 10.283919205981874
2020-08-13: 9.320224341721769
2020-08-14: 10.523333418195762
2020-08-15: 10.548155034508454
2020-08-16: 11.025706761385663
2020-08-17: 10.681131464381906
2020-08-18: 10.30671343323229
2020-08-19: 9.41253836125179
2020-08-20: 8.908594146621908
2020-08-21: 9.271229539186606

Predicting for Vietnam__nan
2020-08-01: 6

In [199]:
# Check the predictions
preds_df.head()

Unnamed: 0,CountryName,RegionName,Date,PredictedDailyNewCases
213,Aruba,,2020-08-01,58.844692
214,Aruba,,2020-08-02,71.323398
215,Aruba,,2020-08-03,78.998212
216,Aruba,,2020-08-04,90.247021
217,Aruba,,2020-08-05,87.814682


# Validation
This is how the predictor is going to be called during the competition.  
!!! PLEASE DO NOT CHANGE THE API !!!

In [127]:
!python predict.py -s 2020-08-01 -e 2020-08-31 -ip ../../../validation/data/2020-09-30_historical_ip.csv -o predictions/2020-08-01_2020-08-31.csv

Generating predictions from 2020-08-01 to 2020-08-31...
Saved predictions to predictions/2020-08-01_2020-08-31.csv
Done!


In [201]:
!head predictions/2020-08-01_2020-08-04.csv

CountryName,RegionName,Date,PredictedDailyNewCases
Aruba,,2020-08-01,58.844691538124245
Aruba,,2020-08-02,71.32339797888373
Aruba,,2020-08-03,78.99821175219239
Aruba,,2020-08-04,90.24702060691985
Afghanistan,,2020-08-01,149.93664407981095
Afghanistan,,2020-08-02,287.2826933086321
Afghanistan,,2020-08-03,277.50656117987995
Afghanistan,,2020-08-04,264.8862561764254
Angola,,2020-08-01,198.70750682167088


# Test cases
We can generate a prediction file. Let's validate a few cases...

In [110]:
import os
print(os.getcwd())

/home/cdrutinus/Cours/SDD/Mini-hackaton/pangolin/covid-xprize/covid_xprize/examples/predictors/SVR


In [111]:
from covid_xprize.validation.predictor_validation import validate_submission

def validate(start_date, end_date, ip_file, output_file):
    # First, delete any potential old file
    try:
        os.remove(output_file)
    except OSError:
        pass
    
    # Then generate the prediction, calling the official API
    !python predict.py -s {start_date} -e {end_date} -ip {ip_file} -o {output_file}
    
    # And validate it
    errors = validate_submission(start_date, end_date, ip_file, output_file)
    if errors:
        for error in errors:
            print(error)
    else:
        print("All good!")

## 4 days, no gap
- All countries and regions
- Official number of cases is known up to start_date
- Intervention Plans are the official ones

In [112]:
validate(start_date="2020-08-01",
         end_date="2020-08-04",
         ip_file="../../../validation/data/2020-09-30_historical_ip.csv",
         output_file="predictions/val_4_days.csv")

Generating predictions from 2020-08-01 to 2020-08-04...
Traceback (most recent call last):
  File "predict.py", line 200, in <module>
    predict(args.start_date, args.end_date, args.ip_file, args.output_file)
  File "predict.py", line 52, in predict
    preds_df = predict_df(start_date, end_date, path_to_ips_file, verbose=False)
  File "predict.py", line 139, in predict_df
    pred = model.predict(X.reshape(1, -1))[0]
  File "/home/cdrutinus/anaconda3/lib/python3.8/site-packages/sklearn/svm/_base.py", line 333, in predict
    X = self._validate_for_predict(X)
  File "/home/cdrutinus/anaconda3/lib/python3.8/site-packages/sklearn/svm/_base.py", line 484, in _validate_for_predict
    raise ValueError("X.shape[1] = %d should be equal to %d, "
ValueError: X.shape[1] = 390 should be equal to 260, the number of features at training time


FileNotFoundError: [Errno 2] No such file or directory: 'predictions/val_4_days.csv'

## 1 month in the future
- 2 countries only
- there's a gap between date of last known number of cases and start_date
- For future dates, Intervention Plans contains scenarios for which predictions are requested to answer the question: what will happen if we apply these plans?

In [205]:
%%time
validate(start_date="2021-01-01",
         end_date="2021-01-31",
         ip_file="../../../validation/data/future_ip.csv",
         output_file="predictions/val_1_month_future.csv")

Generating predictions from 2021-01-01 to 2021-01-31...
Saved predictions to predictions/val_1_month_future.csv
Done!
All good!
CPU times: user 62.5 ms, sys: 109 ms, total: 172 ms
Wall time: 4.08 s


## 180 days, from a future date, all countries and regions
- Prediction start date is 1 week from now. (i.e. assuming submission date is 1 week from now)  
- Prediction end date is 6 months after start date.  
- Prediction is requested for all available countries and regions.  
- Intervention plan scenario: freeze last known intervention plans for each country and region.  

As the number of cases is not known yet between today and start date, but the model relies on them, the model has to predict them in order to use them.  
This test is the most demanding test. It should take less than 1 hour to generate the prediction file.

### Generate the scenario

In [56]:
from datetime import datetime, timedelta

start_date = datetime.now() + timedelta(days=7)
start_date_str = start_date.strftime('%Y-%m-%d')
end_date = start_date + timedelta(days=180)
end_date_str = end_date.strftime('%Y-%m-%d')
print(f"Start date: {start_date_str}")
print(f"End date: {end_date_str}")

Start date: 2020-12-01
End date: 2021-05-30


In [57]:
from covid_xprize.validation.scenario_generator import get_raw_data, generate_scenario, NPI_COLUMNS
DATA_FILE = 'data/OxCGRT_latest.csv'
latest_df = get_raw_data(DATA_FILE, latest=True)
scenario_df = generate_scenario(start_date_str, end_date_str, latest_df, countries=None, scenario="Freeze")
scenario_file = "predictions/180_days_future_scenario.csv"
scenario_df.to_csv(scenario_file, index=False)
print(f"Saved scenario to {scenario_file}")

Saved scenario to predictions/180_days_future_scenario.csv


### Check it

In [58]:
%%time
validate(start_date=start_date_str,
         end_date=end_date_str,
         ip_file=scenario_file,
         output_file="predictions/val_6_month_future.csv")

Generating predictions from 2020-12-01 to 2021-05-30...
Saved predictions to predictions/val_6_month_future.csv
Done!
All good!
CPU times: user 7.59 s, sys: 277 ms, total: 7.87 s
Wall time: 2min 49s
