In [1]:
import matplotlib.pyplot as plt
import numpy as np
from fbprophet import *
import pandas as pd
from functools import reduce
import pyspark.sql.functions as F  
from pyspark.sql.functions import col  

def get_mod():
  m = Prophet(daily_seasonality=False,weekly_seasonality=False,mcmc_samples=1000,growth="linear",n_changepoints=0)
  m.add_seasonality(name="quarterly",period=30*4,fourier_order=5)
  return m 

In [2]:
#read csv files and union them together into a dtaaframe

def get_df(c):
  df = spark.read.format("csv").option("header","True")\
  .load("/covid/data/output-data/historical-deaths/"+c+"_weekly_deaths.csv")
  return df

alldf = {
"uk":  get_df("britain").filter("region='Britain'").select("start_date","total_deaths","covid_deaths"),
"ger": get_df("germany").select("start_date","total_deaths","covid_deaths"),
"fr":  get_df("france").filter("region='France'").select("start_date","total_deaths","covid_deaths"),
"us":  get_df("united_states").groupBy("start_date").agg(F.sum(col("total_deaths")).alias("total_deaths"),F.sum(col("covid_deaths")).alias("covid_deaths")),
"it":  get_df("italy").filter("region='Italy'").groupBy("start_date").agg(F.sum(col("total_deaths")).alias("total_deaths"),F.sum(col("covid_deaths")).alias("covid_deaths")),
"nl":  get_df("netherlands").filter("region='Netherlands'").groupBy("start_date").agg(F.sum(col("total_deaths")).alias("total_deaths"),F.sum(col("covid_deaths")).alias("covid_deaths"))
}

display(
    reduce(lambda x,y: x.union(y),[v.withColumn("country",F.lit(k)) for  (k,v)  in alldf.items()])
    .select("country","start_date","total_deaths").filter("rand()>0.3")
 )

In [3]:

def get_deltas(df):
  missing = df.filter("start_date>'2020-02-01'").count()
  dfp = df.withColumn("ds",col("start_date")).withColumn("y",col("total_deaths")).toPandas()
  dft = df.withColumn("ds",col("start_date")).withColumn("y",col("total_deaths")).filter("ds<'2020-01-01'").toPandas()
  m = get_mod()
  m.fit(dft)
  dfp = m.setup_dataframe(dfp)
  yhat = m.sample_posterior_predictive(dfp)["yhat"]
  x = np.sum((np.array(dfp["y"]).reshape((-1,1))-yhat)[-missing:,:],axis=0)
  return x

In [4]:
all_deltas = {k:get_deltas(alldf[k]) for (k,v) in alldf.items()}

totaldeath = reduce(lambda x,y: x.union(y),[v.withColumn("country",F.lit(k)) for  (k,v)  in alldf.items()]).groupBy("country").agg(F.sum("covid_deaths")).toPandas()

totaldeath = {row[1][0]:row[1][1] for row in totaldeath.iterrows()}

In [5]:
f,a = plt.subplots(6,1, figsize=(5,10))
a = a.ravel()
for country,ax in zip(all_deltas.keys(),a):
    x = allhists[country]
    ax.set_title(country)
    ax.hist(x, 50,density=True,histtype="stepfilled")
    ax.set_xlim(0,100*1000)
    ax.axvline(x=totaldeath[country], color='r', linestyle='dashed', linewidth=2)

plt.tight_layout()

display(ax)

In [6]:
def get_forecasts(df):
   dfp = df.withColumn("ds",col("start_date")).withColumn("y",col("total_deaths")).toPandas()
   dft = df.withColumn("ds",col("start_date")).withColumn("y",col("total_deaths")).filter("ds<'2020-01-01'").toPandas()
   m = get_mod()
   m.fit(dft)
   yhat = m.sample_posterior_predictive(m.setup_dataframe(dfp))["yhat"]

   def extract(i):
     forspark = pd.DataFrame(data=yhat[:,i])
     forspark["ds"] = dfp["ds"]
     forspark["bid"] = [i for x in range(yhat.shape[0])]
     forspark.columns = ["pred","ds","bid"]
     return forspark
   
   output =extract(i)
   output.columns = ["pred","ds","bid"]
   
   for d in range(1,yhat.shape[1]):
     output  = pd.concat([output,get(d)], ignore_index=True)

   return spark.createDataFrame(output), m.params

In [7]:
display(getForc(alldf["ger"])[0]
       .join(alldf["ger"].withColumn("ds",col("start_date"))
       .withColumn("actual",col("total_deaths")),["ds"],"left")
       .filter("ds>'2017-01-01'")   
       .groupBy("start_date").agg(F.avg(col("pred")).alias("pred"),F.avg(col("actual")).alias("actual"))
       )