In [1]:
"""%sh git clone https://github.com/TheEconomist/covid-19-excess-deaths-tracker.git""" 

In [2]:
%python
import matplotlib.pyplot as plt
from fbprophet import *
import pandas as pd

def get_mod():
  m = Prophet(daily_seasonality=False,weekly_seasonality=False,mcmc_samples=400,growth="linear",n_changepoints=0)
  m.add_seasonality(name="quarterly",period=30*4,fourier_order=5)
  return m 

In [3]:
from functools import reduce
import pyspark.sql.functions as F  
from pyspark.sql.functions import col  

def get_df(c):
  df = spark.read.format("csv").option("header","True")\
  .load("/covid/data/output-data/historical-deaths/"+c+"_weekly_deaths.csv")
  return df

alldf = {
"uk":  get_df("britain").filter("region='Britain'").select("start_date","total_deaths","covid_deaths"),
"ger": get_df("germany").select("start_date","total_deaths","covid_deaths"),
"fr":  get_df("france").filter("region='France'").select("start_date","total_deaths","covid_deaths"),
"us":  get_df("united_states").groupBy("start_date").agg(F.sum(col("total_deaths")).alias("total_deaths"),F.sum(col("covid_deaths")).alias("covid_deaths")),
"it":  get_df("italy").filter("region='Italy'").groupBy("start_date").agg(F.sum(col("total_deaths")).alias("total_deaths"),F.sum(col("covid_deaths")).alias("covid_deaths")),
"nl":  get_df("netherlands").filter("region='Netherlands'").groupBy("start_date").agg(F.sum(col("total_deaths")).alias("total_deaths"),F.sum(col("covid_deaths")).alias("covid_deaths"))
}

display(reduce(lambda x,y: x.union(y),[v.withColumn("country",F.lit(k)) for  (k,v)  in alldf.items()])
        .select("country","start_date","total_deaths").filter("rand()>0.3")
       )

country,start_date,total_deaths
uk,2014-12-27,13491.0
uk,2015-01-03,17755.0
uk,2015-01-17,15820.0
uk,2015-01-31,13672.0
uk,2015-02-07,13494.0
uk,2015-02-14,13111.0
uk,2015-02-21,13193.0
uk,2015-02-28,13122.0
uk,2015-03-07,12467.0
uk,2015-03-14,12079.0


In [4]:
import matplotlib.pyplot as plt
import numpy as np

def get_deltas(df):
  missing = df.filter("start_date>'2020-02-01'").count()
  dfp = df.withColumn("ds",col("start_date")).withColumn("y",col("total_deaths")).toPandas()
  dft = df.withColumn("ds",col("start_date")).withColumn("y",col("total_deaths")).filter("ds<'2020-01-01'").toPandas()
  m = get_mod()
  m.fit(dft)
  dfp = m.setup_dataframe(dfp)
  yhat = m.sample_posterior_predictive(dfp)["yhat"]
  x = np.sum((np.array(dfp["y"]).reshape((-1,1))-yhat)[-missing:,:],axis=0)
  return x

In [5]:
all_deltas = {k:get_deltas(alldf[k]) for (k,v) in alldf.items()}

totaldeath = reduce(lambda x,y: x.union(y),[v.withColumn("country",F.lit(k)) for  (k,v)  in alldf.items()]).groupBy("country").agg(F.sum("covid_deaths")).toPandas()

totaldeath = {row[1][0]:row[1][1] for row in totaldeath.iterrows()}

In [6]:
f,a = plt.subplots(6,1, figsize=(5,10))
a = a.ravel()
for country,ax in zip(all_deltas.keys(),a):
    x = allhists[country]
    ax.set_title(country)
    ax.hist(x, 50,density=True,histtype="stepfilled")
    ax.set_xlim(0,100*1000)
    ax.axvline(x=totaldeath[country], color='r', linestyle='dashed', linewidth=2)

plt.tight_layout()

display(ax)

In [7]:
def get_forecasts(df):
  dfp = df.withColumn("ds",col("start_date")).withColumn("y",col("total_deaths")).toPandas()
  dft = df.withColumn("ds",col("start_date")).withColumn("y",col("total_deaths")).filter("ds<'2020-01-01'").toPandas()
  m = get_mod()
  m.fit(dft)
  yhat = m.sample_posterior_predictive(m.setup_dataframe(dfp))["yhat"]

  def extract(i):
    forspark = pd.DataFrame(data=yhat[:,i])
    forspark["ds"] = dfp["ds"]
    forspark["bid"] = [i for x in range(yhat.shape[0])]
    forspark.columns = ["pred","ds","bid"]
    return forspark
  
  output =extract(i)
  output.columns = ["pred","ds","bid"]
  
  for d in range(1,yhat.shape[1]):
    output  = pd.concat([output,get(d)], ignore_index=True)

  return spark.createDataFrame(output), m.params

In [8]:
display(getForc(alldf["ger"])[0]
       .join(alldf["ger"].withColumn("ds",col("start_date"))
       .withColumn("actual",col("total_deaths")),["ds"],"left")
       .filter("ds>'2017-01-01'")   
       .groupBy("start_date").agg(F.avg(col("pred")).alias("pred"),F.avg(col("actual")).alias("actual"))
       )

start_date,pred,actual
2017-05-14,17014.650513130284,17394.0
2020-02-26,22430.1741715765,19507.0
2018-12-17,18787.882003133007,18990.0
2017-04-30,17268.80960560382,17186.0
2018-03-12,20637.97633040408,24385.0
2017-10-01,16791.904184906147,16601.0
2017-12-03,18491.44138434458,18254.0
2018-11-26,18655.289271791626,18091.0
2019-01-15,19809.921281608476,18952.0
2018-11-12,17753.44372835163,16842.0
