In [1]:
import pandas as pd

In [2]:
# Confirmed cases today
download_link = "https://data.humdata.org/hxlproxy/api/data-preview.csv?url=https%3A%2F%2Fraw.githubusercontent.com%2FCSSEGISandData%2FCOVID-19%2Fmaster%2Fcsse_covid_19_data%2Fcsse_covid_19_time_series%2Ftime_series_covid19_confirmed_global.csv&filename=time_series_covid19_confirmed_global.csv"
confirmed = pd.read_csv(download_link)

In [3]:
# GDP per capita
gdp_pp = pd.read_csv("data/gdp_per_capita_2018.csv")

### Merge data sources

In [4]:
confirmed_cntry = confirmed.drop(["Province/State","Lat","Long"], axis = 1).groupby('Country/Region').sum().reset_index()
confirmed_cntry.columns.values[0] = "Country"
gdp_pp.columns.values[0] = "Country"
confirmed_cntry["Country"] = confirmed_cntry["Country"].apply(str)
gdp_pp["Country"] = gdp_pp["Country"].apply(str)
gdp_pp = gdp_pp.dropna()

confirmed_full = gdp_pp.merge(confirmed_cntry, on="Country", how="right")
print(confirmed_full[confirmed_full["GDP per capita"].isna()]["Country"])
confirmed_full = gdp_pp.merge(confirmed_cntry, on="Country", how="inner")

175    Diamond Princess
176             Eritrea
177            Holy See
178                Iran
179       Liechtenstein
180          MS Zaandam
181          San Marino
182         South Sudan
183               Syria
184             Taiwan*
185           Venezuela
186      Western Sahara
187               Yemen
Name: Country, dtype: object


In [79]:
confirmed_full.head(4)

Unnamed: 0,Country,GDP per capita,1/22/20,1/23/20,1/24/20,1/25/20,1/26/20,1/27/20,1/28/20,1/29/20,...,5/27/20,5/28/20,5/29/20,5/30/20,5/31/20,6/1/20,6/2/20,6/3/20,6/4/20,6/5/20
0,Afghanistan,520.896603,0,0,0,0,0,0,0,0,...,12456,13036,13659,14525,15205,15750,16509,17267,18054,18969
1,Angola,3432.385736,0,0,0,0,0,0,0,0,...,71,74,81,84,86,86,86,86,86,86
2,Albania,5268.848504,0,0,0,0,0,0,0,0,...,1050,1076,1099,1122,1137,1143,1164,1184,1197,1212
3,Andorra,42029.76274,0,0,0,0,0,0,0,0,...,763,763,764,764,764,765,844,851,852,852


In [72]:
from feature_engineering import *
features = create_lagged_features(confirmed_full, range(1,28))

In [80]:
#We will drop the rows where there are zero cases
#Also we take a logarithm. 
nonzero = features[features.cases > 100].copy()
for col in nonzero.columns:
    if 'cases' in col:
        nonzero[col] = np.log1p(nonzero[col])
nonzero.head()

Unnamed: 0,Country,GDP per capita,reported_date,cases,cases_1,cases_2,cases_3,cases_4,cases_5,cases_6,...,cases_18,cases_19,cases_20,cases_21,cases_22,cases_23,cases_24,cases_25,cases_26,cases_27
32,China,9770.847088,2020-01-22,6.308098,0.0,0.0,0.0,0.0,0.0,0.0,...,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0
207,China,9770.847088,2020-01-23,6.467699,6.308098,0.0,0.0,0.0,0.0,0.0,...,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0
382,China,9770.847088,2020-01-24,6.82546,6.467699,6.308098,0.0,0.0,0.0,0.0,...,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0
557,China,9770.847088,2020-01-25,7.249215,6.82546,6.467699,6.308098,0.0,0.0,0.0,...,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0
732,China,9770.847088,2020-01-26,7.638198,7.249215,6.82546,6.467699,6.308098,0.0,0.0,...,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0


