# PySpark

## Init Engine

In [1]:
from pyspark.sql import SparkSession
spark = SparkSession.builder.appName('pyspark-covid').getOrCreate()

Starting Spark application


ID,YARN Application ID,Kind,State,Spark UI,Driver log,Current session?
4,,pyspark,idle,,,✔


SparkSession available as 'spark'.


# Data loading

## Create DataFrame from CSV file

In [2]:
df = spark.read. \
    option("header", "true"). \
    option("inferSchema", "true"). \
    option("delimiter", ","). \
    option("escape", "\""). \
    csv("/data/time_series_covid19_deaths_global_narrow.csv")
    
df.printSchema()

root
 |-- Province/State: string (nullable = true)
 |-- Country/Region: string (nullable = true)
 |-- Lat: double (nullable = true)
 |-- Long: double (nullable = true)
 |-- Date: timestamp (nullable = true)
 |-- Value: integer (nullable = true)
 |-- ISO 3166-1 Alpha 3-Codes: string (nullable = true)
 |-- Region Code: integer (nullable = true)
 |-- Sub-region Code: integer (nullable = true)
 |-- Intermediate Region Code: integer (nullable = true)

## Normalize schema

* Remove schema name from special char

In [3]:
normDf = df. \
    withColumnRenamed("ISO 3166-1 Alpha 3-Codes", "iso_country_code"). \
    withColumnRenamed("Region Code", "region_code"). \
    withColumnRenamed("Sub-region Code", "sub_region_code"). \
    withColumnRenamed("Intermediate Region Code", "interm_region_code"). \
    withColumnRenamed("Province/State", "province_state"). \
    withColumnRenamed("Country/Region", "country_region")

normDf.printSchema()

root
 |-- province_state: string (nullable = true)
 |-- country_region: string (nullable = true)
 |-- Lat: double (nullable = true)
 |-- Long: double (nullable = true)
 |-- Date: timestamp (nullable = true)
 |-- Value: integer (nullable = true)
 |-- iso_country_code: string (nullable = true)
 |-- region_code: integer (nullable = true)
 |-- sub_region_code: integer (nullable = true)
 |-- interm_region_code: integer (nullable = true)

# Data cleaning

## Sampling

In [4]:
normDf.show(1, truncate=False)

+--------------+--------------+--------+---------+---------------------+-----+----------------+-----------+---------------+------------------+
|province_state|country_region|Lat     |Long     |Date                 |Value|iso_country_code|region_code|sub_region_code|interm_region_code|
+--------------+--------------+--------+---------+---------------------+-----+----------------+-----------+---------------+------------------+
|null          |Afghanistan   |33.93911|67.709953|2020-10-01 00:00:00.0|1458 |AFG             |142        |34             |null              |
+--------------+--------------+--------+---------+---------------------+-----+----------------+-----------+---------------+------------------+
only showing top 1 row

## DataSet Validation
### Check overall values

In [5]:
from pyspark.sql.functions import sum, col, desc

column = "sum_casualties"

isoDF = normDf. \
    groupBy(col("iso_country_code")). \
    agg(sum("Value").alias(column)). \
    orderBy(desc(column))

isoDF.agg(sum(column)). \
    show(truncate=False)

+-------------------+
|sum(sum_casualties)|
+-------------------+
|98961992           |
+-------------------+

### Column analysis

#### Numeric values

In [6]:
from pyspark.sql.functions import count, mean, stddev, min, max, sum

normDf.agg(count("Lat"), mean("Lat"), stddev("Lat"), min("Lat"), max("Lat"), sum("Lat")).show()

+----------+-----------------+-----------------+--------+--------+------------------+
|count(Lat)|         avg(Lat)| stddev_samp(Lat)|min(Lat)|max(Lat)|          sum(Lat)|
+----------+-----------------+-----------------+--------+--------+------------------+
|     67564|21.07662424812158|24.85792699104551|-51.7963| 71.7069|1424021.0407000864|
+----------+-----------------+-----------------+--------+--------+------------------+

#### String values

In [7]:
from pyspark.sql.functions import col

distinctValuesDF = normDf. \
    select(col("iso_country_code"), col("country_region"), col("province_state"), col("Lat"), col("Long")). \
    distinct()

print(distinctValuesDF.count())

266

In [8]:
from pyspark.sql.functions import desc

distinctValuesDF.groupBy("iso_country_code").count().orderBy(desc("count")).show()

