# Imports

In [1]:
import pandas as pd
from pyspark.sql import SparkSession
spark = SparkSession.builder.appName('tutorial').getOrCreate()

Once the Spark session is created, Spark web user interface (Web UI) can be accessed from: http://localhost:4040/. 

# 1. Basic Functions

## Read

In [2]:
cases = spark.read.load(
    "data/Case.csv", format="csv", sep=",", inferSchema="true", header="true"
)

## See a few rows in the file

In [3]:
cases.show()

+--------+--------+---------------+-----+--------------------+---------+---------+----------+
| case_id|province|           city|group|      infection_case|confirmed| latitude| longitude|
+--------+--------+---------------+-----+--------------------+---------+---------+----------+
| 1000001|   Seoul|     Yongsan-gu| true|       Itaewon Clubs|      139|37.538621|126.992652|
| 1000002|   Seoul|      Gwanak-gu| true|             Richway|      119| 37.48208|126.901384|
| 1000003|   Seoul|        Guro-gu| true| Guro-gu Call Center|       95|37.508163|126.884387|
| 1000004|   Seoul|   Yangcheon-gu| true|Yangcheon Table T...|       43|37.546061|126.874209|
| 1000005|   Seoul|      Dobong-gu| true|     Day Care Center|       43|37.679422|127.044374|
| 1000006|   Seoul|        Guro-gu| true|Manmin Central Ch...|       41|37.481059|126.894343|
| 1000007|   Seoul|from other city| true|SMR Newly Planted...|       36|        -|         -|
| 1000008|   Seoul|  Dongdaemun-gu| true|       Dongan Churc

In [4]:
cases.limit(10).toPandas()

Unnamed: 0,case_id,province,city,group,infection_case,confirmed,latitude,longitude
0,1000001,Seoul,Yongsan-gu,True,Itaewon Clubs,139,37.538621,126.992652
1,1000002,Seoul,Gwanak-gu,True,Richway,119,37.48208,126.901384
2,1000003,Seoul,Guro-gu,True,Guro-gu Call Center,95,37.508163,126.884387
3,1000004,Seoul,Yangcheon-gu,True,Yangcheon Table Tennis Club,43,37.546061,126.874209
4,1000005,Seoul,Dobong-gu,True,Day Care Center,43,37.679422,127.044374
5,1000006,Seoul,Guro-gu,True,Manmin Central Church,41,37.481059,126.894343
6,1000007,Seoul,from other city,True,SMR Newly Planted Churches Group,36,-,-
7,1000008,Seoul,Dongdaemun-gu,True,Dongan Church,17,37.592888,127.056766
8,1000009,Seoul,from other city,True,Coupang Logistics Center,25,-,-
9,1000010,Seoul,Gwanak-gu,True,Wangsung Church,30,37.481735,126.930121


## Change Column Names

In [5]:
# We can do this simply using the below command to change a single column:
cases = cases.withColumnRenamed("infection_case", "infection_source")
cases.show()

+--------+--------+---------------+-----+--------------------+---------+---------+----------+
| case_id|province|           city|group|    infection_source|confirmed| latitude| longitude|
+--------+--------+---------------+-----+--------------------+---------+---------+----------+
| 1000001|   Seoul|     Yongsan-gu| true|       Itaewon Clubs|      139|37.538621|126.992652|
| 1000002|   Seoul|      Gwanak-gu| true|             Richway|      119| 37.48208|126.901384|
| 1000003|   Seoul|        Guro-gu| true| Guro-gu Call Center|       95|37.508163|126.884387|
| 1000004|   Seoul|   Yangcheon-gu| true|Yangcheon Table T...|       43|37.546061|126.874209|
| 1000005|   Seoul|      Dobong-gu| true|     Day Care Center|       43|37.679422|127.044374|
| 1000006|   Seoul|        Guro-gu| true|Manmin Central Ch...|       41|37.481059|126.894343|
| 1000007|   Seoul|from other city| true|SMR Newly Planted...|       36|        -|         -|
| 1000008|   Seoul|  Dongdaemun-gu| true|       Dongan Churc

In [6]:
# for all columns
cases = cases.toDF(
    *['case_id', 'province', 'city', 'group', 'infection_case', 'confirmed', 'latitude', 'longitude']
)
cases.show()

+-------+--------+---------------+-----+--------------------+---------+---------+----------+
|case_id|province|           city|group|      infection_case|confirmed| latitude| longitude|
+-------+--------+---------------+-----+--------------------+---------+---------+----------+
|1000001|   Seoul|     Yongsan-gu| true|       Itaewon Clubs|      139|37.538621|126.992652|
|1000002|   Seoul|      Gwanak-gu| true|             Richway|      119| 37.48208|126.901384|
|1000003|   Seoul|        Guro-gu| true| Guro-gu Call Center|       95|37.508163|126.884387|
|1000004|   Seoul|   Yangcheon-gu| true|Yangcheon Table T...|       43|37.546061|126.874209|
|1000005|   Seoul|      Dobong-gu| true|     Day Care Center|       43|37.679422|127.044374|
|1000006|   Seoul|        Guro-gu| true|Manmin Central Ch...|       41|37.481059|126.894343|
|1000007|   Seoul|from other city| true|SMR Newly Planted...|       36|        -|         -|
|1000008|   Seoul|  Dongdaemun-gu| true|       Dongan Church|       17

