In [1]:
import pandas as pd
from prophet import Prophet
from pymongo import MongoClient
from prophet.plot import plot_plotly, plot_components_plotly

from pyspark.sql.functions import pandas_udf, PandasUDFType
from pyspark.sql.types import *

from pyspark.sql import SparkSession

spark = SparkSession.builder.master("local").getOrCreate()
print(f'Spark Version: {spark.sparkContext.version}')

df = spark.read.format('json').load('./covid_county.json')

df = df.select('GISJOIN', 'cases', 'deaths', 'date', 'formatted_date')

Spark Version: 3.0.1


DataFrame[GISJOIN: string, cases: bigint, deaths: bigint, date: string, formatted_date: struct<$date:string>]

In [2]:
df.show()

+--------+-----+------+----------+--------------------+
| GISJOIN|cases|deaths|      date|      formatted_date|
+--------+-----+------+----------+--------------------+
|G0100010|    3|     0|2020-03-25|[2020-03-25T00:00...|
|G0100010|    0|     0|2020-03-28|[2020-03-28T00:00...|
|G0100010|    0|     0|2020-03-29|[2020-03-29T00:00...|
|G0100010|    1|     0|2020-03-30|[2020-03-30T00:00...|
|G0100010|    0|     0|2020-03-31|[2020-03-31T00:00...|
|G0100010|    3|     0|2020-04-01|[2020-04-01T00:00...|
|G0100010|    0|     0|2020-04-04|[2020-04-04T00:00...|
|G0100010|    0|     0|2020-04-05|[2020-04-05T00:00...|
|G0100010|    0|     1|2020-04-06|[2020-04-06T00:00...|
|G0100010|    0|     0|2020-04-02|[2020-04-02T00:00...|
|G0100010|    0|     0|2020-04-07|[2020-04-07T00:00...|
|G0100010|    0|     0|2020-04-08|[2020-04-08T00:00...|
|G0100010|    5|     0|2020-04-09|[2020-04-09T00:00...|
|G0100010|    0|     0|2020-04-10|[2020-04-10T00:00...|
|G0100010|    2|     0|2020-04-11|[2020-04-11T00

In [3]:
result_schema = StructType([
    StructField("ds", DateType(), True),
    StructField("yhat", DoubleType(), True),
    StructField("yhat_lower", DoubleType(), True),
    StructField("yhat_upper", DoubleType(), True)
])
print('log: result_schema created')


@pandas_udf(result_schema, PandasUDFType.GROUPED_MAP)
def temp(df0):
    # instantiate the model, configure the parameters
    m = Prophet()
    m.fit(df0)
    df0_future = m.make_future_dataframe(periods=365)
    df0_forecast = m.predict(df0_future)

    return df0_forecast[['ds', 'yhat', 'yhat_lower', 'yhat_upper']]


df_cases = df.select('GISJOIN', 'date', 'cases').withColumnRenamed('date', 'ds').withColumnRenamed('cases', 'y')
print('log: Showing df_cases')
df_cases.show()

log: result_schema created
log: Showing df_cases
+--------+----------+---+
| GISJOIN|        ds|  y|
+--------+----------+---+
|G0100010|2020-03-25|  3|
|G0100010|2020-03-28|  0|
|G0100010|2020-03-29|  0|
|G0100010|2020-03-30|  1|
|G0100010|2020-03-31|  0|
|G0100010|2020-04-01|  3|
|G0100010|2020-04-04|  0|
|G0100010|2020-04-05|  0|
|G0100010|2020-04-06|  0|
|G0100010|2020-04-02|  0|
|G0100010|2020-04-07|  0|
|G0100010|2020-04-08|  0|
|G0100010|2020-04-09|  5|
|G0100010|2020-04-10|  0|
|G0100010|2020-04-11|  2|
|G0100010|2020-04-12|  0|
|G0100010|2020-04-13|  0|
|G0100010|2020-09-15| 16|
|G0100010|2020-09-16| 18|
|G0100010|2020-09-17|  5|
+--------+----------+---+
only showing top 20 rows



In [4]:
results = (df_cases.groupBy('GISJOIN').apply(temp))

print('log: Showing results');
results.show()



log: Showing results
+----------+-------------------+-------------------+------------------+
|        ds|               yhat|         yhat_lower|        yhat_upper|
+----------+-------------------+-------------------+------------------+
|2020-03-17| -9.917528714739337|-148.52329811613367|127.89335354248482|
|2020-03-18| 2.5070101773115603| -143.1019733275764|151.59268202876288|
|2020-03-19|  24.17156137325841|-119.84131526305589|159.49667096581467|
|2020-03-20|  8.741398303584567|-134.29060909765994|150.37174283875493|
|2020-03-21|  36.17421812539644|-109.69836329824889| 182.7802933851085|
|2020-03-22| -6.010805038355219|-157.09645756816394|135.77816301115692|
|2020-03-23|-16.904925661884718|-164.51731929361455|124.49948201128198|
|2020-03-24| -3.022601101004865|-148.72412465274144|146.21445841590707|
|2020-03-25|   9.40193779094049|-143.59018705402633|152.50635811309732|
|2020-03-26| 31.066488986919392| -111.8707656014515| 177.9503284453844|
|2020-03-27| 15.636325917202164|-135.188519

In [9]:
results.count()

2194486