+----------------+-----+
|iso_country_code|count|
+----------------+-----+
|             CHN|   31|
|             CAN|   12|
|             AUS|    8|
|            null|    5|
|             SPM|    1|
|             FRA|    1|
|             POL|    1|
|             TCA|    1|
|             LVA|    1|
|             JAM|    1|
|             ZMB|    1|
|             BRA|    1|
|             ARM|    1|
|             MOZ|    1|
|             JOR|    1|
|             CUB|    1|
|             ABW|    1|
|             SOM|    1|
|             BRN|    1|
|             COD|    1|
+----------------+-----+
only showing top 20 rows

In [9]:
distinctValuesDF.where("iso_country_code is null").show()

+----------------+----------------+----------------+-------+-------+
|iso_country_code|  country_region|  province_state|    Lat|   Long|
+----------------+----------------+----------------+-------+-------+
|            null|  United Kingdom| Channel Islands|49.3723|-2.3644|
|            null|      MS Zaandam|            null|    0.0|    0.0|
|            null|Diamond Princess|            null|    0.0|    0.0|
|            null|          Canada|Diamond Princess|    0.0|    0.0|
|            null|          Canada|  Grand Princess|    0.0|    0.0|
+----------------+----------------+----------------+-------+-------+

In [10]:
from pyspark.sql.functions import col

distinctValuesDF.where(col("iso_country_code") == "CHN").show()

+----------------+--------------+--------------+------------------+--------+
|iso_country_code|country_region|province_state|               Lat|    Long|
+----------------+--------------+--------------+------------------+--------+
|             CHN|         China|      Xinjiang|           41.1129| 85.2401|
|             CHN|         China|        Shanxi|           37.5777|112.2922|
|             CHN|         China|       Guangxi|           23.8298|108.7881|
|             CHN|         China|         Hebei|            39.549|116.1306|
|             CHN|         China|      Shanghai|31.201999999999998|121.4491|
|             CHN|         China|         Hunan|           27.6104|111.7088|
|             CHN|         China|      Shandong|           36.3427|118.1498|
|             CHN|         China|       Beijing|           40.1824|116.4142|
|             CHN|         China|       Shaanxi|           35.1917|108.8701|
|             CHN|         China|         Tibet|           31.6927| 88.0924|

#### Date values

In [11]:
from pyspark.sql.functions import min, max

column = "Date"
normDf.agg(min(column).alias("start_date_range"), max(column).alias("end_date_range")).show(truncate=False)

+---------------------+---------------------+
|start_date_range     |end_date_range       |
+---------------------+---------------------+
|2020-01-22 00:00:00.0|2020-10-01 00:00:00.0|
+---------------------+---------------------+

## Understand distribution
### Parent child columns

In [12]:
distinctValuesDF.groupBy(col("iso_country_code")).count().show()

+----------------+-----+
|iso_country_code|count|
+----------------+-----+
|             HTI|    1|
|             PSE|    1|
|             POL|    1|
|             BRB|    1|
|             LVA|    1|
|             JAM|    1|
|             ZMB|    1|
|             SPM|    1|
|             BRA|    1|
|             ARM|    1|
|             MOZ|    1|
|             JOR|    1|
|             CUB|    1|
|             ABW|    1|
|             SOM|    1|
|             FRA|    1|
|             TCA|    1|
|             BRN|    1|
|             COD|    1|
|             URY|    1|
+----------------+-----+
only showing top 20 rows

In [13]:
distinctValuesDF.groupBy(col("iso_country_code")). \
    count(). \
    orderBy(desc("count")). \
    show()

+----------------+-----+
|iso_country_code|count|
+----------------+-----+
|             CHN|   31|
|             CAN|   12|
|             AUS|    8|
|            null|    5|
|             SPM|    1|
|             FRA|    1|
|             POL|    1|
|             TCA|    1|
|             LVA|    1|
|             JAM|    1|
|             ZMB|    1|
|             BRA|    1|
|             ARM|    1|
|             MOZ|    1|
|             JOR|    1|
|             CUB|    1|
|             ABW|    1|
|             SOM|    1|
|             BRN|    1|
|             COD|    1|
+----------------+-----+
only showing top 20 rows

In [14]:
from pyspark.sql.functions import col, count, collect_set

distinctValuesDF.groupBy(col("country_region")). \
    agg(count("province_state").alias("count"), collect_set("province_state").alias("contained_province")). \
    where(col("count") > 0). \
    orderBy("country_region"). \
    show(truncate=False)

