# S1 J4 ? Performance (Cache + Partition)

This notebook demonstrates basic performance techniques: partitioning, caching, and inspecting plans.


In [1]:
try:
    spark
except NameError:
    from pyspark.sql import SparkSession
    spark = SparkSession.builder.appName("performance-cache-partition").getOrCreate()


Using Spark's default log4j profile: org/apache/spark/log4j2-defaults.properties
26/02/09 09:37:46 WARN Utils: Your hostname, MA-L-481079, resolves to a loopback address: 127.0.1.1; using 10.255.255.254 instead (on interface lo)
26/02/09 09:37:46 WARN Utils: Set SPARK_LOCAL_IP if you need to bind to another address
Using Spark's default log4j profile: org/apache/spark/log4j2-defaults.properties
Setting default log level to "WARN".
To adjust logging level use sc.setLogLevel(newLevel). For SparkR, use setLogLevel(newLevel).
26/02/09 09:37:48 WARN NativeCodeLoader: Unable to load native-hadoop library for your platform... using builtin-java classes where applicable
26/02/09 09:37:49 WARN Utils: Service 'SparkUI' could not bind on port 4040. Attempting port 4041.
26/02/09 09:37:49 WARN Utils: Service 'SparkUI' could not bind on port 4041. Attempting port 4042.


In [2]:
from pyspark.sql import functions as F

data_path = "../../data/example.csv"

raw = (
    spark.read
    .option("header", True)
    .option("inferSchema", True)
    .csv(data_path)
)

silver = (
    raw
    .withColumn("signup_date", F.to_date("signup_date"))
    .withColumn("spend", F.col("spend").cast("double"))
)


In [3]:
# Partitioning example
print("Initial partitions:", silver.rdd.getNumPartitions())
by_plan = silver.repartition(4, "plan")
print("After repartition:", by_plan.rdd.getNumPartitions())


Initial partitions: 1
After repartition: 4


In [5]:
# Cache to avoid recomputation across actions
by_plan.cache()

# Materialize cache
by_plan.count()

# Confirm cache status
print("Is cached:", by_plan.is_cached)


Is cached: True


26/02/09 09:43:45 WARN CacheManager: Asked to cache already cached data.


In [6]:
# Explain physical plan for a simple aggregation
agg = (
    by_plan
    .groupBy("plan")
    .agg(
        F.count("*").alias("users"),
        F.round(F.sum("spend"), 2).alias("total_spend"),
    )
)

agg.explain("formatted")
agg.show(truncate=False)


== Physical Plan ==
AdaptiveSparkPlan (10)
+- HashAggregate (9)
   +- HashAggregate (8)
      +- InMemoryTableScan (1)
            +- InMemoryRelation (2)
                  +- AdaptiveSparkPlan (7)
                     +- == Final Plan ==
                        ResultQueryStage (6)
                        +- ShuffleQueryStage (5), Statistics(sizeInBytes=904.0 B, rowCount=10)
                           +- Exchange (4)
                              +- Scan csv  (3)
                     +- == Initial Plan ==
                        Exchange (4)
                        +- Scan csv  (3)


(1) InMemoryTableScan
Output [2]: [plan#21, spend#26]
Arguments: [plan#21, spend#26]

(2) InMemoryRelation
Arguments: [name#17, age#18, city#19, signup_date#25, plan#21, is_active#22, spend#26], StorageLevel(disk, memory, deserialized, 1 replicas)

(3) Scan csv 
Output [7]: [name#17, age#18, city#19, signup_date#20, plan#21, is_active#22, spend#23]
Batched: false
Location: InMemoryFileIndex [file:/home/la

In [7]:
# Free cache when done
by_plan.unpersist()


DataFrame[name: string, age: int, city: string, signup_date: date, plan: string, is_active: boolean, spend: double]