## Select Columns

In [7]:
cases = cases.select('province','city','infection_case','confirmed')
cases.show()

+--------+---------------+--------------------+---------+
|province|           city|      infection_case|confirmed|
+--------+---------------+--------------------+---------+
|   Seoul|     Yongsan-gu|       Itaewon Clubs|      139|
|   Seoul|      Gwanak-gu|             Richway|      119|
|   Seoul|        Guro-gu| Guro-gu Call Center|       95|
|   Seoul|   Yangcheon-gu|Yangcheon Table T...|       43|
|   Seoul|      Dobong-gu|     Day Care Center|       43|
|   Seoul|        Guro-gu|Manmin Central Ch...|       41|
|   Seoul|from other city|SMR Newly Planted...|       36|
|   Seoul|  Dongdaemun-gu|       Dongan Church|       17|
|   Seoul|from other city|Coupang Logistics...|       25|
|   Seoul|      Gwanak-gu|     Wangsung Church|       30|
|   Seoul|   Eunpyeong-gu|Eunpyeong St. Mar...|       14|
|   Seoul|   Seongdong-gu|    Seongdong-gu APT|       13|
|   Seoul|      Jongno-gu|Jongno Community ...|       10|
|   Seoul|     Gangnam-gu|Samsung Medical C...|        7|
|   Seoul|    

## Sort

In [8]:
cases.sort("confirmed").show()
# note that the cases data frame will not change after performing this command 
# as we don’t assign it to any variable.

+-----------------+---------------+--------------------+---------+
|         province|           city|      infection_case|confirmed|
+-----------------+---------------+--------------------+---------+
|          Jeju-do|              -|contact with patient|        0|
|       Gangwon-do|              -|contact with patient|        0|
|            Seoul|     Gangseo-gu|SJ Investment Cal...|        0|
|            Busan|from other city|Cheongdo Daenam H...|        1|
|     Jeollabuk-do|from other city|  Shincheonji Church|        1|
|            Seoul|from other city|Anyang Gunpo Past...|        1|
|            Seoul|     Gangnam-gu|Gangnam Dongin Ch...|        1|
|           Sejong|from other city|  Shincheonji Church|        1|
|     Jeollanam-do|from other city|  Shincheonji Church|        1|
|          Jeju-do|from other city|       Itaewon Clubs|        1|
|            Seoul|from other city|Daejeon door-to-d...|        1|
|            Seoul|              -|         Orange Life|      

In [9]:
# descending Sort
from pyspark.sql import functions as F
cases.sort(F.desc("confirmed")).show()

+-----------------+---------------+--------------------+---------+
|         province|           city|      infection_case|confirmed|
+-----------------+---------------+--------------------+---------+
|            Daegu|         Nam-gu|  Shincheonji Church|     4511|
|            Daegu|              -|contact with patient|      917|
|            Daegu|              -|                 etc|      747|
| Gyeongsangbuk-do|from other city|  Shincheonji Church|      566|
|      Gyeonggi-do|              -|     overseas inflow|      305|
|            Seoul|              -|     overseas inflow|      298|
|            Daegu|   Dalseong-gun|Second Mi-Ju Hosp...|      196|
| Gyeongsangbuk-do|              -|contact with patient|      190|
|            Seoul|              -|contact with patient|      162|
|            Seoul|     Yongsan-gu|       Itaewon Clubs|      139|
| Gyeongsangbuk-do|              -|                 etc|      133|
|            Daegu|         Seo-gu|Hansarang Convale...|      

## Cast

Though we don’t face it in this dataset, there might be scenarios where Pyspark reads a double as integer or string, In such cases, you can use the cast function to convert types.

In [10]:
from pyspark.sql.types import DoubleType, IntegerType, StringType
from pyspark.sql import functions as F
cases = cases.withColumn('confirmed', F.col('confirmed').cast(IntegerType()))
cases = cases.withColumn('city', F.col('city').cast(StringType()))

## Filter

We can filter a data frame using multiple conditions using AND(&), OR(|) and NOT(~) conditions. 

In [11]:
cases.filter((cases.confirmed>10) & (cases.province=='Daegu')).show()

