# Checkpointing DataFrames

Sometimes execution plans can get pretty long and Spark might run into trouble. Common scenarios are iterative algorithms like ML or graph algorithms, which contain a big outer loop and iteratively transform a DataFrame over and over again. This would result in a really huge execution plan.

In these cases you could use `cache()` or `persist()` in order to improve performance (otherwise all steps of the loop would be executed again from the very beginning leading to a runtime of O(n^2)). But this will not cut off the lineage.

Checkpointing is the right solution for these cases. It will persist the data of a DataFrame in a reliable distributed storage (most commonly HDFS) and cut off the lineage.

## Create or Reuse Spark Session

In [1]:
from pyspark.sql import SparkSession

if not 'spark' in locals():
    spark = SparkSession.builder \
        .master("local[*]") \
        .config("spark.driver.memory","24G") \
        .getOrCreate()

spark

In [None]:
spark.conf.set("spark.sql.adaptive.enabled", False)
spark.conf.set("spark.sql.autoBroadcastJoinThreshold", -1)

# 1 Load Data

We will load the weather data again for this example.

In [1]:
storageLocation = "s3://dimajix-training/data/weather"

## 1.1 Load Measurements

In [2]:
from pyspark.sql.functions import *

raw_weather = spark.read.text(storageLocation + "/2003").withColumn("year", lit(2003))

### Extract Measurements

Measurements were stored in a proprietary text based format, with some values at fixed positions. We need to extract these values with a simple `SELECT` statement.

In [3]:
weather = raw_weather.select(
    col("year"),
    substring(col("value"),5,6).alias("usaf"),
    substring(col("value"),11,5).alias("wban"),
    substring(col("value"),16,8).alias("date"),
    substring(col("value"),24,4).alias("time"),
    substring(col("value"),42,5).alias("report_type"),
    substring(col("value"),61,3).alias("wind_direction"),
    substring(col("value"),64,1).alias("wind_direction_qual"),
    substring(col("value"),65,1).alias("wind_observation"),
    (substring(col("value"),66,4).cast("float") / lit(10.0)).alias("wind_speed"),
    substring(col("value"),70,1).alias("wind_speed_qual"),
    (substring(col("value"),88,5).cast("float") / lit(10.0)).alias("air_temperature"),
    substring(col("value"),93,1).alias("air_temperature_qual")
)
    
weather.limit(10).toPandas()

Unnamed: 0,year,usaf,wban,date,time,report_type,wind_direction,wind_direction_qual,wind_observation,wind_speed,wind_speed_qual,air_temperature,air_temperature_qual
0,2003,703160,25624,20030101,0,SY-MT,10,5,N,5.2,5,-0.6,5
1,2003,703160,25624,20030101,17,FM-16,20,1,N,4.6,1,-2.0,1
2,2003,703160,25624,20030101,53,FM-15,10,5,N,5.2,5,-2.8,5
3,2003,703160,25624,20030101,100,NSRDB,999,9,9,999.9,9,999.9,9
4,2003,703160,25624,20030101,153,FM-15,10,5,N,6.2,5,-2.2,5
5,2003,703160,25624,20030101,200,NSRDB,999,9,9,999.9,9,999.9,9
6,2003,703160,25624,20030101,253,FM-15,10,5,N,7.2,5,-3.3,5
7,2003,703160,25624,20030101,300,NSRDB,999,9,9,999.9,9,999.9,9
8,2003,703160,25624,20030101,353,FM-15,20,5,N,6.2,5,-1.1,5
9,2003,703160,25624,20030101,400,NSRDB,999,9,9,999.9,9,999.9,9


## 1.2 Load Station Metadata

In [4]:
stations = spark.read \
    .option("header", True) \
    .csv(storageLocation + "/isd-history")

# Display first 10 records    
stations.limit(10).toPandas()

Unnamed: 0,USAF,WBAN,STATION NAME,CTRY,STATE,ICAO,LAT,LON,ELEV(M),BEGIN,END
0,7005,99999,CWOS 07005,,,,,,,20120127,20120127
1,7011,99999,CWOS 07011,,,,,,,20111025,20121129
2,7018,99999,WXPOD 7018,,,,0.0,0.0,7018.0,20110309,20130730
3,7025,99999,CWOS 07025,,,,,,,20120127,20120127
4,7026,99999,WXPOD 7026,AF,,,0.0,0.0,7026.0,20120713,20141120
5,7034,99999,CWOS 07034,,,,,,,20121024,20121106
6,7037,99999,CWOS 07037,,,,,,,20111202,20121125
7,7044,99999,CWOS 07044,,,,,,,20120127,20120127
8,7047,99999,CWOS 07047,,,,,,,20120613,20120717
9,7052,99999,CWOS 07052,,,,,,,20121129,20121130