+--------------+-----+----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------+
|country_region|count|contained_province                                                                                                                                                                                                                                                                                        |
+--------------+-----+----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------+
|Australia     |8    |[Queensland,

### Distribution By Time Window

In [16]:
from pyspark.sql.functions import col, sum, desc, window

tsDF = normDf. \
    groupBy(col("iso_country_code"), col("country_region"), col("province_state"), window(col("Date"), "1 day")). \
    agg(sum("Value").alias("casualties")). \
    orderBy(desc("window.start"))

tsDF. \
    show(truncate=False)

+----------------+----------------------+--------------------------------+---------------------------------------------+----------+
|iso_country_code|country_region        |province_state                  |window                                       |casualties|
+----------------+----------------------+--------------------------------+---------------------------------------------+----------+
|BIH             |Bosnia and Herzegovina|null                            |[2020-10-01 00:00:00.0,2020-10-02 00:00:00.0]|861       |
|SVK             |Slovakia              |null                            |[2020-10-01 00:00:00.0,2020-10-02 00:00:00.0]|48        |
|ITA             |Italy                 |null                            |[2020-10-01 00:00:00.0,2020-10-02 00:00:00.0]|35918     |
|MMR             |Burma                 |null                            |[2020-10-01 00:00:00.0,2020-10-02 00:00:00.0]|321       |
|SGP             |Singapore             |null                            |[2

#### Window value stats

In [17]:
column = "window.start"
tsDF.agg(min(column).alias("start_date_range"), max(column).alias("end_date_range")).show(truncate=False)

+---------------------+---------------------+
|start_date_range     |end_date_range       |
+---------------------+---------------------+
|2020-01-22 00:00:00.0|2020-10-01 00:00:00.0|
+---------------------+---------------------+

### Data preparation & ordering

In [18]:
storedDF = tsDF.select("iso_country_code", "country_region", "province_state", "window.start", "window.end", "casualties"). \
    na.fill("unknown", subset=["country_region", "province_state"]). \
    orderBy("iso_country_code", "start")

storedDF.show(1, truncate=False)

+----------------+--------------+--------------+---------------------+---------------------+----------+
|iso_country_code|country_region|province_state|start                |end                  |casualties|
+----------------+--------------+--------------+---------------------+---------------------+----------+
|null            |MS Zaandam    |unknown       |2020-01-22 00:00:00.0|2020-01-23 00:00:00.0|0         |
+----------------+--------------+--------------+---------------------+---------------------+----------+
only showing top 1 row

## Exporting

### Persist to CSV file

In [19]:
storedDF.coalesce(1). \
    write.mode('overwrite'). \
    option("mapreduce.fileoutputcommitter.marksuccessfuljobs","false"). \
    option("header","true"). \
    csv("/data/covid-19-by-country.csv")

### Save data into Table

In [20]:
storedDF. \
    write.mode("overwrite"). \
    saveAsTable("data")

### Export dataframe using .toPandas()

In [None]:
pandaDf = tsDF.select("iso_country_code", "window.start", "casualties").toPandas()

In [None]:
pandaDf.head()

# Tables

In [22]:
spark.sql("SHOW DATABASES").show()

+------------+
|databaseName|
+------------+
|     default|
+------------+

In [23]:
spark.catalog.listTables("default")

[Table(name=u'data', database=u'default', description=None, tableType=u'MANAGED', isTemporary=False)]

### SparkSQL - Data query

In [24]:
pdf = spark.sql("SELECT start as time, iso_country_code, casualties \
FROM data \
WHERE iso_country_code = 'FRA' \
ORDER BY time") \
pdf.show(5)


VBox(children=(HBox(children=(HTML(value='Type:'), Button(description='Table', layout=Layout(width='70px'), st…

Output()

# Visualization

## matplotlib

In [None]:
import sys  
!{sys.executable} -m pip install --user matplotlib

In [None]:
%matplotlib inline

import matplotlib.pyplot as plt
import numpy as np

plt.plot(pdf['time'], pdf['casualties'], color='red')
plt.gcf().autofmt_xdate()
plt.show()

In [None]:
fig, ax = plt.subplots()
for country in pdf['iso_country_code']:
    dataframe = pdf.loc[pdf['iso_country_code'] == country]
    ax.plot(dataframe['time'], dataframe['casualties'], label=country)

ax.set_title('Casualties per country')
ax.legend(loc='upper right')

## seaborn

In [None]:
import matplotlib.pyplot as plt
import seaborn as sns   

sns.catplot(x ='time', y ='casualties', data = pdf)
plt.title('COVID-19 Casualities')
plt.show() 