In [1]:
from pyspark.sql import SparkSession
import time

In [2]:
spark = SparkSession.builder \
    .appName("Caching Demo") \
    .getOrCreate()

In [3]:
data = [(i, f"Name_{i}", i * 2) for i in range(1000000)]
df = spark.createDataFrame(data, ["id", "name", "value"])

In [12]:
df.show()

+---+-------+-----+
| id|   name|value|
+---+-------+-----+
|  0| Name_0|    0|
|  1| Name_1|    2|
|  2| Name_2|    4|
|  3| Name_3|    6|
|  4| Name_4|    8|
|  5| Name_5|   10|
|  6| Name_6|   12|
|  7| Name_7|   14|
|  8| Name_8|   16|
|  9| Name_9|   18|
| 10|Name_10|   20|
| 11|Name_11|   22|
| 12|Name_12|   24|
| 13|Name_13|   26|
| 14|Name_14|   28|
| 15|Name_15|   30|
| 16|Name_16|   32|
| 17|Name_17|   34|
| 18|Name_18|   36|
| 19|Name_19|   38|
+---+-------+-----+
only showing top 20 rows



In [4]:
def measure_time(operation, df):
    start_time = time.time()
    result = operation(df)
    end_time = time.time()
    return end_time - start_time, result

In [5]:
def sample_operation(df):
    return df.groupBy("name").count().collect()

In [6]:
print("Running without caching...")
time_no_cache, result1 = measure_time(sample_operation, df)
print(f"Time without caching: {time_no_cache:.2f} seconds")

Running without caching...
Time without caching: 22.89 seconds


In [7]:
df.cache()

DataFrame[id: bigint, name: string, value: bigint]

In [8]:
df.count()

1000000

In [9]:
print("\nRunning with caching...")
time_with_cache, result2 = measure_time(sample_operation, df)
print(f"Time with caching: {time_with_cache:.2f} seconds")


Running with caching...
Time with caching: 11.14 seconds


In [10]:
from pyspark.storagelevel import StorageLevel
df.persist(StorageLevel.MEMORY_AND_DISK)
df.count()

1000000

In [11]:
df.unpersist()

DataFrame[id: bigint, name: string, value: bigint]

In [13]:
repartitioned_df = df.repartition(4)
print(f"Number of partitions: {repartitioned_df.rdd.getNumPartitions()}")

Number of partitions: 4


In [14]:
repartitioned_by_col = df.repartition(3, "value")

In [15]:
df = spark.createDataFrame(data, ["name", "value"]).repartition(4)
print(f"Original partitions: {df.rdd.getNumPartitions()}")

Original partitions: 4


In [16]:
coalesced_df = df.coalesce(2)
print(f"After coalesce: {coalesced_df.rdd.getNumPartitions()}")

After coalesce: 2