# 2 Join Data

Now we perform the join between the station master data and the measurements, as we did before.

In [5]:
joined_weather = weather.join(stations, (weather["usaf"] == stations["usaf"]) & (weather["wban"] == stations["wban"]))

# 3 Truncating Execution Plans

Now we want to understand the effect of checkpointing. First we will use the traditional aggregation and print the execution plan.

## 3.1 Traditional Aggregation

In [6]:
result = joined_weather.groupBy(joined_weather["ctry"], joined_weather["year"]).agg(
        min(when(joined_weather.air_temperature_qual == lit(1), joined_weather.air_temperature)).alias('min_temp'),
        max(when(joined_weather.air_temperature_qual == lit(1), joined_weather.air_temperature)).alias('max_temp')
)

result.explain(True)

== Parsed Logical Plan ==
'Aggregate [ctry#45, year#2], [ctry#45, year#2, min(CASE WHEN (air_temperature_qual#16 = 1) THEN air_temperature#15 END) AS min_temp#173, max(CASE WHEN (air_temperature_qual#16 = 1) THEN air_temperature#15 END) AS max_temp#175]
+- AnalysisBarrier
      +- Join Inner, ((usaf#5 = usaf#42) && (wban#6 = wban#43))
         :- Project [year#2, substring(value#0, 5, 6) AS usaf#5, substring(value#0, 11, 5) AS wban#6, substring(value#0, 16, 8) AS date#7, substring(value#0, 24, 4) AS time#8, substring(value#0, 42, 5) AS report_type#9, substring(value#0, 61, 3) AS wind_direction#10, substring(value#0, 64, 1) AS wind_direction_qual#11, substring(value#0, 65, 1) AS wind_observation#12, (cast(cast(substring(value#0, 66, 4) as float) as double) / cast(10.0 as double)) AS wind_speed#13, substring(value#0, 70, 1) AS wind_speed_qual#14, (cast(cast(substring(value#0, 88, 5) as float) as double) / cast(10.0 as double)) AS air_temperature#15, substring(value#0, 93, 1) AS air_tempe

## 3.2 Reliable Checkpointing

Now we first checkpoint the joined weather data set and then perform the aggregation on the checkpointed DataFrame.

### Set Checkpoint directory

First we need to specify a checkpoint directory on a reliable shared file system.

In [7]:
spark.sparkContext.setCheckpointDir("/tmp/checkpoints")

### Create checkpoint

Now we can create a checkpoint for the joined weather. Note that this takes some time, as checkpointing is not a lazy operation, it will be executed immediately. This is also conceptionally neccessary, because one aspect of checkpointing is that the whole lineage gets cut off. So there is no way around executing the computation for materializing the DataFrame inside the checkpoint directory

In [8]:
cp_weather = joined_weather.checkpoint(eager=True)

### Inspect Checkpoint directory

In [11]:
%%bash
hdfs dfs -ls /tmp/checkpoints

Found 1 items
drwxr-xr-x   - hadoop hadoop          0 2018-10-28 07:37 /tmp/checkpoints/1e08381c-ddda-4d02-876b-07ba3427c9f8


### Inspect execution plan

Let us have a look at the execution plan of the checkpointed DataFrame

In [9]:
cp_weather.explain(True)

== Parsed Logical Plan ==
AnalysisBarrier
   +- LogicalRDD [year#2, usaf#5, wban#6, date#7, time#8, report_type#9, wind_direction#10, wind_direction_qual#11, wind_observation#12, wind_speed#13, wind_speed_qual#14, air_temperature#15, air_temperature_qual#16, USAF#42, WBAN#43, STATION NAME#44, CTRY#45, STATE#46, ICAO#47, LAT#48, LON#49, ELEV(M)#50, BEGIN#51, END#52], false

== Analyzed Logical Plan ==
year: int, usaf: string, wban: string, date: string, time: string, report_type: string, wind_direction: string, wind_direction_qual: string, wind_observation: string, wind_speed: double, wind_speed_qual: string, air_temperature: double, air_temperature_qual: string, USAF: string, WBAN: string, STATION NAME: string, CTRY: string, STATE: string, ICAO: string, LAT: string, LON: string, ELEV(M): string, BEGIN: string, END: string
LogicalRDD [year#2, usaf#5, wban#6, date#7, time#8, report_type#9, wind_direction#10, wind_direction_qual#11, wind_observation#12, wind_speed#13, wind_speed_qual#14, 

As you can see, the lineage got lost and is replaced by a `Scan ExistingRDD` which refers to the data in the checkpoint directory.

### Perform aggregation

Now we can perform the aggregation with the checkpointed variant of the joined weather DataFrame.

In [10]:
result = cp_weather.groupBy(cp_weather["ctry"], cp_weather["year"]).agg(
        min(when(cp_weather.air_temperature_qual == lit(1), cp_weather.air_temperature)).alias('min_temp'),
        max(when(cp_weather.air_temperature_qual == lit(1), cp_weather.air_temperature)).alias('max_temp')
)

result.explain(True)

== Parsed Logical Plan ==
'Aggregate [ctry#45, year#2], [ctry#45, year#2, min(CASE WHEN (air_temperature_qual#16 = 1) THEN air_temperature#15 END) AS min_temp#247, max(CASE WHEN (air_temperature_qual#16 = 1) THEN air_temperature#15 END) AS max_temp#249]
+- AnalysisBarrier
      +- LogicalRDD [year#2, usaf#5, wban#6, date#7, time#8, report_type#9, wind_direction#10, wind_direction_qual#11, wind_observation#12, wind_speed#13, wind_speed_qual#14, air_temperature#15, air_temperature_qual#16, USAF#42, WBAN#43, STATION NAME#44, CTRY#45, STATE#46, ICAO#47, LAT#48, LON#49, ELEV(M)#50, BEGIN#51, END#52], false

== Analyzed Logical Plan ==
ctry: string, year: int, min_temp: double, max_temp: double
Aggregate [ctry#45, year#2], [ctry#45, year#2, min(CASE WHEN (cast(air_temperature_qual#16 as int) = 1) THEN air_temperature#15 END) AS min_temp#247, max(CASE WHEN (cast(air_temperature_qual#16 as int) = 1) THEN air_temperature#15 END) AS max_temp#249]
+- LogicalRDD [year#2, usaf#5, wban#6, date#7, ti

As expected, the execution plan now essentially only contains the aggregation in three steps (partial aggregation, shuffle, final aggregation). The lineage of the join is not present any more.

## 3.3 Unreliable Checkpointing

In addition to *reliable* checkpointing, Spark also supports *unreliable* checkpointing, where the data is not stored in HDFS but on the local worker nodes instead using the caching backend.

Note that it is stronlgly discouraged to use unreliable checkpointing with dynamic execution mode, where executors can be freed up again.

In [12]:
cpu_weather = joined_weather.localCheckpoint(eager=True)

### Inspect Checkpoint data

Now you can see the checkpointed data in the "Storage" section of the web interface.

### Perform aggregation

Now we can perform the aggregation with the checkpointed variant of the joined weather DataFrame.

In [13]:
result = cpu_weather.groupBy(cpu_weather["ctry"], cpu_weather["year"]).agg(
        min(when(cpu_weather.air_temperature_qual == lit(1), cpu_weather.air_temperature)).alias('min_temp'),
        max(when(cpu_weather.air_temperature_qual == lit(1), cpu_weather.air_temperature)).alias('max_temp')
)

result.explain(True)

== Parsed Logical Plan ==
'Aggregate [ctry#45, year#2], [ctry#45, year#2, min(CASE WHEN (air_temperature_qual#16 = 1) THEN air_temperature#15 END) AS min_temp#308, max(CASE WHEN (air_temperature_qual#16 = 1) THEN air_temperature#15 END) AS max_temp#310]
+- AnalysisBarrier
      +- LogicalRDD [year#2, usaf#5, wban#6, date#7, time#8, report_type#9, wind_direction#10, wind_direction_qual#11, wind_observation#12, wind_speed#13, wind_speed_qual#14, air_temperature#15, air_temperature_qual#16, USAF#42, WBAN#43, STATION NAME#44, CTRY#45, STATE#46, ICAO#47, LAT#48, LON#49, ELEV(M)#50, BEGIN#51, END#52], false

== Analyzed Logical Plan ==
ctry: string, year: int, min_temp: double, max_temp: double
Aggregate [ctry#45, year#2], [ctry#45, year#2, min(CASE WHEN (cast(air_temperature_qual#16 as int) = 1) THEN air_temperature#15 END) AS min_temp#308, max(CASE WHEN (cast(air_temperature_qual#16 as int) = 1) THEN air_temperature#15 END) AS max_temp#310]
+- LogicalRDD [year#2, usaf#5, wban#6, date#7, ti

In [15]:
result.limit(5).toPandas()

Unnamed: 0,ctry,year,min_temp,max_temp
0,DA,2003,-16.0,30.0
1,EZ,2003,-16.0,37.0
2,JA,2003,-0.7,34.2
3,NL,2003,-14.3,36.0
4,LU,2003,-13.0,37.4


## 3.4 Checkpoint cleanup

Spark can automatically remove checkpoint directories, if the configuration property `spark.cleaner.referenceTracking.cleanCheckpoints` is set to `True` (default is `False` as of Spark 2.3). Otherwise you have to manually remove checkpoint data from HDFS.