## Creating a model
We will try using random forest algorithm
Let's use default setting and without cross-validation

In [74]:
from sklearn.ensemble import RandomForestRegressor
train_date = '5/10/2020'
X_train = nonzero[nonzero.reported_date < train_date].drop(columns=['Country', 'reported_date', 'cases']).values
y_train = nonzero[nonzero.reported_date < train_date].cases.values
X_test = nonzero[nonzero.reported_date >= train_date].drop(columns=['Country', 'reported_date', 'cases']).values
y_test = nonzero[nonzero.reported_date >= train_date].cases.values

In [75]:
model = RandomForestRegressor(n_jobs=-1)
model.fit(X_train, y_train, )
model.score(X_test, y_test)

0.9985504200431222

In [77]:
pred = model.predict(X_test)
diff = pred - y_test
test_set = nonzero[nonzero.reported_date >= train_date].copy()
test_set['diff'] = np.abs(diff)
test_set['pred'] = pred
test_set.sort_values('diff', ascending=False).head(20)

Unnamed: 0,Country,GDP per capita,reported_date,cases,cases_1,cases_2,cases_3,cases_4,cases_5,cases_6,...,cases_20,cases_21,cases_22,cases_23,cases_24,cases_25,cases_26,cases_27,diff,pred
20663,Benin,901.543871,2020-05-19,4.875197,5.828946,5.828946,5.828946,5.828946,5.828946,5.793014,...,4.174387,4.174387,4.174387,4.174387,4.007333,4.007333,4.007333,4.007333,0.989097,5.864294
21996,Nicaragua,2028.894755,2020-05-26,6.633318,5.63479,5.63479,5.63479,5.63479,5.63479,5.541264,...,2.833213,2.833213,2.772589,2.772589,2.70805,2.70805,2.70805,2.639057,0.819696,5.813622
19978,Central African Republic,475.72125,2020-05-15,5.710427,4.969813,4.969813,4.969813,4.969813,4.969813,4.969813,...,2.833213,2.833213,2.833213,2.70805,2.70805,2.564949,2.564949,2.564949,0.691468,5.018959
22188,Russia,11288.87244,2020-05-27,12.823097,12.800347,12.775435,12.749797,12.724518,12.696029,12.668406,...,12.084814,12.019321,11.953571,11.886342,11.810716,11.72848,11.647736,11.575891,0.578737,13.401834
23334,United Kingdom,42943.90227,2020-06-03,12.547074,12.540375,12.53443,12.528725,12.521686,12.51579,12.508061,...,12.364959,12.350112,12.335969,12.320887,12.303426,12.285466,12.267309,12.245188,0.57343,13.120504
23509,United Kingdom,42943.90227,2020-06-04,12.553485,12.547074,12.540375,12.53443,12.528725,12.521686,12.51579,...,12.380047,12.364959,12.350112,12.335969,12.320887,12.303426,12.285466,12.267309,0.570652,13.124137
22363,Russia,11288.87244,2020-05-28,12.845429,12.823097,12.800347,12.775435,12.749797,12.724518,12.696029,...,12.143452,12.084814,12.019321,11.953571,11.886342,11.810716,11.72848,11.647736,0.567466,13.412894
23684,United Kingdom,42943.90227,2020-06-05,12.559314,12.553485,12.547074,12.540375,12.53443,12.528725,12.521686,...,12.394467,12.380047,12.364959,12.350112,12.335969,12.320887,12.303426,12.285466,0.567172,13.126486
22013,Russia,11288.87244,2020-05-26,12.800347,12.775435,12.749797,12.724518,12.696029,12.668406,12.640145,...,12.019321,11.953571,11.886342,11.810716,11.72848,11.647736,11.575891,11.506907,0.563887,13.364234
21838,Russia,11288.87244,2020-05-25,12.775435,12.749797,12.724518,12.696029,12.668406,12.640145,12.611344,...,11.953571,11.886342,11.810716,11.72848,11.647736,11.575891,11.506907,11.446348,0.560531,13.335966
