In [1]:
# COVID 19 Spread Prediction

In [2]:
%sql
show tables

database,tableName,isTemporary
default,covid19_spread,False
default,covid_spread,False


In [3]:
%sql
select country, date, log(label + 1) as confirmed from covid_spread where date = '2020-04-24' order by label desc limit 10

country,date,confirmed
US,2020-04-24T00:00:00.000+0000,13.676787425838866
Spain,2020-04-24T00:00:00.000+0000,12.300314072603523
Italy,2020-04-24T00:00:00.000+0000,12.154642499673624
France,2020-04-24T00:00:00.000+0000,11.979774122522665
Germany,2020-04-24T00:00:00.000+0000,11.942009438677026
United Kingdom,2020-04-24T00:00:00.000+0000,11.844004613571816
Turkey,2020-04-24T00:00:00.000+0000,11.530676970545883
Iran,2020-04-24T00:00:00.000+0000,11.387305551045577
China,2020-04-24T00:00:00.000+0000,11.33721401321588
Russia,2020-04-24T00:00:00.000+0000,11.136383034489144


In [4]:
%sql
select country, since_100th_confirmed_n_days, label as confirmed from covid19_spread where label > 10000 

country,since_100th_confirmed_n_days,confirmed
China,10,11891
China,11,16630
China,12,19716
China,13,23707
China,14,27440
China,15,30587
China,16,34110
China,17,36814
China,18,39829
China,19,42354


In [5]:
train_df = spark.sql('select * from covid_spread where date < "2020-04-17"')
test_df = spark.sql('select * from covid_spread where date > "2020-04-18" and date < "2020-04-24"')
today_df = spark.sql('select * from covid_spread where date = "2020-04-24"')

In [6]:
train_df.cache()
test_df.cache()
today_df.cache()

test_df.printSchema()

In [7]:
test_df.describe().toPandas().transpose()

Unnamed: 0,0,1,2,3,4
summary,count,mean,stddev,min,max
country,534,,,Afghanistan,Venezuela
since_1st_confirmed_n_days,534,57.859550561797754,16.153605205912104,29,92
since_100th_confirmed_n_days,534,36.62359550561798,11.221072268079299,12,92
since_1st_recovered_n_days,534,44.29775280898876,16.876088377291232,13,92
since_100th_recovered_n_days,534,18.091760299625467,13.475466119560908,0,86
since_1st_deaths_n_days,534,37.54494382022472,12.482514701027922,14,92
since_10th_deaths_n_days,534,24.0561797752809,14.467814742637621,0,92
confirmed_n_lag_7d,534,21606.026217228464,70034.1288425759,175,667592
recovered_n_lag_7d,534,5209.250936329588,14857.638586293844,0,78401


In [8]:
display(test_df)

country,date,since_1st_confirmed_n_days,since_100th_confirmed_n_days,since_1st_recovered_n_days,since_100th_recovered_n_days,since_1st_deaths_n_days,since_10th_deaths_n_days,confirmed_n_lag_7d,recovered_n_lag_7d,deaths_n_lag_7d,active_n_lag_7d,confirmed_n_lag_8d,recovered_n_lag_8d,deaths_n_lag_8d,active_n_lag_8d,confirmed_n_lag_9d,recovered_n_lag_9d,deaths_n_lag_9d,active_n_lag_9d,confirmed_n_lag_10d,recovered_n_lag_10d,deaths_n_lag_10d,active_n_lag_10d,label
Afghanistan,2020-04-18T00:00:00.000+0000,54,22,33,0,27,12,555,32,18,505,521,32,15,474,484,32,15,437,444,29,14,401,933
Albania,2020-04-18T00:00:00.000+0000,40,26,28,13,38,21,433,197,23,213,416,182,23,211,409,165,23,221,400,154,22,224,548
Algeria,2020-04-18T00:00:00.000+0000,53,28,37,11,37,29,1825,460,275,1090,1761,405,256,1100,1666,347,235,1084,1572,237,205,1130,2534
Andorra,2020-04-18T00:00:00.000+0000,47,27,37,6,27,18,601,71,26,504,601,71,26,504,583,58,25,500,564,52,23,489,704
Argentina,2020-04-18T00:00:00.000+0000,46,29,35,19,41,22,1975,440,83,1452,1975,375,82,1518,1795,365,72,1358,1715,358,63,1294,2758
Armenia,2020-04-18T00:00:00.000+0000,48,30,32,10,23,9,967,173,13,781,937,149,12,776,921,138,10,773,881,114,9,758,1248
Australia,2020-04-18T00:00:00.000+0000,83,39,79,25,48,23,6303,1806,57,4440,6215,1793,54,4368,6108,1472,51,4585,6010,1080,50,4880,6547
Austria,2020-04-18T00:00:00.000+0000,53,41,40,23,37,27,13806,6604,337,6865,13555,6064,319,7172,13244,5240,295,7709,12942,4512,273,8157,14671
Azerbaijan,2020-04-18T00:00:00.000+0000,48,23,38,9,36,8,1058,200,11,847,991,159,10,822,926,101,9,816,822,63,8,751,1373
Bangladesh,2020-04-18T00:00:00.000+0000,41,12,33,0,31,12,482,36,30,416,424,33,27,364,330,33,21,276,218,33,20,165,2144


