In [None]:
# Spark Session
from pyspark.sql import SparkSession

spark = (
    SparkSession.builder.appName("Understand Caching")
    .master("spark://spark-master:7077")
    .config("spark.executor.memory", "512M")
    .getOrCreate()
)

spark

Setting default log level to "WARN".
To adjust logging level use sc.setLogLevel(newLevel). For SparkR, use setLogLevel(newLevel).
24/11/11 16:12:09 WARN NativeCodeLoader: Unable to load native-hadoop library for your platform... using builtin-java classes where applicable


In [None]:
df = (
    spark.read.format("csv")
    .option("inferSchema", True)
    .option("header", True)
    .load("hdfs://namenode:9000/input/data/employee_records.csv")
)

                                                                                

In [3]:
df.show()

+----------+----------+--------------------+----------+--------------------+--------------------+------+-------------+
|first_name| last_name|           job_title|       dob|               email|               phone|salary|department_id|
+----------+----------+--------------------+----------+--------------------+--------------------+------+-------------+
|   Richard|  Morrison|Public relations ...|1973-05-05|melissagarcia@exa...|       (699)525-4827|512653|            8|
|     Bobby|  Mccarthy|   Barrister's clerk|1974-04-25|   llara@example.net|  (750)846-1602x7458|999836|            7|
|    Dennis|    Norman|Land/geomatics su...|1990-06-24| jturner@example.net|    873.820.0518x825|131900|           10|
|      John|    Monroe|        Retail buyer|1968-06-16|  erik33@example.net|    820-813-0557x624|485506|            1|
|  Michelle|   Elliott|      Air cabin crew|1975-03-31|tiffanyjohnston@e...|       (705)900-5337|604738|            8|
|    Ashley|   Montoya|        Cartographer|1976

In [4]:
# Cache DataFrame (cache or persist)

df_cache = df.where("salary > 100000").cache()

In [None]:
df_cache.count()

[Stage 3:>                                                          (0 + 8) / 8]

In [None]:
df.where("salary > 5000").count()

In [None]:
df.where("salary > 500000").count()

In [None]:
df_cache.where("salary > 500000").count()

In [None]:
# MEMORY_ONLY, MEMORY_AND_DISK, MEMORY_ONLY_SER, MEMORY_AND_DISK_SER, DISK_ONLY, MEMORY_ONLY_2, MEMORY_AND_DISK_2
import pyspark

df_persist = df.persist(pyspark.StorageLevel.MEMORY_ONLY_2)


In [None]:
df_persist.write.format("noop").mode("overwrite").save()

In [None]:
# Remove Cache

spark.catalog.clearCache()

In [None]:
from pyspark.sql import functions as F

emp = (
    spark.read.format("csv")
    .option("inferSchema", True)
    .option("header", True)
    .load("hdfs://namenode:9000/input/data/employee_records.csv")
)

emp.groupby("last_name").count().orderBy("count", ascending=False).show()



+---------+-----+
|last_name|count|
+---------+-----+
|    Smith|21740|
|  Johnson|16928|
| Williams|13943|
|    Jones|12511|
|    Brown|12444|
|   Miller|10248|
|    Davis| 9865|
|   Garcia| 7654|
|Rodriguez| 7335|
| Martinez| 7052|
|   Wilson| 6993|
| Anderson| 6886|
|   Taylor| 6667|
|   Thomas| 6452|
|    Moore| 6393|
|Hernandez| 6363|
|   Martin| 6130|
|  Jackson| 6044|
|    White| 5850|
| Thompson| 5836|
+---------+-----+
only showing top 20 rows



                                                                                

In [None]:
# The real use of caching
# Image a costly operation: join two big df
# and the result of the join is going to be used for later analytics
# in this case, if you cache the join result, you never run the full dag again

# we are going to get possible siblings in department 1

df1 = emp.filter(F.col("department_id") == 1).select(
    *(F.col(x).alias(x + "_1") for x in emp.columns)
)
df2 = emp.filter(F.col("department_id") == 1).select(
    *(F.col(x).alias(x + "_2") for x in emp.columns)
)

siblings = df1.join(df2, on=df1.last_name_1 == df2.last_name_2)

In [6]:
siblings.printSchema()

root
 |-- first_name_1: string (nullable = true)
 |-- last_name_1: string (nullable = true)
 |-- job_title_1: string (nullable = true)
 |-- dob_1: date (nullable = true)
 |-- email_1: string (nullable = true)
 |-- phone_1: string (nullable = true)
 |-- salary_1: integer (nullable = true)
 |-- department_id_1: integer (nullable = true)
 |-- first_name_2: string (nullable = true)
 |-- last_name_2: string (nullable = true)
 |-- job_title_2: string (nullable = true)
 |-- dob_2: date (nullable = true)
 |-- email_2: string (nullable = true)
 |-- phone_2: string (nullable = true)
 |-- salary_2: integer (nullable = true)
 |-- department_id_2: integer (nullable = true)