+--------+------------+--------------------+---------+
|province|        city|      infection_case|confirmed|
+--------+------------+--------------------+---------+
|   Daegu|      Nam-gu|  Shincheonji Church|     4511|
|   Daegu|Dalseong-gun|Second Mi-Ju Hosp...|      196|
|   Daegu|      Seo-gu|Hansarang Convale...|      124|
|   Daegu|Dalseong-gun|Daesil Convalesce...|      101|
|   Daegu|     Dong-gu|     Fatima Hospital|       39|
|   Daegu|           -|     overseas inflow|       41|
|   Daegu|           -|contact with patient|      917|
|   Daegu|           -|                 etc|      747|
+--------+------------+--------------------+---------+



## GroupBy

In [12]:
from pyspark.sql import functions as F
(cases
 .groupBy(["province", "city"])
 .agg(F.sum("confirmed"), F.max("confirmed"))
 .show()
)

+----------------+---------------+--------------+--------------+
|        province|           city|sum(confirmed)|max(confirmed)|
+----------------+---------------+--------------+--------------+
|Gyeongsangnam-do|       Jinju-si|             9|             9|
|           Seoul|        Guro-gu|           139|            95|
|           Seoul|     Gangnam-gu|            18|             7|
|         Daejeon|              -|           100|            55|
|    Jeollabuk-do|from other city|             6|             3|
|Gyeongsangnam-do|Changnyeong-gun|             7|             7|
|           Seoul|              -|           561|           298|
|         Jeju-do|from other city|             1|             1|
|Gyeongsangbuk-do|              -|           345|           190|
|Gyeongsangnam-do|   Geochang-gun|            18|            10|
|Gyeongsangbuk-do|        Gumi-si|            10|            10|
|         Incheon|from other city|           117|            53|
|           Busan|       

If you don’t like the new column names, you can use the `alias` keyword to rename columns in the `agg` command itself.

In [13]:
(cases
 .groupBy(["province","city"])
 .agg(
     F.sum("confirmed").alias("TotalConfirmed"),
     F.max("confirmed").alias("MaxFromOneConfirmedCase")
 )
 .show()
)

+----------------+---------------+--------------+-----------------------+
|        province|           city|TotalConfirmed|MaxFromOneConfirmedCase|
+----------------+---------------+--------------+-----------------------+
|Gyeongsangnam-do|       Jinju-si|             9|                      9|
|           Seoul|        Guro-gu|           139|                     95|
|           Seoul|     Gangnam-gu|            18|                      7|
|         Daejeon|              -|           100|                     55|
|    Jeollabuk-do|from other city|             6|                      3|
|Gyeongsangnam-do|Changnyeong-gun|             7|                      7|
|           Seoul|              -|           561|                    298|
|         Jeju-do|from other city|             1|                      1|
|Gyeongsangbuk-do|              -|           345|                    190|
|Gyeongsangnam-do|   Geochang-gun|            18|                     10|
|Gyeongsangbuk-do|        Gumi-si|    

## Joins

In [14]:
# We need one more csv file
regions = spark.read.load(
    "data/Region.csv", format="csv", sep=",", inferSchema="true", header="true"
)
regions.limit(10).toPandas()

Unnamed: 0,code,province,city,latitude,longitude,elementary_school_count,kindergarten_count,university_count,academy_ratio,elderly_population_ratio,elderly_alone_ratio,nursing_home_count
0,10000,Seoul,Seoul,37.566953,126.977977,607,830,48,1.44,15.38,5.8,22739
1,10010,Seoul,Gangnam-gu,37.518421,127.047222,33,38,0,4.18,13.17,4.3,3088
2,10020,Seoul,Gangdong-gu,37.530492,127.123837,27,32,0,1.54,14.55,5.4,1023
3,10030,Seoul,Gangbuk-gu,37.639938,127.025508,14,21,0,0.67,19.49,8.5,628
4,10040,Seoul,Gangseo-gu,37.551166,126.849506,36,56,1,1.17,14.39,5.7,1080
5,10050,Seoul,Gwanak-gu,37.47829,126.951502,22,33,1,0.89,15.12,4.9,909
6,10060,Seoul,Gwangjin-gu,37.538712,127.082366,22,33,3,1.16,13.75,4.8,723
7,10070,Seoul,Guro-gu,37.495632,126.88765,26,34,3,1.0,16.21,5.7,741
8,10080,Seoul,Geumcheon-gu,37.456852,126.895229,18,19,0,0.96,16.15,6.7,475
9,10090,Seoul,Nowon-gu,37.654259,127.056294,42,66,6,1.39,15.4,7.4,952


In [15]:
cases = cases.join(regions, ['province', 'city'], how='left')
cases.limit(10).toPandas()

