## Import and setup Spark Session

In [1]:
import pyspark
from pyspark.sql import SparkSession
from pyspark.sql import DataFrame
from pyspark.sql import functions as f
from contextlib import contextmanager
from logging import Logger
import time
from pathlib import Path

### Utility functions

In [2]:
@contextmanager
def timer(operation_name: str, logger: Logger = None):
    """Context manager to time operations"""
    start = time.time()
    try:
        yield
    finally:
        duration = time.time() - start
        msg = f"{operation_name} completed in {duration:.2f}s"
        if logger:
            logger.info(msg)
        else:
            print(msg)

In [3]:
spark_warehouse_dir = "./spark-warehouse"
spark_events_dir = "./spark-events"

Path(spark_events_dir).mkdir(parents=True, exist_ok=True)

spark = (
    SparkSession.builder
    .appName("partition-issue")
    .master("local[*]")
    .config("spark.sql.warehouse.dir", spark_warehouse_dir)
    .config("spark.eventLog.enabled", "true")
    .config("spark.eventLog.dir", spark_events_dir)  # choose any directory
    .config("spark.ui.enabled", "false")  # optional: disables live UI completely
    .enableHiveSupport() 
    .getOrCreate()
)

Using Spark's default log4j profile: org/apache/spark/log4j2-defaults.properties
Setting default log level to "WARN".
To adjust logging level use sc.setLogLevel(newLevel). For SparkR, use setLogLevel(newLevel).
25/12/01 17:42:28 WARN NativeCodeLoader: Unable to load native-hadoop library for your platform... using builtin-java classes where applicable


## Create table if not exists

In [4]:
TABLE_NAME = "sales_orders"
spark.sql(f"""
    CREATE TABLE IF NOT EXISTS {TABLE_NAME} (
        id          INT,
        name        STRING,
        amount      DOUBLE
    )
    PARTITIONED BY (event_date DATE)
    STORED AS PARQUET
""")

25/12/01 17:42:31 WARN ObjectStore: Version information not found in metastore. hive.metastore.schema.verification is not enabled so recording the schema version 2.3.0
25/12/01 17:42:31 WARN ObjectStore: setMetaStoreSchemaVersion called but recording version is disabled: version = 2.3.0, comment = Set by MetaStore mc8max@10.169.177.112


DataFrame[]

## Generate dummy data

In [5]:
# -----------------------------
# Config
# -----------------------------
TOTAL_ROWS = 25_000_000       # ~25M rows
NUM_DAYS = 365                # spread across 365 days (> 300)
BASE_DATE = "2024-01-01"      # starting date for event_date
CREATED = True

spark.sql("SET hive.exec.dynamic.partition = true")
spark.sql("SET hive.exec.dynamic.partition.mode = nonstrict")

if spark.table(TABLE_NAME).isEmpty():
    base_df = spark.range(
        start=0,
        end=TOTAL_ROWS,
        step=1,
        numPartitions=200,   # tune based on your machine/cluster
    )
    
    df = (
        base_df
        .withColumnRenamed("id", "rownum")                # original column from range()
        .withColumn("id", f.col("rownum").cast("int"))      # our business id
        .withColumn("name", f.expr("concat('customer_', id % 100000)"))
        .withColumn("amount", (f.rand() * 1000).cast("double"))
        .withColumn("day_offset", (f.col("id") % NUM_DAYS).cast("int"))
        .withColumn(
            "event_date",
            f.expr(f"date_add(date('{BASE_DATE}'), day_offset)")
        )
        .select("id", "name", "amount", "event_date")     # match table schema/order
    )
    df.write.mode("append").insertInto(TABLE_NAME)