In [None]:
# REMOVE DUPLICATED AND UNLIKELY
siblings = siblings.filter(F.col("email_1") != F.col("email_2"))
siblings = siblings.filter(
    F.abs(F.months_between(F.col("dob_1"), F.col("dob_2"))) < 120
)
siblings = siblings.filter(
    F.concat(F.col("email_1"), F.lit("-"), F.col("email_2"))
    != F.concat(F.col("email_2"), F.lit("-"), F.col("email_1"))
)

# GROUP BY TITLE PAIRS
siblings = siblings.withColumn(
    "pair_title_1", F.least("job_title_1", "job_title_2")
).withColumn("pair_title_2", F.greatest("job_title_1", "job_title_2"))
siblings_grouped = siblings.groupby("pair_title_1", "pair_title_2").count()

In [8]:
siblings_grouped.show()



+--------------------+--------------------+-----+
|        pair_title_1|        pair_title_2|count|
+--------------------+--------------------+-----+
| Colour technologist|Engineer, electrical|   56|
|         Chiropodist|Sound technician,...|  100|
| Designer, furniture|Production design...|   74|
|Environmental con...|   Recycling officer|   96|
|Claims inspector/...|Financial risk an...|  100|
|Clinical cytogene...|Designer, televis...|   66|
|Biochemist, clinical|        Neurosurgeon|   80|
|Financial risk an...|Radio broadcast a...|  106|
|Commercial hortic...|Speech and langua...|   58|
|Community arts wo...|Television/film/v...|   58|
|Amenity horticult...|Production assist...|  100|
|Amenity horticult...|            Musician|   46|
|   Company secretary|Journalist, newsp...|   84|
|Education officer...|Production design...|   66|
|Higher education ...|Museum/gallery co...|  100|
|          Oncologist|    Public librarian|   98|
|Insurance underwr...|          Oncologist|  104|


                                                                                

In [None]:
# WITHOUT CACHING
from datetime import datetime

start = datetime.now()

# Total of possible siblings
print(siblings_grouped.count())

# Total of siblings who are Oncologist or H. Manager
siblings_grouped.filter(
    F.col("pair_title_1").isin("Oncologist", "Hotel manager")
    & F.col("pair_title_2").isin("Oncologist", "Hotel manager")
).select(F.sum(F.col("count"))).show()

# Total of siblings with same role
siblings_grouped.filter(F.col("pair_title_1") == F.col("pair_title_2")).select(
    F.sum(F.col("count"))
).show()

# Total of siblings starting with A
siblings_grouped.filter(F.substring(F.col("pair_title_1"), 1, 1) == F.lit("A")).select(
    F.sum(F.col("count"))
).show()

end = datetime.now()

print((end - start).total_seconds())

                                                                                

204480


                                                                                

+----------+
|sum(count)|
+----------+
|       214|
+----------+



                                                                                

+----------+
|sum(count)|
+----------+
|     22478|
+----------+





+----------+
|sum(count)|
+----------+
|   2454176|
+----------+

33.853596


                                                                                

In [None]:
# WITH CACHING

siblings_grouped.cache()

start = datetime.now()
print(siblings_grouped.count())
siblings_grouped.filter(
    F.col("pair_title_1").isin("Oncologist", "Hotel manager")
    & F.col("pair_title_2").isin("Oncologist", "Hotel manager")
).select(F.sum(F.col("count"))).show()
siblings_grouped.filter(F.col("pair_title_1") == F.col("pair_title_2")).select(
    F.sum(F.col("count"))
).show()
siblings_grouped.filter(F.substring(F.col("pair_title_1"), 1, 1) == F.lit("A")).select(
    F.sum(F.col("count"))
).show()
end = datetime.now()

print((end - start).total_seconds())

                                                                                

204480


                                                                                

+----------+
|sum(count)|
+----------+
|       214|
+----------+



                                                                                

+----------+
|sum(count)|
+----------+
|     22478|
+----------+

+----------+
|sum(count)|
+----------+
|   2454176|
+----------+

22.623305


In [8]:
spark.sparkContext.setCheckpointDir("hdfs://namenode:9000/checkpoint")

In [9]:
siblings_grouped_checkpoint = siblings_grouped.checkpoint()

                                                                                

In [10]:
siblings_grouped_checkpoint.count()

                                                                                

204480

In [11]:
spark.stop()