Unnamed: 0,province,city,infection_case,confirmed,code,latitude,longitude,elementary_school_count,kindergarten_count,university_count,academy_ratio,elderly_population_ratio,elderly_alone_ratio,nursing_home_count
0,Seoul,Yongsan-gu,Itaewon Clubs,139,10210.0,37.532768,126.990021,15.0,13.0,1.0,0.68,16.87,6.5,435.0
1,Seoul,Gwanak-gu,Richway,119,10050.0,37.47829,126.951502,22.0,33.0,1.0,0.89,15.12,4.9,909.0
2,Seoul,Guro-gu,Guro-gu Call Center,95,10070.0,37.495632,126.88765,26.0,34.0,3.0,1.0,16.21,5.7,741.0
3,Seoul,Yangcheon-gu,Yangcheon Table Tennis Club,43,10190.0,37.517189,126.866618,30.0,43.0,0.0,2.26,13.55,5.5,816.0
4,Seoul,Dobong-gu,Day Care Center,43,10100.0,37.668952,127.047082,23.0,26.0,1.0,0.95,17.89,7.2,485.0
5,Seoul,Guro-gu,Manmin Central Church,41,10070.0,37.495632,126.88765,26.0,34.0,3.0,1.0,16.21,5.7,741.0
6,Seoul,from other city,SMR Newly Planted Churches Group,36,,,,,,,,,,
7,Seoul,Dongdaemun-gu,Dongan Church,17,10110.0,37.574552,127.039721,21.0,31.0,4.0,1.06,17.26,6.7,832.0
8,Seoul,from other city,Coupang Logistics Center,25,,,,,,,,,,
9,Seoul,Gwanak-gu,Wangsung Church,30,10050.0,37.47829,126.951502,22.0,33.0,1.0,0.89,15.12,4.9,909.0


# 2. Broadcast/Map Side Joins

Sometimes you might face a scenario where you need to join a very big table(~1B Rows) with a very small table(~100–200 rows). The scenario might also involve increasing the size of your database. Such sort of operations is aplenty in Spark where you might want to apply multiple operations to a particular key. But assuming that the data for each key in the Big table is large, it will involve a lot of data movement. And sometimes so much that the application itself breaks. A small optimization then you can do when joining on such big tables(assuming the other table is small) is to broadcast the small table to each machine/node when you perform a join. You can do this easily using the broadcast keyword. This has been a lifesaver many times with Spark when everything else fails.

In [16]:
from pyspark.sql.functions import broadcast
cases = cases.join(broadcast(regions), ['province', 'city'], how='left')

In [17]:
cases.limit(10).toPandas()

Unnamed: 0,province,city,infection_case,confirmed,code,latitude,longitude,elementary_school_count,kindergarten_count,university_count,...,code.1,latitude.1,longitude.1,elementary_school_count.1,kindergarten_count.1,university_count.1,academy_ratio,elderly_population_ratio,elderly_alone_ratio,nursing_home_count
0,Seoul,Yongsan-gu,Itaewon Clubs,139,10210.0,37.532768,126.990021,15.0,13.0,1.0,...,10210.0,37.532768,126.990021,15.0,13.0,1.0,0.68,16.87,6.5,435.0
1,Seoul,Gwanak-gu,Richway,119,10050.0,37.47829,126.951502,22.0,33.0,1.0,...,10050.0,37.47829,126.951502,22.0,33.0,1.0,0.89,15.12,4.9,909.0
2,Seoul,Guro-gu,Guro-gu Call Center,95,10070.0,37.495632,126.88765,26.0,34.0,3.0,...,10070.0,37.495632,126.88765,26.0,34.0,3.0,1.0,16.21,5.7,741.0
3,Seoul,Yangcheon-gu,Yangcheon Table Tennis Club,43,10190.0,37.517189,126.866618,30.0,43.0,0.0,...,10190.0,37.517189,126.866618,30.0,43.0,0.0,2.26,13.55,5.5,816.0
4,Seoul,Dobong-gu,Day Care Center,43,10100.0,37.668952,127.047082,23.0,26.0,1.0,...,10100.0,37.668952,127.047082,23.0,26.0,1.0,0.95,17.89,7.2,485.0
5,Seoul,Guro-gu,Manmin Central Church,41,10070.0,37.495632,126.88765,26.0,34.0,3.0,...,10070.0,37.495632,126.88765,26.0,34.0,3.0,1.0,16.21,5.7,741.0
6,Seoul,from other city,SMR Newly Planted Churches Group,36,,,,,,,...,,,,,,,,,,
7,Seoul,Dongdaemun-gu,Dongan Church,17,10110.0,37.574552,127.039721,21.0,31.0,4.0,...,10110.0,37.574552,127.039721,21.0,31.0,4.0,1.06,17.26,6.7,832.0
8,Seoul,from other city,Coupang Logistics Center,25,,,,,,,...,,,,,,,,,,
9,Seoul,Gwanak-gu,Wangsung Church,30,10050.0,37.47829,126.951502,22.0,33.0,1.0,...,10050.0,37.47829,126.951502,22.0,33.0,1.0,0.89,15.12,4.9,909.0


# 3. Use SQL with DataFrames