25/12/01 17:42:31 WARN SetCommand: 'SET hive.exec.dynamic.partition=true' might not work, since Spark doesn't support changing the Hive config dynamically. Please pass the Hive-specific config by adding the prefix spark.hadoop (e.g. spark.hadoop.hive.exec.dynamic.partition) when starting a Spark application. For details, see the link: https://spark.apache.org/docs/latest/configuration.html#dynamically-loading-spark-properties.
25/12/01 17:42:31 WARN SetCommand: 'SET hive.exec.dynamic.partition.mode=nonstrict' might not work, since Spark doesn't support changing the Hive config dynamically. Please pass the Hive-specific config by adding the prefix spark.hadoop (e.g. spark.hadoop.hive.exec.dynamic.partition.mode) when starting a Spark application. For details, see the link: https://spark.apache.org/docs/latest/configuration.html#dynamically-loading-spark-properties.
                                                                                

In [6]:
def noop(df: DataFrame):
    df.write.format("noop").mode("overwrite").save()

    
def get_max_partition_from_table(spark: SparkSession, table_name:str, partition_col:str) -> DataFrame:
    df = spark.table(table_name)
    max_date_df = df.select(f.max(f.col(partition_col)).alias("max_date"))
    return df.join(f.broadcast(max_date_df), f.col(partition_col) == f.col("max_date")).drop("max_date")


def get_max_partition_from_table_with_showpartitions(spark: SparkSession, table_name:str, partition_col:str) -> DataFrame:
    parts_df = spark.sql(f"SHOW PARTITIONS {table_name}")
    parts_with_cols = (
        parts_df
            .withColumn(partition_col, f.split("partition", "/")[0])      # "event_date=2024-01-01"
            .withColumn(partition_col, f.regexp_extract(partition_col, fr"{partition_col}=([^/]+)", 1))
            .withColumn(partition_col, f.to_date(partition_col))
    )
    max_date_df = parts_with_cols.select(f.col(partition_col)).agg(f.max(partition_col).alias("max_date"))
    # max_date_df.show()
    if max_date_df.isEmpty():
        return spark.table(table_name).filter("0 = 1")
    return spark.table(table_name).filter(f.col(partition_col).eqNullSafe(f.lit(max_date_df.collect()[0][0])))


def get_max_partition_from_table_with_showpartitions_with_join(spark: SparkSession, table_name:str, partition_col:str) -> DataFrame:
    parts_df = spark.sql(f"SHOW PARTITIONS {table_name}")
    parts_with_cols = (
        parts_df
            .withColumn(partition_col, f.split("partition", "/")[0])      # "event_date=2024-01-01"
            .withColumn(partition_col, f.regexp_extract(partition_col, fr"{partition_col}=([^/]+)", 1))
            .withColumn(partition_col, f.to_date(partition_col))
    )
    max_date_df = parts_with_cols.select(f.col(partition_col)).agg(f.max(partition_col).alias("max_date"))
    df = spark.table(table_name)
    return df.join(f.broadcast(max_date_df), f.col(partition_col) == f.col("max_date")).drop("max_date")

In [7]:
spark.sql("CLEAR CACHE;")
with timer("get_max_partition_from_table_with_showpartitions"):
    df = get_max_partition_from_table_with_showpartitions(spark, TABLE_NAME, "event_date")
    noop(df)    

get_max_partition_from_table_with_showpartitions completed in 0.70s


In [8]:
spark.sql("CLEAR CACHE;")
with timer("get_max_partition_from_table_with_showpartitions_with_join"):
    df = get_max_partition_from_table_with_showpartitions_with_join(spark, TABLE_NAME, "event_date")
    noop(df)    

get_max_partition_from_table_with_showpartitions_with_join completed in 0.38s


In [9]:
spark.sql("CLEAR CACHE;")
with timer("get_max_partition_from_table"):
    df = get_max_partition_from_table(spark, TABLE_NAME, "event_date")
    noop(df)    



get_max_partition_from_table completed in 13.76s


                                                                                

In [10]:
with timer("get_max_partition_from_table_with_showpartitions_with_join"):
    df = spark.table(TABLE_NAME)
    max_date_df = df.select(f.max(f.col("event_date")).alias("max_date"))
    max_date_df.show()



+----------+
|  max_date|
+----------+
|2024-12-30|
+----------+

get_max_partition_from_table_with_showpartitions_with_join completed in 6.72s


                                                                                

In [11]:
spark.stop()