In [9]:
from pyspark.ml.feature import VectorAssembler

features_cols = test_df.columns[2:-1]
features_vector = VectorAssembler(inputCols = features_cols, outputCol = 'features')

train = features_vector.transform(train_df)
train = train.select(['features', 'label'])

test = features_vector.transform(test_df)
test = test.select(['features', 'label'])

today = features_vector.transform(today_df)
today = today.select(['features', 'label'])

test.show(10)

In [10]:
from pyspark.ml.regression import LinearRegression

lr = LinearRegression(maxIter=100, regParam=0.3, elasticNetParam=0.8)
lr_model = lr.fit(train)

print("Coefficients: " + str(lr_model.coefficients))
print("Intercept: " + str(lr_model.intercept))

print("RMSE: %f" % lr_model.summary.rootMeanSquaredError)
print("r2: %f" % lr_model.summary.r2)

In [11]:
test_result = lr_model.evaluate(test)

print("RMSE on test data = %g" % test_result.rootMeanSquaredError)
print("R2 on test data = %g" % test_result.r2)

In [12]:
pred = lr_model.transform(test)
display(pred)

features,label,prediction
"List(1, 22, List(), List(54.0, 22.0, 33.0, 0.0, 27.0, 12.0, 555.0, 32.0, 18.0, 505.0, 521.0, 32.0, 15.0, 474.0, 484.0, 32.0, 15.0, 437.0, 444.0, 29.0, 14.0, 401.0))",933,3427.1225815131
"List(1, 22, List(), List(40.0, 26.0, 28.0, 13.0, 38.0, 21.0, 433.0, 197.0, 23.0, 213.0, 416.0, 182.0, 23.0, 211.0, 409.0, 165.0, 23.0, 221.0, 400.0, 154.0, 22.0, 224.0))",548,2211.17430138375
"List(1, 22, List(), List(53.0, 28.0, 37.0, 11.0, 37.0, 29.0, 1825.0, 460.0, 275.0, 1090.0, 1761.0, 405.0, 256.0, 1100.0, 1666.0, 347.0, 235.0, 1084.0, 1572.0, 237.0, 205.0, 1130.0))",2534,7523.854684560882
"List(1, 22, List(), List(47.0, 27.0, 37.0, 6.0, 27.0, 18.0, 601.0, 71.0, 26.0, 504.0, 601.0, 71.0, 26.0, 504.0, 583.0, 58.0, 25.0, 500.0, 564.0, 52.0, 23.0, 489.0))",704,3795.701791568676
"List(1, 22, List(), List(46.0, 29.0, 35.0, 19.0, 41.0, 22.0, 1975.0, 440.0, 83.0, 1452.0, 1975.0, 375.0, 82.0, 1518.0, 1795.0, 365.0, 72.0, 1358.0, 1715.0, 358.0, 63.0, 1294.0))",2758,4046.5941301025114
"List(1, 22, List(), List(48.0, 30.0, 32.0, 10.0, 23.0, 9.0, 967.0, 173.0, 13.0, 781.0, 937.0, 149.0, 12.0, 776.0, 921.0, 138.0, 10.0, 773.0, 881.0, 114.0, 9.0, 758.0))",1248,1056.6619603117715
"List(1, 22, List(), List(83.0, 39.0, 79.0, 25.0, 48.0, 23.0, 6303.0, 1806.0, 57.0, 4440.0, 6215.0, 1793.0, 54.0, 4368.0, 6108.0, 1472.0, 51.0, 4585.0, 6010.0, 1080.0, 50.0, 4880.0))",6547,8703.077346758262
"List(1, 22, List(), List(53.0, 41.0, 40.0, 23.0, 37.0, 27.0, 13806.0, 6604.0, 337.0, 6865.0, 13555.0, 6064.0, 319.0, 7172.0, 13244.0, 5240.0, 295.0, 7709.0, 12942.0, 4512.0, 273.0, 8157.0))",14671,15502.314191046216
"List(1, 22, List(), List(48.0, 23.0, 38.0, 9.0, 36.0, 8.0, 1058.0, 200.0, 11.0, 847.0, 991.0, 159.0, 10.0, 822.0, 926.0, 101.0, 9.0, 816.0, 822.0, 63.0, 8.0, 751.0))",1373,366.9234142812945
"List(1, 22, List(), List(41.0, 12.0, 33.0, 0.0, 31.0, 12.0, 482.0, 36.0, 30.0, 416.0, 424.0, 33.0, 27.0, 364.0, 330.0, 33.0, 21.0, 276.0, 218.0, 33.0, 20.0, 165.0))",2144,3256.9794498253723


