# Caching Data

Spark offers the possibility to cache data, which means that it tries to keep (intermediate) results either in memory or on disk. This can be very helpful in iterative algorithms or interactive analysis, where you want to prevent that the same processing steps are performed over and over again.

### Approach to Caching
Instead of performing timings of individual executions, we use the `explain()` method again to see how output changes with cached intermediate results.

### Weather Example
We will again use the weather example to understand how caching works.

## Create or Reuse Spark Session

In [1]:
from pyspark.sql import SparkSession
import pyspark.sql.functions as f

if not 'spark' in locals():
    spark = SparkSession.builder \
        .master("local[*]") \
        .config("spark.driver.memory","24G") \
        .getOrCreate()

spark

/opt/anaconda3/lib/python3.10/site-packages/pyspark/bin/load-spark-env.sh: line 68: ps: command not found
Setting default log level to "WARN".
To adjust logging level use sc.setLogLevel(newLevel). For SparkR, use setLogLevel(newLevel).
23/11/25 17:25:44 WARN NativeCodeLoader: Unable to load native-hadoop library for your platform... using builtin-java classes where applicable
23/11/25 17:25:45 WARN Utils: Service 'SparkUI' could not bind on port 4040. Attempting port 4041.
23/11/25 17:25:45 WARN Utils: Service 'SparkUI' could not bind on port 4041. Attempting port 4042.


In [2]:
spark.conf.set("spark.sql.adaptive.enabled", False)
spark.conf.set("spark.sql.autoBroadcastJoinThreshold", -1)

# 1. Load Data

First we load the weather data, which consists of the measurement data and some station metadata.

In [3]:
storageLocation = "s3://dimajix-training/data/weather"
# storageLocation = "/dimajix/data/weather-noaa-sample"

## 1.1 Load Measurements

Measurements are stored in multiple directories (one per year)

In [4]:
raw_weather = spark.read.text(storageLocation + "/2003").withColumn("year", f.lit(2003))    

### Extract Measurements

Measurements were stored in a proprietary text based format, with some values at fixed positions. We need to extract these values with a simple `SELECT` statement.

In [5]:
weather = raw_weather.select(
    f.col("year"),
    f.substring(f.col("value"),5,6).alias("usaf"),
    f.substring(f.col("value"),11,5).alias("wban"),
    f.substring(f.col("value"),16,8).alias("date"),
    f.substring(f.col("value"),24,4).alias("time"),
    f.substring(f.col("value"),42,5).alias("report_type"),
    f.substring(f.col("value"),61,3).alias("wind_direction"),
    f.substring(f.col("value"),64,1).alias("wind_direction_qual"),
    f.substring(f.col("value"),65,1).alias("wind_observation"),
    (f.substring(f.col("value"),66,4).cast("float") / f.lit(10.0)).alias("wind_speed"),
    f.substring(f.col("value"),70,1).alias("wind_speed_qual"),
    (f.substring(f.col("value"),88,5).cast("float") / f.lit(10.0)).alias("air_temperature"),
    f.substring(f.col("value"),93,1).alias("air_temperature_qual")
)
    
weather.limit(10).toPandas()

Unnamed: 0,year,usaf,wban,date,time,report_type,wind_direction,wind_direction_qual,wind_observation,wind_speed,wind_speed_qual,air_temperature,air_temperature_qual
0,2003,703160,25624,20030101,0,SY-MT,10,5,N,5.2,5,-0.6,5
1,2003,703160,25624,20030101,17,FM-16,20,1,N,4.6,1,-2.0,1
2,2003,703160,25624,20030101,53,FM-15,10,5,N,5.2,5,-2.8,5
3,2003,703160,25624,20030101,100,NSRDB,999,9,9,999.9,9,999.9,9
4,2003,703160,25624,20030101,153,FM-15,10,5,N,6.2,5,-2.2,5
5,2003,703160,25624,20030101,200,NSRDB,999,9,9,999.9,9,999.9,9
6,2003,703160,25624,20030101,253,FM-15,10,5,N,7.2,5,-3.3,5
7,2003,703160,25624,20030101,300,NSRDB,999,9,9,999.9,9,999.9,9
8,2003,703160,25624,20030101,353,FM-15,20,5,N,6.2,5,-1.1,5
9,2003,703160,25624,20030101,400,NSRDB,999,9,9,999.9,9,999.9,9


