# WAP using Branch

- Since Iceberg version > 1.2.0, Iceberg has feature of branching and tagging
- This notebook shows how to implement WAP on AWS

In [None]:
%%configure
{
    "conf": {
      "spark.sql.extensions": "org.apache.iceberg.spark.extensions.IcebergSparkSessionExtensions",
      "spark.sql.catalog.glue": "org.apache.iceberg.spark.SparkCatalog",
        "spark.sql.catalog.glue.catalog-impl":"org.apache.iceberg.aws.glue.GlueCatalog",
        "spark.sql.catalog.glue.io-impl":"org.apache.iceberg.aws.s3.S3FileIO",
        "spark.sql.catalog.glue.warehouse":"s3://<s3_bucket>/glue/warehouse"
    }
}

## Creating a Production Table from Sep 2023 Yellow Taxi Trips Data

- This data can be found [here](https://www.nyc.gov/site/tlc/about/tlc-trip-record-data.page)

In [25]:
from pyspark.sql.functions import lit, col

# Reading NYC Yellow Taxi Data
data_bucket = "my-data-bucket" # Change here
raw_data_path = f"s3://{data_bucket}/raw_data/nyc_tlc"
yellow_df = spark.read.format("parquet").load(f"{raw_data_path}/yellow/sep2023/")

VBox()

FloatProgress(value=0.0, bar_style='info', description='Progress:', layout=Layout(height='25px', width='50%'),…

In [26]:
yellow_df.printSchema()

VBox()

FloatProgress(value=0.0, bar_style='info', description='Progress:', layout=Layout(height='25px', width='50%'),…

root
 |-- VendorID: integer (nullable = true)
 |-- tpep_pickup_datetime: timestamp_ntz (nullable = true)
 |-- tpep_dropoff_datetime: timestamp_ntz (nullable = true)
 |-- passenger_count: long (nullable = true)
 |-- trip_distance: double (nullable = true)
 |-- RatecodeID: long (nullable = true)
 |-- store_and_fwd_flag: string (nullable = true)
 |-- PULocationID: integer (nullable = true)
 |-- DOLocationID: integer (nullable = true)
 |-- payment_type: long (nullable = true)
 |-- fare_amount: double (nullable = true)
 |-- extra: double (nullable = true)
 |-- mta_tax: double (nullable = true)
 |-- tip_amount: double (nullable = true)
 |-- tolls_amount: double (nullable = true)
 |-- improvement_surcharge: double (nullable = true)
 |-- total_amount: double (nullable = true)
 |-- congestion_surcharge: double (nullable = true)
 |-- Airport_fee: double (nullable = true)

In [27]:
# adding year and month column
yellow_df = yellow_df.withColumn("month", lit(9)) \
        .withColumn("year", lit(2023))

VBox()

FloatProgress(value=0.0, bar_style='info', description='Progress:', layout=Layout(height='25px', width='50%'),…

In [28]:
# creating an Iceberg table in glue catalog within nyc_tlc database. default compression is taken as 'zstd'
yellow_df.writeTo("glue.nyc_tlc.yellow_taxi_trips").partitionedBy("year", "month").using("iceberg") \
            .tableProperty("format-version", "2") \
            .tableProperty("write.parquet.compression-codec", "snappy") \
            .create()

VBox()

FloatProgress(value=0.0, bar_style='info', description='Progress:', layout=Layout(height='25px', width='50%'),…

```sql
CREATE TABLE nyc_tlc.yellow_taxi_trips (
  VendorID int,
  tpep_pickup_datetime timestamp,
  tpep_dropoff_datetime timestamp,
  passenger_count bigint,
  trip_distance double,
  RatecodeID bigint,
  store_and_fwd_flag string,
  PULocationID int,
  DOLocationID int,
  payment_type bigint,
  fare_amount double,
  extra double,
  mta_tax double,
  tip_amount double,
  tolls_amount double,
  improvement_surcharge double,
  total_amount double,
  congestion_surcharge double,
  Airport_fee double,
  month int,
  year int)
PARTITIONED BY (`year`, `month`)
LOCATION '<glue_db_location>/yellow_taxi_trips'
TBLPROPERTIES (
  'table_type'='iceberg',
  'write_compression'='snappy'
);
```

In [30]:
# checking the refrences in the table
spark.sql("select * from glue.nyc_tlc.yellow_taxi_trips.refs").show(truncate=False)

VBox()

FloatProgress(value=0.0, bar_style='info', description='Progress:', layout=Layout(height='25px', width='50%'),…

+----+------+-------------------+-----------------------+---------------------+----------------------+
|name|type  |snapshot_id        |max_reference_age_in_ms|min_snapshots_to_keep|max_snapshot_age_in_ms|
+----+------+-------------------+-----------------------+---------------------+----------------------+
|main|BRANCH|2498870574022386457|null                   |null                 |null                  |
+----+------+-------------------+-----------------------+---------------------+----------------------+

In [33]:
# Check data into prod table
spark.table(prod_table).groupBy("year", "month").count().show()

VBox()

FloatProgress(value=0.0, bar_style='info', description='Progress:', layout=Layout(height='25px', width='50%'),…

+----+-----+-------+
|year|month|  count|
+----+-----+-------+
|2023|    9|2846722|
+----+-----+-------+

### Preparing Source Data

- This data will be written in Production table created above using WAP Pattern.

In [45]:
## Source data that will be written later in the prod table.
source_df = spark.read.parquet(f"{raw_data_path}/yellow/oct2023/")
source_df = source_df.withColumn("month", lit(10)) \
        .withColumn("year", lit(2023))
source_df.printSchema(), source_df.count()

VBox()

FloatProgress(value=0.0, bar_style='info', description='Progress:', layout=Layout(height='25px', width='50%'),…

root
 |-- VendorID: integer (nullable = true)
 |-- tpep_pickup_datetime: timestamp_ntz (nullable = true)
 |-- tpep_dropoff_datetime: timestamp_ntz (nullable = true)
 |-- passenger_count: long (nullable = true)
 |-- trip_distance: double (nullable = true)
 |-- RatecodeID: long (nullable = true)
 |-- store_and_fwd_flag: string (nullable = true)
 |-- PULocationID: integer (nullable = true)
 |-- DOLocationID: integer (nullable = true)
 |-- payment_type: long (nullable = true)
 |-- fare_amount: double (nullable = true)
 |-- extra: double (nullable = true)
 |-- mta_tax: double (nullable = true)
 |-- tip_amount: double (nullable = true)
 |-- tolls_amount: double (nullable = true)
 |-- improvement_surcharge: double (nullable = true)
 |-- total_amount: double (nullable = true)
 |-- congestion_surcharge: double (nullable = true)
 |-- Airport_fee: double (nullable = true)
 |-- month: integer (nullable = false)
 |-- year: integer (nullable = false)

(None, 3522285)

In [46]:
source_df.groupBy("year", "month").count().show()

VBox()

FloatProgress(value=0.0, bar_style='info', description='Progress:', layout=Layout(height='25px', width='50%'),…

+----+-----+-------+
|year|month|  count|
+----+-----+-------+
|2023|   10|3522285|
+----+-----+-------+

## WAP Implementation

In [31]:
prod_table = "glue.nyc_tlc.yellow_taxi_trips"
audit_branch = f"audit_102023"

# Create an Audit Branch for staging the new data before writing in prod table
spark.sql(f"ALTER TABLE {prod_table} CREATE BRANCH {audit_branch}")

VBox()

FloatProgress(value=0.0, bar_style='info', description='Progress:', layout=Layout(height='25px', width='50%'),…

DataFrame[]

In [32]:
# Verify if an audit branch is created.
spark.sql(f"select * from {prod_table}.refs").show(truncate=False)

VBox()

FloatProgress(value=0.0, bar_style='info', description='Progress:', layout=Layout(height='25px', width='50%'),…

+------------+------+-------------------+-----------------------+---------------------+----------------------+
|name        |type  |snapshot_id        |max_reference_age_in_ms|min_snapshots_to_keep|max_snapshot_age_in_ms|
+------------+------+-------------------+-----------------------+---------------------+----------------------+
|main        |BRANCH|2498870574022386457|null                   |null                 |null                  |
|audit_102023|BRANCH|2498870574022386457|null                   |null                 |null                  |
+------------+------+-------------------+-----------------------+---------------------+----------------------+

In [44]:
spark.sql(f"select * from {prod_table}.snapshots").show(truncate=False)

VBox()

FloatProgress(value=0.0, bar_style='info', description='Progress:', layout=Layout(height='25px', width='50%'),…

+-----------------------+-------------------+---------+---------+---------------------------------------------------------------------------------------------------------------------------------------------+-------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------+
|committed_at           |snapshot_id        |parent_id|operation|manifest_list                                                                                                                                |summary                                                                                                                                                                                                                                                                  

### Setting table property for WAP and Spark Session for WAP branch

In [39]:
# Set write.wap.enabled=true for table to write data into WAP Pattern.
#### After setting this property, unable to see Table DDL statement in Athena but data is queryable.. Need to look more on this later, but for now make sure to unset it after everything is done.
spark.sql(f"ALTER TABLE {prod_table} SET TBLPROPERTIES ('write.wap.enabled'='true')")

VBox()

FloatProgress(value=0.0, bar_style='info', description='Progress:', layout=Layout(height='25px', width='50%'),…

DataFrame[]

In [41]:
## Check table DDL
spark.sql(f"show create table {prod_table}").collect()[0]['createtab_stmt']

VBox()

FloatProgress(value=0.0, bar_style='info', description='Progress:', layout=Layout(height='25px', width='50%'),…

"CREATE TABLE glue.nyc_tlc.yellow_taxi_trips (\n  VendorID INT,\n  tpep_pickup_datetime TIMESTAMP_NTZ,\n  tpep_dropoff_datetime TIMESTAMP_NTZ,\n  passenger_count BIGINT,\n  trip_distance DOUBLE,\n  RatecodeID BIGINT,\n  store_and_fwd_flag STRING,\n  PULocationID INT,\n  DOLocationID INT,\n  payment_type BIGINT,\n  fare_amount DOUBLE,\n  extra DOUBLE,\n  mta_tax DOUBLE,\n  tip_amount DOUBLE,\n  tolls_amount DOUBLE,\n  improvement_surcharge DOUBLE,\n  total_amount DOUBLE,\n  congestion_surcharge DOUBLE,\n  Airport_fee DOUBLE,\n  month INT,\n  year INT)\nUSING iceberg\nPARTITIONED BY (year, month)\nLOCATION 's3://aws-data-068910838149-ap-south-1/nyc_tlc/yellow_taxi_trips'\nTBLPROPERTIES (\n  'current-snapshot-id' = '2498870574022386457',\n  'format' = 'iceberg/parquet',\n  'format-version' = '2',\n  'write.parquet.compression-codec' = 'snappy',\n  'write.wap.enabled' = 'true')\n"

In [43]:
# setting WAP branch to Audit Branch in Spark Session so new data is written into Audit Branch
# this will make the default branch from main to audit_branch for this particular Spark Session i.e. all the queries being run without any refrence to other branch will run on audit_branch.
spark.conf.set("spark.wap.branch", audit_branch)

VBox()

FloatProgress(value=0.0, bar_style='info', description='Progress:', layout=Layout(height='25px', width='50%'),…

## Write

- Writing data into the prod table.
- As `spark.wap.branch` is set to `audit_branch`, data will be written default into this branch.

In [47]:
# Appending data into Iceberg Prod Table
source_df.writeTo(prod_table).append()

VBox()

FloatProgress(value=0.0, bar_style='info', description='Progress:', layout=Layout(height='25px', width='50%'),…

In [48]:
# Verifying the snapshot to see if a new snapshot is created for the same append
spark.sql(f"select * from {prod_table}.snapshots").show()

VBox()

FloatProgress(value=0.0, bar_style='info', description='Progress:', layout=Layout(height='25px', width='50%'),…

+--------------------+-------------------+-------------------+---------+--------------------+--------------------+
|        committed_at|        snapshot_id|          parent_id|operation|       manifest_list|             summary|
+--------------------+-------------------+-------------------+---------+--------------------+--------------------+
|2023-12-27 11:53:...|2498870574022386457|               null|   append|s3://aws-data-068...|{spark.app.id -> ...|
|2023-12-27 12:28:...| 215507697857352501|2498870574022386457|   append|s3://aws-data-068...|{spark.app.id -> ...|
+--------------------+-------------------+-------------------+---------+--------------------+--------------------+

In [49]:
# Verifying refs: snapshot_id must be updated for audit_102023 branch
spark.sql(f"select * from {prod_table}.refs").show(truncate=False)

VBox()

FloatProgress(value=0.0, bar_style='info', description='Progress:', layout=Layout(height='25px', width='50%'),…

+------------+------+-------------------+-----------------------+---------------------+----------------------+
|name        |type  |snapshot_id        |max_reference_age_in_ms|min_snapshots_to_keep|max_snapshot_age_in_ms|
+------------+------+-------------------+-----------------------+---------------------+----------------------+
|main        |BRANCH|2498870574022386457|null                   |null                 |null                  |
|audit_102023|BRANCH|215507697857352501 |null                   |null                 |null                  |
+------------+------+-------------------+-----------------------+---------------------+----------------------+

In [50]:
## Let's verify data in prod table:
spark.table(prod_table).groupBy("year", "month").count().show()

VBox()

FloatProgress(value=0.0, bar_style='info', description='Progress:', layout=Layout(height='25px', width='50%'),…

+----+-----+-------+
|year|month|  count|
+----+-----+-------+
|2023|   10|3522285|
|2023|    9|2846722|
+----+-----+-------+

**Yikes..!!!!** Data in Prod table is CHANGED ?!?!? Not Really..!!

- `spark.wap.branch` is set to `audit_branch` so the actual branch that is being queried is `audit_branch` and not `main` branch.

In [51]:
## Let's check the data in main branch
spark.read.option("BRANCH", "main").table(prod_table).groupBy("year", "month").count().show()

VBox()

FloatProgress(value=0.0, bar_style='info', description='Progress:', layout=Layout(height='25px', width='50%'),…

+----+-----+-------+
|year|month|  count|
+----+-----+-------+
|2023|    9|2846722|
+----+-----+-------+

This verifies that all the read/write operation with `wap.branch` set to `audit_branch` is happening in `audit_branch`

## Audit

### Auditing data present in Audit Branch.

Let's assume as per our application data quality standards:
- It shouldn't have any data with `total_amount` as negative.
- The completeness of VendorID should be 1.0 i.e. there shouldn't be any `null` in `VendorID` field.
- `payment_type` should contain only discrete numeric values from 1 to 6

So these are mainly nothing but data quality rules that you can run on the `audit branch` and decide what you want to do with the data for e.g.

- Discard the entire snapshot as the DQ doesn't meet the expectation.
- DELETE the rows not meeting DQ standards and preserve such records somewhere else.
- Fix the data via some logic like populating `null` and `missing` value with some logic.

In [52]:
# Reading data from Audit Branch:
audit_data = spark.read.option("BRANCH", audit_branch).table(prod_table)

# check if there are any rows with negative total_amount
neg_amt_df = audit_data.filter(col("total_amount") < 0)

VBox()

FloatProgress(value=0.0, bar_style='info', description='Progress:', layout=Layout(height='25px', width='50%'),…

In [53]:
if not neg_amt_df.isEmpty():
    # write rejected records in some table or someplace before deleting them.
    # neg_amt_df.write.partitionedBy("year","month").parquet("bad-data-location")
    spark.sql(f"DELETE FROM {prod_table} where total_amount < 0")

VBox()

FloatProgress(value=0.0, bar_style='info', description='Progress:', layout=Layout(height='25px', width='50%'),…

DataFrame[]

In [54]:
# Checking data in main branch
spark.read.option("BRANCH","main").table(prod_table).filter(col("total_amount") < 0).count()

VBox()

FloatProgress(value=0.0, bar_style='info', description='Progress:', layout=Layout(height='25px', width='50%'),…

29253

In [55]:
# Checking data in audit branch
spark.table(prod_table).filter(col("total_amount") < 0).count()

VBox()

FloatProgress(value=0.0, bar_style='info', description='Progress:', layout=Layout(height='25px', width='50%'),…

0

- Incase data doesn't need to publish on DQ Check failure. Just drop the audit_branch.

In [None]:
# dropping Audit Branch
# spark.sql(f"ALTER TABLE {prod_table} DROP BRANCH {audit_branch}")

- Data is cleaned and good as per the DQ Quality. Let's Publish this data.

## Publish

Once the Auditing is done and DQ is as expected or fixed. We can `Publish` these final records in the Production `main` branch.

In [57]:
# Fast forwarding Audit branch to main branch
spark.sql(f"""CALL glue.system.fast_forward('{prod_table}', 'main', '{audit_branch}')""").show()

VBox()

FloatProgress(value=0.0, bar_style='info', description='Progress:', layout=Layout(height='25px', width='50%'),…

+--------------+-------------------+-------------------+
|branch_updated|       previous_ref|        updated_ref|
+--------------+-------------------+-------------------+
|          main|2498870574022386457|8657093001848818796|
+--------------+-------------------+-------------------+

In [58]:
# checking data in main branch
main_df = spark.read.option("BRANCH","main").table(prod_table)
main_df.groupBy("year", "month").count().show(), main_df.filter(col("total_amount") < 0).count()

VBox()

FloatProgress(value=0.0, bar_style='info', description='Progress:', layout=Layout(height='25px', width='50%'),…

+----+-----+-------+
|year|month|  count|
+----+-----+-------+
|2023|   10|3485320|
|2023|    9|2817469|
+----+-----+-------+

(None, 0)

### Unsetting from Spark Session

In [59]:
# Unset from SparkSession
spark.conf.unset('spark.wap.branch')

VBox()

FloatProgress(value=0.0, bar_style='info', description='Progress:', layout=Layout(height='25px', width='50%'),…

### Unsetting from Table Properties

In [60]:
### unset from table properties so that create table ddl can be seen again on Athena.
spark.sql(f"ALTER TABLE {prod_table} UNSET TBLPROPERTIES ('write.wap.enabled')")

VBox()

FloatProgress(value=0.0, bar_style='info', description='Progress:', layout=Layout(height='25px', width='50%'),…

DataFrame[]

### Droping the Audit branch

In [None]:
# Once the fast forward is done, both the branch will be pointing to the same snapshot so audit branch can be dropped if required.
spark.sql(f"ALTER TABLE {prod_table} DROP BRANCH {audit_branch}")