In [13]:
pred = lr_model.transform(today)
display(pred)

features,label,prediction
"List(1, 22, List(), List(89.0, 45.0, 85.0, 31.0, 54.0, 29.0, 6522.0, 3808.0, 66.0, 2648.0, 6462.0, 2355.0, 63.0, 4044.0, 6440.0, 2186.0, 63.0, 4191.0, 6415.0, 2186.0, 62.0, 4167.0))",6661,7762.535186218617
"List(1, 22, List(), List(59.0, 47.0, 46.0, 29.0, 43.0, 33.0, 14595.0, 9704.0, 431.0, 4460.0, 14476.0, 8986.0, 410.0, 5080.0, 14336.0, 8098.0, 393.0, 5845.0, 14226.0, 7633.0, 384.0, 6209.0))",15071,15492.811311733069
"List(1, 22, List(), List(89.0, 44.0, 72.0, 31.0, 46.0, 35.0, 32813.0, 10545.0, 1354.0, 20914.0, 30808.0, 9698.0, 1257.0, 19853.0, 28208.0, 8966.0, 1006.0, 18236.0, 27034.0, 8210.0, 899.0, 17925.0))",43353,49401.504769309126
"List(1, 22, List(), List(93.0, 93.0, 93.0, 87.0, 93.0, 93.0, 83760.0, 77552.0, 4636.0, 1572.0, 83403.0, 78401.0, 3346.0, 1656.0, 83356.0, 78311.0, 3346.0, 1699.0, 83306.0, 78200.0, 3345.0, 1761.0))",83885,93533.28799235456
"List(1, 22, List(), List(57.0, 45.0, 49.0, 23.0, 41.0, 34.0, 7268.0, 3571.0, 336.0, 3361.0, 7074.0, 3203.0, 321.0, 3550.0, 6876.0, 2925.0, 309.0, 3642.0, 6706.0, 2689.0, 299.0, 3718.0))",8408,11629.917089326216
"List(1, 22, List(), List(86.0, 42.0, 72.0, 22.0, 34.0, 26.0, 3489.0, 1700.0, 82.0, 1707.0, 3369.0, 1700.0, 75.0, 1594.0, 3237.0, 300.0, 72.0, 2865.0, 3161.0, 300.0, 64.0, 2797.0))",4395,6969.886685763555
"List(1, 22, List(), List(91.0, 55.0, 72.0, 33.0, 69.0, 48.0, 149130.0, 35006.0, 18703.0, 95421.0, 147091.0, 33327.0, 17941.0, 95823.0, 134582.0, 31470.0, 17188.0, 85924.0, 131361.0, 29098.0, 15748.0, 86515.0))",159495,184531.9419659357
"List(1, 22, List(), List(88.0, 54.0, 71.0, 37.0, 46.0, 40.0, 141397.0, 83114.0, 4352.0, 53931.0, 137698.0, 77000.0, 4052.0, 56646.0, 134753.0, 72600.0, 3804.0, 58349.0, 131359.0, 68200.0, 3294.0, 59865.0))",153584,149867.2636274644
"List(1, 22, List(), List(56.0, 43.0, 45.0, 27.0, 40.0, 4.0, 1754.0, 1224.0, 9.0, 521.0, 1739.0, 1144.0, 8.0, 587.0, 1727.0, 1077.0, 8.0, 642.0, 1720.0, 989.0, 8.0, 723.0))",1789,-4356.781499814759
"List(1, 22, List(), List(55.0, 41.0, 38.0, 3.0, 44.0, 29.0, 13980.0, 77.0, 530.0, 13373.0, 13271.0, 77.0, 486.0, 12708.0, 12547.0, 77.0, 444.0, 12026.0, 11479.0, 25.0, 406.0, 11048.0))",17607,25647.171185456576