We first register the cases dataframe to a temporary table cases_table on which we can run SQL operations. As you can see, the result of the SQL select statement is again a Spark Dataframe.

In [22]:
from pyspark.sql import SQLContext
sqlContext = SQLContext(spark)
cases = spark.read.load(
    "data/Case.csv", format="csv", sep=",", inferSchema="true", header="true"
)   # Reloading SparkSession for demonstration
cases.registerTempTable('cases_table')
newDF = sqlContext.sql('select * from cases_table where confirmed>100')
newDF.show()



+--------+-----------------+---------------+-----+--------------------+---------+---------+----------+
| case_id|         province|           city|group|      infection_case|confirmed| latitude| longitude|
+--------+-----------------+---------------+-----+--------------------+---------+---------+----------+
| 1000001|            Seoul|     Yongsan-gu| true|       Itaewon Clubs|      139|37.538621|126.992652|
| 1000002|            Seoul|      Gwanak-gu| true|             Richway|      119| 37.48208|126.901384|
| 1000036|            Seoul|              -|false|     overseas inflow|      298|        -|         -|
| 1000037|            Seoul|              -|false|contact with patient|      162|        -|         -|
| 1200001|            Daegu|         Nam-gu| true|  Shincheonji Church|     4511| 35.84008|  128.5667|
| 1200002|            Daegu|   Dalseong-gun| true|Second Mi-Ju Hosp...|      196|35.857375|128.466651|
| 1200003|            Daegu|         Seo-gu| true|Hansarang Convale...|  

# 4. Create New Columns

## Using Spark Native Functions

We can use `.withcolumn` along with PySpark SQL functions to create a new column. In essence, you can find String functions, Date functions, and Math functions already implemented using Spark functions. Our first function, the `F.col` function gives us access to the column. So if we wanted to add 100 to a column, we could use `F.col` as:

In [23]:
import pyspark.sql.functions as F
casesWithNewConfirmed = cases.withColumn("NewConfirmed", 100 + F.col("confirmed"))
casesWithNewConfirmed.show()

+--------+--------+---------------+-----+--------------------+---------+---------+----------+------------+
| case_id|province|           city|group|      infection_case|confirmed| latitude| longitude|NewConfirmed|
+--------+--------+---------------+-----+--------------------+---------+---------+----------+------------+
| 1000001|   Seoul|     Yongsan-gu| true|       Itaewon Clubs|      139|37.538621|126.992652|         239|
| 1000002|   Seoul|      Gwanak-gu| true|             Richway|      119| 37.48208|126.901384|         219|
| 1000003|   Seoul|        Guro-gu| true| Guro-gu Call Center|       95|37.508163|126.884387|         195|
| 1000004|   Seoul|   Yangcheon-gu| true|Yangcheon Table T...|       43|37.546061|126.874209|         143|
| 1000005|   Seoul|      Dobong-gu| true|     Day Care Center|       43|37.679422|127.044374|         143|
| 1000006|   Seoul|        Guro-gu| true|Manmin Central Ch...|       41|37.481059|126.894343|         141|
| 1000007|   Seoul|from other city| t

We can also use math functions like `F.exp` function:

In [24]:
casesWithExpConfirmed = cases.withColumn("ExpConfirmed", F.exp("confirmed"))
casesWithExpConfirmed.show()

+--------+--------+---------------+-----+--------------------+---------+---------+----------+--------------------+
| case_id|province|           city|group|      infection_case|confirmed| latitude| longitude|        ExpConfirmed|
+--------+--------+---------------+-----+--------------------+---------+---------+----------+--------------------+
| 1000001|   Seoul|     Yongsan-gu| true|       Itaewon Clubs|      139|37.538621|126.992652|2.327732040478862E60|
| 1000002|   Seoul|      Gwanak-gu| true|             Richway|      119| 37.48208|126.901384|4.797813327299302E51|
| 1000003|   Seoul|        Guro-gu| true| Guro-gu Call Center|       95|37.508163|126.884387|1.811239082889023...|
| 1000004|   Seoul|   Yangcheon-gu| true|Yangcheon Table T...|       43|37.546061|126.874209|4.727839468229346...|
| 1000005|   Seoul|      Dobong-gu| true|     Day Care Center|       43|37.679422|127.044374|4.727839468229346...|
| 1000006|   Seoul|        Guro-gu| true|Manmin Central Ch...|       41|37.48105

## Using Spark UDFs

To use Spark UDFs, we need to use the `F.udf` function to convert a regular python function to a Spark UDF. We also need to specify the return type of the function. In this example the return type is `StringType()`

In [25]:
# import pyspark.sql.functions as F
# from pyspark.sql.types import *


# def casesHighLow(confirmed):
#     if confirmed < 50:
#         return 'low'
#     else:
#         return 'high'