## 1.2 Load Station Metadata

We also need to load the weather station meta data containing information about the geo location, country etc of individual weather stations.

In [6]:
stations = spark.read \
    .option("header", True) \
    .csv(storageLocation + "/isd-history")

# Display first 10 records    
stations.limit(10).toPandas()

Unnamed: 0,USAF,WBAN,STATION NAME,CTRY,STATE,ICAO,LAT,LON,ELEV(M),BEGIN,END
0,7018,99999,WXPOD 7018,,,,0.0,0.0,7018.0,20110309,20130730
1,7026,99999,WXPOD 7026,AF,,,0.0,0.0,7026.0,20120713,20170822
2,7070,99999,WXPOD 7070,AF,,,0.0,0.0,7070.0,20140923,20150926
3,8260,99999,WXPOD8270,,,,0.0,0.0,0.0,20050101,20100920
4,8268,99999,WXPOD8278,AF,,,32.95,65.567,1156.7,20100519,20120323
5,8307,99999,WXPOD 8318,AF,,,0.0,0.0,8318.0,20100421,20100421
6,8411,99999,XM20,,,,,,,20160217,20160217
7,8414,99999,XM18,,,,,,,20160216,20160217
8,8415,99999,XM21,,,,,,,20160217,20160217
9,8418,99999,XM24,,,,,,,20160217,20160217


# 2 Caching Data

For analysing the impact of cachign data, we will use a slightly simplified variant of the weather analysis (only temperature will be aggregated). We will change the execution by caching intermediate results and watch how execution plans change.

## 2.1 Original Execution Plan

First let's have the execution plans of the original query as our reference.

In [7]:
joined_weather = weather.join(stations, ["usaf", "wban"])
aggregates = joined_weather.groupBy(joined_weather.CTRY, joined_weather.year).agg(
        f.min(f.when(joined_weather.air_temperature_qual == f.lit(1), joined_weather.air_temperature)).alias('min_temp'),
        f.max(f.when(joined_weather.air_temperature_qual == f.lit(1), joined_weather.air_temperature)).alias('max_temp')
    )

In [8]:
result = joined_weather.join(f.broadcast(aggregates), ["ctry", "year"])

result.explain()

