# PySpark

## Init Engine

In [1]:
from pyspark.sql import SparkSession
spark = SparkSession.builder.appName('pyspark-covid').getOrCreate()

Starting Spark application


ID,YARN Application ID,Kind,State,Spark UI,Driver log,Current session?
2,,pyspark,idle,,,✔


SparkSession available as 'spark'.


# Data loading

### Create DataFrame from CSV file

In [None]:
df = spark.read. \
    option("header", "true"). \
    option("inferSchema", "true"). \
    option("delimiter", ","). \
    option("escape", "\""). \
    csv("/data/time_series_covid19_deaths_global_narrow.csv")

In [None]:
df.printSchema()

In [None]:
df.show(1, truncate=False)

## Data cleaning
### Type conversion

In [None]:
cleanDf = df. \
    withColumnRenamed("ISO 3166-1 Alpha 3-Codes", "iso_country_code"). \
    withColumnRenamed("Region Code", "region_code"). \
    withColumnRenamed("Sub-region Code", "sub_region_code"). \
    withColumnRenamed("Intermediate Region Code", "interm_region_code"). \
    withColumnRenamed("Province/State", "province_state"). \
    withColumnRenamed("Country/Region", "country_region")

cleanDf.printSchema()

## Data Validation
### Check overall values

In [None]:
from pyspark.sql.functions import sum, col, desc

column = "sum_casualties"

isoDF = cleanDf. \
    groupBy(col("iso_country_code")). \
    agg(sum("Value").alias(column)). \
    orderBy(desc(column))

isoDF.agg(sum(column)). \
    show(truncate=False)

### Column analysis

#### Numeric values

In [None]:
from pyspark.sql.functions import count, mean, stddev, min, max, sum

df.agg(count("Lat"), mean("Lat"), stddev("Lat"), min("Lat"), max("Lat"), sum("Lat")).show()

#### String values

In [None]:
from pyspark.sql.functions import col

distinctValuesDF = cleanDf. \
    select(col("iso_country_code"), col("country_region"), col("province_state"), col("Lat"), col("Long")). \
    distinct()

print(distinctValuesDF.count())

In [None]:
from pyspark.sql.functions import desc

distinctValuesDF.groupBy("iso_country_code").count().orderBy(desc("count")).show()

In [None]:
distinctValuesDF.where("iso_country_code is null").show()

In [None]:
from pyspark.sql.functions import col

distinctValuesDF.where(col("iso_country_code") == "CHN").show()

#### Date values

In [None]:
from pyspark.sql.functions import min, max

column = "Date"
df.agg(min(column).alias("start_date_range"), max(column).alias("end_date_range")).show(truncate=False)

## Understand distribution
### Parent child columns

In [None]:
distinctValuesDF.groupBy(col("iso_country_code")).count().show()

In [None]:
distinctValuesDF.groupBy(col("iso_country_code")). \
    count(). \
    orderBy(desc("count")). \
    show()

In [None]:
from pyspark.sql.functions import col, count, collect_set

distinctValuesDF.groupBy(col("country_region")). \
    agg(count("province_state").alias("count"), collect_set("province_state").alias("contained_province")). \
    where(col("count") > 0). \
    orderBy("country_region"). \
    show(truncate=False)

### Distribution By Time Window

In [None]:
from pyspark.sql.functions import col, sum, desc, window

tsDF = cleanDf. \
    groupBy(col("iso_country_code"), col("country_region"), col("province_state"), window(col("Date"), "1 day")). \
    agg(sum("Value").alias("casualties")). \
    orderBy(desc("window.start"))

tsDF. \
    show(truncate=False)

#### Window value stats

In [None]:
column = "window.start"
tsDF.agg(min(column).alias("start_date_range"), max(column).alias("end_date_range")).show(truncate=False)

### Data preparation & ordering

In [None]:
storedDF = tsDF.select("iso_country_code", "country_region", "province_state", "window.start", "window.end", "casualties"). \
    na.fill("unknown", subset=["country_region", "province_state"]). \
    orderBy("iso_country_code", "start")

storedDF.show(1, truncate=False)

## Exporting

### Persist to CSV file

In [None]:
storedDF.coalesce(1). \
    write.mode('overwrite'). \
    option("mapreduce.fileoutputcommitter.marksuccessfuljobs","false"). \
    option("header","true"). \
    csv("/data/covid-19.csv")

### Save data into Table

In [None]:
storedDF. \
    write.mode("overwrite"). \
    saveAsTable("data")

### Export dataframe to %%local

In [None]:
pandaDf = tsDF.select("iso_country_code", "window.start", "casualties").toPandas()

In [None]:
pandaDf.head()

# Tables

In [None]:
spark.sql("SHOW DATABASES").show()

In [None]:
spark.catalog.listTables("default")

# Visualization

## matplotlib

In [None]:
%%local
%matplotlib inline

import matplotlib.pyplot as plt
import numpy as np

plt.plot(pdf['time'], pdf['casualties'], color='red')
plt.gcf().autofmt_xdate()
plt.show()

In [None]:
%%local
fig, ax = plt.subplots()
for country in pdf['iso_country_code']:
    dataframe = pdf.loc[pdf['iso_country_code'] == country]
    ax.plot(dataframe['time'], dataframe['casualties'], label=country)

ax.set_title('Casualties per country')
ax.legend(loc='upper right')

## seaborn

In [None]:
%%local
import matplotlib.pyplot as plt
import seaborn as sns   

sns.catplot(x ='time', y ='casualties', data = pdf)
plt.title('COVID-19 Casualities')
plt.show() 