# # convert to a UDF Function by passing in the function and return type of function
# casesHighLowUDF = F.udf(casesHighLow, StringType())
# CasesWithHighLow = cases.withColumn("HighLow", casesHighLowUDF("confirmed"))
# CasesWithHighLow.show()

## Using RDDs

In [26]:
# import math
# from pyspark.sql import Row


# def rowwise_function(row):
#     # convert row to python dictionary:
#     row_dict = row.asDict()
#     # Add a new key in the dictionary with the new column name and value.
#     # This might be a big complex function.
#     row_dict['expConfirmed'] = float(np.exp(row_dict['confirmed']))
#     # convert dict to row back again:
#     newrow = Row(**row_dict)
#     # return new row
#     return newrow


# # convert cases dataframe to RDD
# cases_rdd = cases.rdd

# # apply our function to RDD
# cases_rdd_new = cases_rdd.map(lambda row: rowwise_function(row))

# # Convert RDD Back to DataFrame
# casesNewDf = sqlContext.createDataFrame(cases_rdd_new)

# casesNewDf.show()

## Using Pandas UDF

The way we use it is by using the `F.pandas_udf` decorator. **We assume here that the input to the function will be a pandas data frame.** And we need to return a pandas dataframe in turn from this function.

The only complexity here is that we have to provide a schema for the output Dataframe. We can use the original schema of a dataframe to create the outSchema.

In [27]:
cases.printSchema()

root
 |--  case_id: integer (nullable = true)
 |-- province: string (nullable = true)
 |-- city: string (nullable = true)
 |-- group: boolean (nullable = true)
 |-- infection_case: string (nullable = true)
 |-- confirmed: integer (nullable = true)
 |-- latitude: string (nullable = true)
 |-- longitude: string (nullable = true)



Here I am using Pandas UDF to get normalized confirmed cases grouped by infection_case. The main advantage here is that I get to work with pandas dataframes in Spark.

In [29]:
# from pyspark.sql.types import IntegerType, StringType, DoubleType, BooleanType
# from pyspark.sql.types import StructType, StructField

# # Declare the schema for the output of our function

# outSchema = StructType([StructField('case_id', IntegerType(),True),
#                         StructField('province', StringType(),True),
#                         StructField('city', StringType(),True),
#                         StructField('group', BooleanType(),True),
#                         StructField('infection_case', StringType(),True),
#                         StructField('confirmed', IntegerType(),True),
#                         StructField('latitude', StringType(),True),
#                         StructField('longitude', StringType(),True),
#                         StructField('normalized_confirmed', DoubleType(),True)
#                        ])
# # decorate our function with pandas_udf decorator
# @F.pandas_udf(outSchema, F.PandasUDFType.GROUPED_MAP)
# def subtract_mean(pdf):
#     # pdf is a pandas.DataFrame
#     v = pdf.confirmed
#     v = v - v.mean()
#     pdf['normalized_confirmed'] = v
#     return pdf


# confirmed_groupwise_normalization = cases.groupby("infection_case").apply(subtract_mean)

# confirmed_groupwise_normalization.limit(10).toPandas()

# 5. Spark Window Functions

In [38]:
timeprovince = spark.read.load(
    "data/TimeProvince.csv", format="csv", sep=",", inferSchema="true", header="true"
)
timeprovince.show()

+-------------------+----+-----------------+---------+--------+--------+
|               date|time|         province|confirmed|released|deceased|
+-------------------+----+-----------------+---------+--------+--------+
|2020-01-20 00:00:00|  16|            Seoul|        0|       0|       0|
|2020-01-20 00:00:00|  16|            Busan|        0|       0|       0|
|2020-01-20 00:00:00|  16|            Daegu|        0|       0|       0|
|2020-01-20 00:00:00|  16|          Incheon|        1|       0|       0|
|2020-01-20 00:00:00|  16|          Gwangju|        0|       0|       0|
|2020-01-20 00:00:00|  16|          Daejeon|        0|       0|       0|
|2020-01-20 00:00:00|  16|            Ulsan|        0|       0|       0|
|2020-01-20 00:00:00|  16|           Sejong|        0|       0|       0|
|2020-01-20 00:00:00|  16|      Gyeonggi-do|        0|       0|       0|
|2020-01-20 00:00:00|  16|       Gangwon-do|        0|       0|       0|
|2020-01-20 00:00:00|  16|Chungcheongbuk-do|       

## Ranking

In [39]:
from pyspark.sql.window import Window
windowSpec = (Window()
              .partitionBy(['province'])
              .orderBy(F.desc('confirmed'))
             )
(cases
 .withColumn("rank", F.rank().over(windowSpec))
 .show()
)