== Physical Plan ==
*(14) Project [CTRY#52, 2003 AS year#2, usaf#6, wban#7, date#8, time#9, report_type#10, wind_direction#11, wind_direction_qual#12, wind_observation#13, wind_speed#14, wind_speed_qual#15, air_temperature#16, air_temperature_qual#17, STATION NAME#51, STATE#53, ICAO#54, LAT#55, LON#56, ELEV(M)#57, BEGIN#58, END#59, min_temp#128, max_temp#130]
+- *(14) SortMergeJoin [CTRY#52], [CTRY#139], Inner
   :- *(6) Sort [CTRY#52 ASC NULLS FIRST], false, 0
   :  +- Exchange hashpartitioning(CTRY#52, 200), ENSURE_REQUIREMENTS, [plan_id=153]
   :     +- *(5) Project [usaf#6, wban#7, date#8, time#9, report_type#10, wind_direction#11, wind_direction_qual#12, wind_observation#13, wind_speed#14, wind_speed_qual#15, air_temperature#16, air_temperature_qual#17, STATION NAME#51, CTRY#52, STATE#53, ICAO#54, LAT#55, LON#56, ELEV(M)#57, BEGIN#58, END#59]
   :        +- *(5) SortMergeJoin [usaf#6, wban#7], [USAF#49, WBAN#50], Inner
   :           :- *(2) Sort [usaf#6 ASC NULLS FIRST, wban#7 AS

## 2.2 Caching Weather

First let us simply cache the joined input DataFrame.

In [9]:
joined_weather.cache()

DataFrame[usaf: string, wban: string, year: int, date: string, time: string, report_type: string, wind_direction: string, wind_direction_qual: string, wind_observation: string, wind_speed: double, wind_speed_qual: string, air_temperature: double, air_temperature_qual: string, STATION NAME: string, CTRY: string, STATE: string, ICAO: string, LAT: string, LON: string, ELEV(M): string, BEGIN: string, END: string]

### Forcing physical caching

The `cache()` method again works lazily and only marks the DataFrame to be cached. The physical cache itself will only take place once the elements are evaluated. A common and easy way to enforce this is to call a `count()` on the to-be cached DataFrame.

In [10]:
%%time

joined_weather.count()



CPU times: user 25.9 ms, sys: 13.5 ms, total: 39.4 ms
Wall time: 9.12 s


                                                                                

1807253

When you now perform `count` a second time, it should be much faster

In [11]:
%%time

joined_weather.count()

CPU times: user 0 ns, sys: 2.25 ms, total: 2.25 ms
Wall time: 356 ms


1807253

### Execution Plan with Cache

Now let us have a look at the execution plan with the cache for the `weather` DataFrame enabled.

In [12]:
aggregates = joined_weather.groupBy(joined_weather.CTRY, joined_weather.year).agg(
        f.min(f.when(joined_weather.air_temperature_qual == f.lit(1), joined_weather.air_temperature)).alias('min_temp'),
        f.max(f.when(joined_weather.air_temperature_qual == f.lit(1), joined_weather.air_temperature)).alias('max_temp')
    )

result = joined_weather.join(f.broadcast(aggregates), ["ctry", "year"])

result.explain()

== Physical Plan ==
*(3) Project [CTRY#52, year#2, usaf#6, wban#7, date#8, time#9, report_type#10, wind_direction#11, wind_direction_qual#12, wind_observation#13, wind_speed#14, wind_speed_qual#15, air_temperature#16, air_temperature_qual#17, STATION NAME#51, STATE#53, ICAO#54, LAT#55, LON#56, ELEV(M)#57, BEGIN#58, END#59, min_temp#1484, max_temp#1486]
+- *(3) BroadcastHashJoin [CTRY#52, year#2], [CTRY#1495, year#1503], Inner, BuildRight, false
   :- *(3) Filter isnotnull(CTRY#52)
   :  +- InMemoryTableScan [usaf#6, wban#7, year#2, date#8, time#9, report_type#10, wind_direction#11, wind_direction_qual#12, wind_observation#13, wind_speed#14, wind_speed_qual#15, air_temperature#16, air_temperature_qual#17, STATION NAME#51, CTRY#52, STATE#53, ICAO#54, LAT#55, LON#56, ELEV(M)#57, BEGIN#58, END#59], [isnotnull(CTRY#52)]
   :        +- InMemoryRelation [usaf#6, wban#7, year#2, date#8, time#9, report_type#10, wind_direction#11, wind_direction_qual#12, wind_observation#13, wind_speed#14, wind_

In [13]:
result.limit(10).toPandas()

                                                                                

Unnamed: 0,CTRY,year,usaf,wban,date,time,report_type,wind_direction,wind_direction_qual,wind_observation,...,STATION NAME,STATE,ICAO,LAT,LON,ELEV(M),BEGIN,END,min_temp,max_temp
0,AS,2003,954920,99999,20030101,0,FM-12,200,1,N,...,THARGOMINDAH,,YTGM,-27.986,143.811,132.0,20010705,20200205,0.3,45.8
1,AS,2003,954920,99999,20030101,200,FM-12,230,1,N,...,THARGOMINDAH,,YTGM,-27.986,143.811,132.0,20010705,20200205,0.3,45.8
2,AS,2003,954920,99999,20030101,300,FM-12,220,1,N,...,THARGOMINDAH,,YTGM,-27.986,143.811,132.0,20010705,20200205,0.3,45.8
3,AS,2003,954920,99999,20030101,500,FM-12,230,1,N,...,THARGOMINDAH,,YTGM,-27.986,143.811,132.0,20010705,20200205,0.3,45.8
4,AS,2003,954920,99999,20030101,600,FM-12,240,1,N,...,THARGOMINDAH,,YTGM,-27.986,143.811,132.0,20010705,20200205,0.3,45.8
5,AS,2003,954920,99999,20030101,800,FM-12,210,1,N,...,THARGOMINDAH,,YTGM,-27.986,143.811,132.0,20010705,20200205,0.3,45.8
6,AS,2003,954920,99999,20030101,900,FM-12,220,1,N,...,THARGOMINDAH,,YTGM,-27.986,143.811,132.0,20010705,20200205,0.3,45.8
7,AS,2003,954920,99999,20030101,1100,FM-12,210,1,N,...,THARGOMINDAH,,YTGM,-27.986,143.811,132.0,20010705,20200205,0.3,45.8
8,AS,2003,954920,99999,20030101,1200,FM-12,200,1,N,...,THARGOMINDAH,,YTGM,-27.986,143.811,132.0,20010705,20200205,0.3,45.8
9,AS,2003,954920,99999,20030101,1400,FM-12,190,1,N,...,THARGOMINDAH,,YTGM,-27.986,143.811,132.0,20010705,20200205,0.3,45.8


### Remarks

Although the data is already cached, the execution plan still contains all steps. But the caching step won't be executed any more (since data is already cached), it is only mentioned here for completenss of the plan. We will see in the web interface.

The cache itself is presented as two steps in the execution plan:
* Creating the cache (InMemoryRelation)
* Using the cache (InMemoryTableScan)

If you look closely at the execution plans and compare these to the original uncached plan, you will notice that certain optimizations are not performed any more:
* Cache contains ALL columns of the weather DataFrame, although only a subset is required.
* Filter operation of JOIN is performed part of caching.

Caching is an optimization barrier. This means that Spark can only optimize plans before building the cache and plans after using the cache. No optimization is possible that spans building and using the cache. The idea simply is that the DataFrame should be cached exactly how it was specified without any column truncating or record filtering in place which appears after the cache.

## 2.2 Uncaching Data

Caches occupy resources (memory and/or disk). Once you do not need the cache any more, you'd probably like to free up the resources again. This is easily possible with the `unpersist()` method.

In [14]:
joined_weather.unpersist()

DataFrame[usaf: string, wban: string, year: int, date: string, time: string, report_type: string, wind_direction: string, wind_direction_qual: string, wind_observation: string, wind_speed: double, wind_speed_qual: string, air_temperature: double, air_temperature_qual: string, STATION NAME: string, CTRY: string, STATE: string, ICAO: string, LAT: string, LON: string, ELEV(M): string, BEGIN: string, END: string]

### Exeuction plan after unpersist

Now we'd expect to have the original execution plan again.

In [15]:
result = joined_weather.groupBy(joined_weather.CTRY, joined_weather.year).agg(
        f.min(f.when(joined_weather.air_temperature_qual == f.lit(1), joined_weather.air_temperature)).alias('min_temp'),
        f.max(f.when(joined_weather.air_temperature_qual == f.lit(1), joined_weather.air_temperature)).alias('max_temp')
    )

result.explain(False)

== Physical Plan ==
*(6) HashAggregate(keys=[CTRY#52, 2003#3336], functions=[min(CASE WHEN (cast(air_temperature_qual#17 as int) = 1) THEN air_temperature#16 END), max(CASE WHEN (cast(air_temperature_qual#17 as int) = 1) THEN air_temperature#16 END)])
+- Exchange hashpartitioning(CTRY#52, 2003#3336, 200), ENSURE_REQUIREMENTS, [plan_id=576]
   +- *(5) HashAggregate(keys=[CTRY#52, 2003 AS 2003#3336], functions=[partial_min(CASE WHEN (cast(air_temperature_qual#17 as int) = 1) THEN air_temperature#16 END), partial_max(CASE WHEN (cast(air_temperature_qual#17 as int) = 1) THEN air_temperature#16 END)])
      +- *(5) Project [air_temperature#16, air_temperature_qual#17, CTRY#52]
         +- *(5) SortMergeJoin [usaf#6, wban#7], [USAF#49, WBAN#50], Inner
            :- *(2) Sort [usaf#6 ASC NULLS FIRST, wban#7 ASC NULLS FIRST], false, 0
            :  +- Exchange hashpartitioning(usaf#6, wban#7, 200), ENSURE_REQUIREMENTS, [plan_id=559]
            :     +- *(1) Project [substring(value#0, 5, 6)

### Remarks

As you see in the execution plan, the cache has been removed now and the plan equals to the original one before we started caching data.

# 3 Cache Levels

Spark supports different levels of cache (memory, disk and a combination). These can be specified explicitly if you use `persist()` instead of `cache()`. Cache actually is a shortcut for `persist(MEMORY_AND_DISK)`.

In [16]:
from pyspark.storagelevel import StorageLevel

joined_weather.persist(StorageLevel.MEMORY_ONLY)
joined_weather.persist(StorageLevel.DISK_ONLY)
joined_weather.persist(StorageLevel.MEMORY_AND_DISK)

joined_weather.persist(StorageLevel.MEMORY_ONLY_2)
joined_weather.persist(StorageLevel.DISK_ONLY_2)
joined_weather.persist(StorageLevel.MEMORY_AND_DISK_2)


23/11/25 17:26:55 WARN CacheManager: Asked to cache already cached data.
23/11/25 17:26:55 WARN CacheManager: Asked to cache already cached data.
23/11/25 17:26:55 WARN CacheManager: Asked to cache already cached data.
23/11/25 17:26:55 WARN CacheManager: Asked to cache already cached data.
23/11/25 17:26:55 WARN CacheManager: Asked to cache already cached data.


DataFrame[usaf: string, wban: string, year: int, date: string, time: string, report_type: string, wind_direction: string, wind_direction_qual: string, wind_observation: string, wind_speed: double, wind_speed_qual: string, air_temperature: double, air_temperature_qual: string, STATION NAME: string, CTRY: string, STATE: string, ICAO: string, LAT: string, LON: string, ELEV(M): string, BEGIN: string, END: string]

### Cache level explanation

* `MEMORY_ONLY` - stores all records directly in memory
* `DISK_ONLY` - stores all records serialized on disk
* `MEMORY_AND_DISK` - stores all records first in memory and spills onto disk when no space is left in memory
* `..._2` - stores caches on two nodes instead of one for additional redundancy

# 4 Don'ts

Although reading from a cache can be faster than reprocessing data from scratch, especially if that involves reading original data from slow IO devices (S3) or complex operations (joins), some caution should be taken. Caching is not free, not only is it a optimization barrier, it also occupies resources (memory and disk) and definately slows down the first query that has to build the cache.

In order to limit the physical resources (RAM and disk), you should reduce the amount to cache to the bare minimum and even exclude simple calculations from the cache. For example if we included conversions to mph and °F in our weather data as precalculated measurements, it would be a wise idea to exclude these simple calculations from the cache, since they would only blow up the overall volume while these conversions are simple and cheap to calculate even after reading from the cache (plus they can be removed by the optimizer when they are not needed in a specific query)

In [41]:
# Remove any previous caches
weather.unpersist()

weather_intl = weather.withColumn("air_temperature_fahrenheit", weather["air_temperature"]*9.0/5.0+32) \
        .withColumn("wind_speed_mph", weather["wind_speed"]*2.236936)

# DON'T !
weather_intl.cache()

DataFrame[year: int, usaf: string, wban: string, date: string, time: string, report_type: string, wind_direction: string, wind_direction_qual: string, wind_observation: string, wind_speed: double, wind_speed_qual: string, air_temperature: double, air_temperature_qual: string, air_temperature_fahrenheit: double, wind_speed_mph: double]

In [42]:
# Remove any previous caches
weather_intl.unpersist()

# Prefer caching the smaller input data set and perform trivial calculations after caching
weather.cache()
weather_intl = weather.withColumn("air_temperature_fahrenheit", weather["air_temperature"]*9.0/5.0+32) \
        .withColumn("wind_speed_mph", weather["wind_speed"]*2.236936)