# WAP implementation for Iceberg version < 1.2.1

In [None]:
!pip install pyspark=3.3
!pip install findspark
!pip install pydeequ

In [None]:
import findspark
findspark.init()
findspark.find()

Loading Iceberg jars 
- .config('spark.jars.packages', 'org.apache.iceberg:iceberg-spark-runtime-3.5_2.12:1.4.2')

Loading Iceberg Extensions to call stored procedures
- .config('spark.sql.extensions','org.apache.iceberg.spark.extensions.IcebergSparkSessionExtensions')

Creating Iceberg catalog name of type `hive` that loads table from `Hive Metastore`. Adds support for spark built-in catalog 
- .config('spark.sql.catalog.spark_catalog','org.apache.iceberg.spark.SparkSessionCatalog')
- .config('spark.sql.catalog.spark_catalog.type','hive')

Creating Iceberg catalog named `local` of type `hadoop`. This supports directory based catalog in HDFS
- .config('spark.sql.catalog.local','org.apache.iceberg.spark.SparkCatalog')
- .config('spark.sql.catalog.local.type','hadoop')
- .config('spark.sql.catalog.local.warehouse','<path_to_warehouse>') \

If `type` is `null`, `spark.sql.catalog.<catalog-name>.catalog-impl` **shouldn't** be `null`

In [None]:
from pyspark.sql import SparkSession
import os

os.environ["SPARK_VERSION"] = '3.3'
import pydeequ

warehouse_directory = "local_path"

spark = SparkSession.builder \
    .master("local[4]") \
    .appName("wap-iceberg-1.2.0") \
    .config('spark.jars.packages', f'org.apache.iceberg:iceberg-spark-runtime-3.3_2.12:1.2.0,{pydeequ.deequ_maven_coord}')\
    .config('spark.jars.excludes', pydeequ.f2j_maven_coord) \
    .config('spark.sql.extensions','org.apache.iceberg.spark.extensions.IcebergSparkSessionExtensions') \
    .config('spark.sql.catalog.spark_catalog','org.apache.iceberg.spark.SparkSessionCatalog') \
    .config('spark.sql.catalog.spark_catalog.type','hive') \
    .config('spark.sql.catalog.local','org.apache.iceberg.spark.SparkCatalog') \
    .config('spark.sql.catalog.local.type','hadoop') \
    .config('spark.sql.catalog.local.warehouse',f'{warehouse_directory}/warehouse') \
    .getOrCreate()

## Reading and Creating Iceberg data

In [8]:
green_df = spark.read.parquet("../nyc-taxi-trips/green/sep-2023/")
green_df.printSchema(), green_df.count()

root
 |-- VendorID: integer (nullable = true)
 |-- lpep_pickup_datetime: timestamp (nullable = true)
 |-- lpep_dropoff_datetime: timestamp (nullable = true)
 |-- store_and_fwd_flag: string (nullable = true)
 |-- RatecodeID: long (nullable = true)
 |-- PULocationID: integer (nullable = true)
 |-- DOLocationID: integer (nullable = true)
 |-- passenger_count: long (nullable = true)
 |-- trip_distance: double (nullable = true)
 |-- fare_amount: double (nullable = true)
 |-- extra: double (nullable = true)
 |-- mta_tax: double (nullable = true)
 |-- tip_amount: double (nullable = true)
 |-- tolls_amount: double (nullable = true)
 |-- ehail_fee: double (nullable = true)
 |-- improvement_surcharge: double (nullable = true)
 |-- total_amount: double (nullable = true)
 |-- payment_type: long (nullable = true)
 |-- trip_type: long (nullable = true)
 |-- congestion_surcharge: double (nullable = true)



(None, 65471)

In [9]:
# Creating month and year column
from pyspark.sql.functions import lit
green_df = green_df.withColumn("month", lit(9)) \
        .withColumn("year", lit(2023))

In [124]:
# Creating and Iceberg table
green_df.writeTo("local.nyc_tlc.green_taxi_trips").partitionedBy("year", "month").using("iceberg") \
    .tableProperty("format-version", "2") \
    .tableProperty("write.parquet.compression-codec", "snappy") \
    .create()

In [11]:
table_stmt = spark.sql("show create table local.nyc_tlc.green_taxi_trips").collect()[0]['createtab_stmt']