+--------+-----------------+---------------+-----+--------------------+---------+--------+---------+----+
| case_id|         province|           city|group|      infection_case|confirmed|latitude|longitude|rank|
+--------+-----------------+---------------+-----+--------------------+---------+--------+---------+----+
| 1100001|            Busan|     Dongnae-gu| true|       Onchun Church|       39|35.21628| 129.0771|   1|
| 1100008|            Busan|              -|false|     overseas inflow|       36|       -|        -|   2|
| 1100010|            Busan|              -|false|                 etc|       30|       -|        -|   3|
| 1100009|            Busan|              -|false|contact with patient|       19|       -|        -|   4|
| 1100002|            Busan|from other city| true|  Shincheonji Church|       12|       -|        -|   5|
| 1100004|            Busan|    Haeundae-gu| true|Haeundae-gu Catho...|        6|35.20599| 129.1256|   6|
| 1100003|            Busan|     Suyeong-gu| t

## Lag Variables

In [41]:
from pyspark.sql.window import Window

windowSpec = (Window()
              .partitionBy(['province'])
              .orderBy('date')
             )

timeprovinceWithLag = (
    timeprovince
    .withColumn("lag_7", F.lag("confirmed", 7).over(windowSpec))
)

(timeprovinceWithLag
 .filter(timeprovinceWithLag.date>'2020-03-10')
 .show()
)

+-------------------+----+--------+---------+--------+--------+-----+
|               date|time|province|confirmed|released|deceased|lag_7|
+-------------------+----+--------+---------+--------+--------+-----+
|2020-03-11 00:00:00|   0|   Busan|       98|      21|       0|   92|
|2020-03-12 00:00:00|   0|   Busan|       99|      29|       0|   92|
|2020-03-13 00:00:00|   0|   Busan|      100|      36|       0|   95|
|2020-03-14 00:00:00|   0|   Busan|      103|      40|       0|   96|
|2020-03-15 00:00:00|   0|   Busan|      106|      52|       1|   96|
|2020-03-16 00:00:00|   0|   Busan|      107|      53|       1|   96|
|2020-03-17 00:00:00|   0|   Busan|      107|      54|       1|   96|
|2020-03-18 00:00:00|   0|   Busan|      107|      58|       1|   98|
|2020-03-19 00:00:00|   0|   Busan|      107|      58|       1|   99|
|2020-03-20 00:00:00|   0|   Busan|      108|      60|       1|  100|
|2020-03-21 00:00:00|   0|   Busan|      108|      67|       1|  103|
|2020-03-22 00:00:00

## Rolling Aggregations

In [42]:
from pyspark.sql.window import Window

windowSpec = (Window()
              .partitionBy(['province'])
              .orderBy('date')
              .rowsBetween(-6, 0)
             )

timeprovinceWithRoll = (
    timeprovince
    .withColumn("roll_7_confirmed", F.mean("confirmed").over(windowSpec))
)

(timeprovinceWithRoll
 .filter(timeprovinceWithLag.date>'2020-03-10')
 .show()
)

+-------------------+----+--------+---------+--------+--------+------------------+
|               date|time|province|confirmed|released|deceased|  roll_7_confirmed|
+-------------------+----+--------+---------+--------+--------+------------------+
|2020-03-11 00:00:00|   0|   Busan|       98|      21|       0| 95.57142857142857|
|2020-03-12 00:00:00|   0|   Busan|       99|      29|       0| 96.57142857142857|
|2020-03-13 00:00:00|   0|   Busan|      100|      36|       0| 97.28571428571429|
|2020-03-14 00:00:00|   0|   Busan|      103|      40|       0| 98.28571428571429|
|2020-03-15 00:00:00|   0|   Busan|      106|      52|       1| 99.71428571428571|
|2020-03-16 00:00:00|   0|   Busan|      107|      53|       1|101.28571428571429|
|2020-03-17 00:00:00|   0|   Busan|      107|      54|       1|102.85714285714286|
|2020-03-18 00:00:00|   0|   Busan|      107|      58|       1|104.14285714285714|
|2020-03-19 00:00:00|   0|   Busan|      107|      58|       1|105.28571428571429|
|202

In [43]:
from pyspark.sql.window import Window

windowSpec = (Window()
              .partitionBy(['province'])
              .orderBy('date')
              .rowsBetween(Window.unboundedPreceding, Window.currentRow)
             )

timeprovinceWithRoll = (
    timeprovince
    .withColumn("cumulative_confirmed", F.sum("confirmed").over(windowSpec))
)

(timeprovinceWithRoll
 .filter(timeprovinceWithLag.date>'2020-03-10')
 .show()
)

+-------------------+----+--------+---------+--------+--------+--------------------+
|               date|time|province|confirmed|released|deceased|cumulative_confirmed|
+-------------------+----+--------+---------+--------+--------+--------------------+
|2020-03-11 00:00:00|   0|   Busan|       98|      21|       0|                1408|
|2020-03-12 00:00:00|   0|   Busan|       99|      29|       0|                1507|
|2020-03-13 00:00:00|   0|   Busan|      100|      36|       0|                1607|
|2020-03-14 00:00:00|   0|   Busan|      103|      40|       0|                1710|
|2020-03-15 00:00:00|   0|   Busan|      106|      52|       1|                1816|
|2020-03-16 00:00:00|   0|   Busan|      107|      53|       1|                1923|
|2020-03-17 00:00:00|   0|   Busan|      107|      54|       1|                2030|
|2020-03-18 00:00:00|   0|   Busan|      107|      58|       1|                2137|
|2020-03-19 00:00:00|   0|   Busan|      107|      58|       1|  

# 6. Pivot Dataframes

In [44]:
pivotedTimeprovince = (timeprovince
                       .groupBy('date')
                       .pivot('province')
                       .agg(
                           F.sum('confirmed').alias('confirmed'), 
                           F.sum('released').alias('released')
                       )
                      )
pivotedTimeprovince.limit(10).toPandas()

Unnamed: 0,date,Busan_confirmed,Busan_released,Chungcheongbuk-do_confirmed,Chungcheongbuk-do_released,Chungcheongnam-do_confirmed,Chungcheongnam-do_released,Daegu_confirmed,Daegu_released,Daejeon_confirmed,...,Jeollabuk-do_confirmed,Jeollabuk-do_released,Jeollanam-do_confirmed,Jeollanam-do_released,Sejong_confirmed,Sejong_released,Seoul_confirmed,Seoul_released,Ulsan_confirmed,Ulsan_released
0,2020-04-05,122,90,45,28,135,102,6768,4773,37,...,16,7,15,4,46,15,552,145,40,27
1,2020-05-20,144,131,59,44,144,141,6872,6471,44,...,21,19,18,14,47,47,752,592,49,42
2,2020-06-21,150,142,61,55,159,144,6899,6680,82,...,23,20,20,18,49,47,1219,733,53,48
3,2020-06-04,147,138,60,49,146,142,6885,6634,46,...,21,19,20,17,47,47,909,644,52,46
4,2020-05-21,144,131,59,44,145,141,6872,6498,44,...,21,19,18,16,47,47,756,596,49,42
5,2020-05-01,137,116,45,41,143,128,6852,6144,40,...,18,12,15,12,46,39,634,455,43,37
6,2020-03-16,107,53,31,6,115,12,6066,734,22,...,7,4,4,2,40,0,253,52,28,7
7,2020-05-06,138,119,45,42,143,128,6856,6217,40,...,18,12,16,12,46,45,637,494,44,37
8,2020-06-07,147,140,61,49,148,143,6887,6640,46,...,21,19,20,17,47,47,974,651,53,46
9,2020-05-25,144,131,59,46,145,141,6874,6519,45,...,21,19,18,17,47,47,774,607,50,42


# 7. Unpivot/Stack Dataframes

Please take a look at the original post in Medium for this section

In [45]:
# we need to replace "-" with "_" in the column names
newColnames = [x.replace("-", "_") for x in pivotedTimeprovince.columns]
pivotedTimeprovince = pivotedTimeprovince.toDF(*newColnames)

In [51]:
expression = ""
cnt = 0
for column in pivotedTimeprovince.columns:
    if column != 'date':
        cnt += 1
        expression += f"'{column}' , {column},"
exprs = f"stack({cnt}, {expression[:-1]}) as (Type,Value)"

In [52]:
unpivotedTimeprovince = pivotedTimeprovince.select('date', F.expr(exprs))
unpivotedTimeprovince.show()

+-------------------+--------------------+-----+
|               date|                Type|Value|
+-------------------+--------------------+-----+
|2020-04-05 00:00:00|     Busan_confirmed|  122|
|2020-04-05 00:00:00|      Busan_released|   90|
|2020-04-05 00:00:00|Chungcheongbuk_do...|   45|
|2020-04-05 00:00:00|Chungcheongbuk_do...|   28|
|2020-04-05 00:00:00|Chungcheongnam_do...|  135|
|2020-04-05 00:00:00|Chungcheongnam_do...|  102|
|2020-04-05 00:00:00|     Daegu_confirmed| 6768|
|2020-04-05 00:00:00|      Daegu_released| 4773|
|2020-04-05 00:00:00|   Daejeon_confirmed|   37|
|2020-04-05 00:00:00|    Daejeon_released|   18|
|2020-04-05 00:00:00|Gangwon_do_confirmed|   45|
|2020-04-05 00:00:00| Gangwon_do_released|   24|
|2020-04-05 00:00:00|   Gwangju_confirmed|   27|
|2020-04-05 00:00:00|    Gwangju_released|   15|
|2020-04-05 00:00:00|Gyeonggi_do_confi...|  572|
|2020-04-05 00:00:00|Gyeonggi_do_released|  208|
|2020-04-05 00:00:00|Gyeongsangbuk_do_...| 1314|
|2020-04-05 00:00:00