```sql
CREATE TABLE Iceberg local.nyc_tlc.green_taxi_trips (
    VendorID INT,
    lpep_pickup_datetime TIMESTAMP,
    lpep_dropoff_datetime TIMESTAMP,
    store_and_fwd_flag STRING,
    RatecodeID BIGINT,
    PULocationID INT,
    DOLocationID INT,
    passenger_count BIGINT,
    trip_distance DOUBLE,
    fare_amount DOUBLE,
    extra DOUBLE,
    mta_tax DOUBLE,
    tip_amount DOUBLE,
    tolls_amount DOUBLE,
    ehail_fee DOUBLE,
    improvement_surcharge DOUBLE,
    total_amount DOUBLE,
    payment_type BIGINT,
    trip_type BIGINT,
    congestion_surcharge DOUBLE,
    month INT,
    year INT)
    USING iceberg
    PARTITIONED BY (year, month)
    LOCATION '<warehouse_path>/nyc_tlc/green_taxi_trips'
    TBLPROPERTIES (
        'current-snapshot-id' = '9122000941857891650',
        'format' = 'iceberg/parquet',
        'format-version' = '2',
        'write.parquet.compression-codec' = 'snappy')


In [126]:
spark.sql("select * from local.nyc_tlc.green_taxi_trips.snapshots").show(truncate=False)

+-----------------------+-------------------+---------+---------+---------------------------------------------------------------------------------------------------------------------------------------------------------------------------------+--------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------+
|committed_at           |snapshot_id        |parent_id|operation|manifest_list                                                                                                                                                                    |summary                                                                                                                                                                                                           

## Write new data into production table using WAP enabled

In [127]:
# Read new data that needs to be written in to prod table
new_df = spark.read.parquet("../nyc-taxi-trips/green/oct-2023/")
new_df = new_df.withColumn("year", lit(2023)).withColumn("month", lit(10))
new_df.count()

66177

In [128]:
prod_table = "local.nyc_tlc.green_taxi_trips"

## Check and add write.wap.enabled = true in prod table properties if not present.
spark.sql(f"ALTER TABLE {prod_table} SET TBLPROPERTIES ('write.wap.enabled' = 'true')")

23/12/26 18:02:40 WARN BaseTransaction: Failed to load metadata for a committed snapshot, skipping clean-up


DataFrame[]

In [None]:
# check if tblproperties are updated with write.wap.enabled = true
spark.sql("show create table local.nyc_tlc.green_taxi_trips").collect()[0]['createtab_stmt']

In [130]:
# As Branching is not available in Iceberg < 1.2.1, 
# generated a wap.id that will be used to identify the snapshot_id for new data written into table

import uuid
wap_id = uuid.uuid4().hex

# set wap.id in spark session to put it in the summary of snapshot
spark.conf.set("spark.wap.id", wap_id)

print(wap_id)

de6d592742b54cf09353fb168d6dc856


In [131]:
# Write new data into table now
new_df.writeTo(prod_table).append()

In [132]:
spark.table(prod_table).groupBy("year","month").count().show()

+----+-----+-----+
|year|month|count|
+----+-----+-----+
|2023|    9|65471|
+----+-----+-----+



In [None]:
# checking the snapshots for the table
spark.sql(f"select * from {prod_table}.snapshots").show(truncate=False)

## Audit

- Get snapshot_id corresponding to the wap.id that was generated.
- Read data based on this snapshot_id.
- Run your DQ tests on it.
- Based on DQ Test results, you can choose if this new snapshot needs to be published further or can be expired or deleted.


In [134]:
from pyspark.sql.functions import col
staged_data_snap_id = spark.sql(f"select snapshot_id, summary['wap.id'] as wap_id from {prod_table}.snapshots").filter(col("wap_id") == wap_id).collect()[0]['snapshot_id']
staged_data_snap_id

23/12/26 18:03:06 WARN SparkScanBuilder: Failed to check if IsNotNull(summary) can be pushed down: Cannot find field 'summary' in struct: struct<>


3620034948650558114

In [135]:
# Reading the staged data from prod_table
staged_data = spark.read.option("snapshot-id", staged_data_snap_id).table(prod_table)


In [136]:
staged_data.groupBy("year","month").count().show()

+----+-----+-----+
|year|month|count|
+----+-----+-----+
|2023|    9|65471|
|2023|   10|66177|
+----+-----+-----+



### Running DQ tests using PyDeequ.

Let's say your DQ requirement is:
- Completeness criteria for `VendorID` should be 1.0 i.e. there are no `null`.
- `total_amount` shouldn't be `negative`
- `payment_type` should be discrete numbers among 1 to 6.

In [137]:
from pydeequ.checks import *
from pydeequ.verification import *

# Check represents a list of constraints that can be applied to a provided Spark Dataframe
vendorID_checks = Check(spark, CheckLevel.Error, "VendorID Checks")
payment_type_checks = Check(spark, CheckLevel.Error, "payment_type Checks")
tot_amt_checks = Check(spark, CheckLevel.Warning, "total_amount Checks")


checkResult = VerificationSuite(spark) \
    .onData(staged_data) \
    .addCheck(vendorID_checks.isComplete("VendorID")) \
    .addCheck(payment_type_checks.isContainedIn("payment_type", ['1','2','3','4','5','6'])) \
    .addCheck(tot_amt_checks.isNonNegative("total_amount")) \
    .run()

print(f"Verification Run Status: {checkResult.status}")
checkResult_df = VerificationResult.checkResultsAsDataFrame(spark, checkResult)
checkResult_df.show(truncate=False)

+-------------------+-----------+------------+--------------------------------------------------------------------------------------------------------------------------------------------------+-----------------+-------------------------------------------------------------------+
|check              |check_level|check_status|constraint                                                                                                                                        |constraint_status|constraint_message                                                 |
+-------------------+-----------+------------+--------------------------------------------------------------------------------------------------------------------------------------------------+-----------------+-------------------------------------------------------------------+
|VendorID Checks    |Error      |Success     |CompletenessConstraint(Completeness(VendorID,None))                                                               



In [113]:
# spark.sql(f"DELETE FROM {prod_table} where total_amount < 0")
cleaned_df = staged_data.filter(col("total_amount") >= 0)
# write the cleaned_df to the prod_table, this will be written with the same wap.id that we have set in SparkSession
# so it will still be like writing data into prod table from start. Check parent_id after writing it should be same for both the commits.
#cleaned_df.writeTo(prod_table).append()
cleaned_df.groupBy("year","month").count().show()

+----+-----+-----+
|year|month|count|
+----+-----+-----+
|2023|    9|65263|
|2023|   10|65974|
+----+-----+-----+



In [115]:
# writing cleaned data 
cleaned_df.writeTo(prod_table).overwritePartitions()

## Publish

In [117]:
# getting the latest snapshot_id after the cleaned data is written
udpated_data_snap_id = spark.sql(f"select snapshot_id, summary['wap.id'] as wap_id from {prod_table}.snapshots").filter(col("wap_id") == wap_id).orderBy(col("committed_at").desc()).collect()[0]['snapshot_id']
udpated_data_snap_id

23/12/26 17:26:47 WARN SparkScanBuilder: Failed to check if IsNotNull(summary) can be pushed down: Cannot find field 'summary' in struct: struct<>


7581908246933781414

In [118]:
# checking in data in Production Table before Publishing
spark.table(prod_table).groupBy("year","month").count().show()

+----+-----+-----+
|year|month|count|
+----+-----+-----+
|2023|    9|65471|
+----+-----+-----+



In [119]:
# call cherrypick to publish the updated final snapshot with clean data in Production Table
spark.sql(f"CALL local.system.cherrypick_snapshot('{prod_table}', {udpated_data_snap_id})").show()

23/12/26 17:26:56 WARN BaseTransaction: Failed to load metadata for a committed snapshot, skipping clean-up
+-------------------+-------------------+
| source_snapshot_id|current_snapshot_id|
+-------------------+-------------------+
|7581908246933781414|7581908246933781414|
+-------------------+-------------------+



In [120]:
# Verifying data after Publish
spark.table(prod_table).groupBy("year","month").count().show()

+----+-----+-----+
|year|month|count|
+----+-----+-----+
|2023|    9|65263|
|2023|   10|65974|
+----+-----+-----+



In [122]:
# validating data if there are any records with negative total_amount
spark.sql(f"select count(*) from {prod_table} where total_amount < 0").show()

+--------+
|count(1)|
+--------+
|       0|
+--------+

