# Storing PII Securely

Adding a pseudonymized key to incremental workloads is as simple as adding a transformation.

In this notebook, we'll examine design patterns for ensuring PII is stored securely and updated accurately. We'll also demonstrate an approach for processing delete requests to make sure these are captured appropriately.

<img src="https://files.training.databricks.com/images/ade/ADE_arch_users.png" width="60%" />

## Learning Objectives
By the end of this notebook, students will be able to:
- Apply incremental transformations to store data with pseudonymized keys
- Use windowed ranking to identify the most-recent records in a CDC feed

Begin by running the following cell to set up relevant databases and paths.

In [0]:
%run ../Includes/Classroom-Setup-6.2

Execute the following cell to create the **`users`** table.

In [0]:
%sql
CREATE TABLE users
(alt_id STRING, dob DATE, sex STRING, gender STRING, first_name STRING, last_name STRING, street_address STRING, city STRING, state STRING, zip INT, updated TIMESTAMP)
USING DELTA
LOCATION '${da.paths.working_dir}/users'

## ELT with Pseudonymization
The data in the **`user_info`** topic contains complete row outputs from a Change Data Capture feed.

There are three values for **`update_type`** present in the data: **`new`**, **`update`**, and **`delete`**.

The **`users`** table will be implemented as a Type 1 table, so only the most recent value matters

Run the cell below to visually confirm that both **`new`** and **`update`** records contain all the fields we need for our **`users`** table.

In [0]:
from pyspark.sql import functions as F

schema = """
    user_id LONG, 
    update_type STRING, 
    timestamp FLOAT, 
    dob STRING, 
    sex STRING, 
    gender STRING, 
    first_name STRING, 
    last_name STRING, 
    address STRUCT<
        street_address: STRING, 
        city: STRING, 
        state: STRING, 
        zip: INT>"""

users_df = (spark.table("bronze")
                 .filter("topic = 'user_info'")
                 .select(F.from_json(F.col("value").cast("string"), schema).alias("v")).select("v.*")
                 .filter(F.col("update_type").isin(["new", "update"])))

display(users_df)

## Deduplication with Windowed Ranking

We've previously explored some ways to remove duplicate records:
- Using Delta Lake's **`MERGE`** syntax, we can update or insert records based on keys, matching new records with previously loaded data
- **`dropDuplicates`** will remove exact duplicates within a table or incremental microbatch

Now we have multiple records for a given primary key BUT these records are not identical. **`dropDuplicates`** will not work to remove these records, and we'll get an error from our merge statement if we have the same key present multiple times.

Below, a third approach for removing duplicates is shown below using the <a href="http://spark.apache.org/docs/latest/api/python/reference/api/pyspark.sql.Window.html?highlight=window#pyspark.sql.Window" target="_blank">PySpark Window class</a>.

In [0]:
from pyspark.sql.window import Window

window = Window.partitionBy("user_id").orderBy(F.col("timestamp").desc())

ranked_df = (users_df.withColumn("rank", F.rank().over(window))
                     .filter("rank == 1")
                     .drop("rank"))
display(ranked_df)

As desired, we get only the newest (**`rank == 1`**) entry for each unique **`user_id`**.

Unfortunately, if we try to apply this to a streaming read of our data, we'll learn that
> Non-time-based windows are not supported on streaming DataFrames

Uncomment and run the following cell to see this error in action:

In [0]:
# ranked_df = (spark.readStream
#                   .table("bronze")
#                   .filter("topic = 'user_info'")
#                   .select(F.from_json(F.col("value").cast("string"), schema).alias("v"))
#                   .select("v.*")
#                   .filter(F.col("update_type").isin(["new", "update"]))
#                   .withColumn("rank", F.rank().over(window))
#                   .filter("rank == 1").drop("rank"))

# display(ranked_df)

Luckily we have a workaround to avoid this restriction.

## Implementing Streaming Ranked De-duplication

As we saw previously, when apply **`MERGE`** logic with a Structured Streaming job, we need to use **`foreachBatch`** logic.

Recall that while we're inside a streaming microbatch, we interact with our data using batch syntax.

This means that if we can apply our ranked **`Window`** logic within our **`foreachBatch`** function, we can avoid the restriction throwing our error.

The code below sets up all the incremental logic needed to load in the data in the correct schema from the bronze table. This includes:
- Filter for the **`user_info`** topic
- Dropping identical records within the batch
- Unpack all of the JSON fields from the **`value`** column into the correct schema
- Update field names and types to match the **`users`** table schema
- Use the salted hash function to cast the **`user_id`** to **`alt_id`**

In [0]:
salt = "BEANS"

unpacked_df = (spark.readStream
                    .table("bronze")
                    .filter("topic = 'user_info'")
                    .select(F.from_json(F.col("value").cast("string"), schema).alias("v"))
                    .select("v.*")
                    .select(F.sha2(F.concat(F.col("user_id"), F.lit(salt)), 256).alias("alt_id"),
                            F.col("timestamp").cast("timestamp").alias("updated"),
                            F.to_date("dob", "MM/dd/yyyy").alias("dob"), "sex", "gender", "first_name", "last_name", "address.*", "update_type"))

The updated Window logic is provided below. Note that this is being applied to each **`micro_batch_df`** to result in a local **`ranked_df`** that will be used for merging.
 
For our **`MERGE`** statement, we need to:
- Match entries on our **`alt_id`**
- Update all when matched **if** the new record has is newer than the previous entry
- When not matched, insert all

As before, use **`foreachBatch`** to apply merge operations in Structured Streaming.

In [0]:
from pyspark.sql.window import Window

window = Window.partitionBy("alt_id").orderBy(F.col("updated").desc())

def batch_rank_upsert(microBatchDF, batchId):
    
    (microBatchDF.filter(F.col("update_type").isin(["new", "update"]))
                 .withColumn("rank", F.rank().over(window))
                 .filter("rank == 1")
                 .drop("rank")
                 .createOrReplaceTempView("ranked_updates"))
    
    microBatchDF._jdf.sparkSession().sql("""
        MERGE INTO users u
        USING ranked_updates r
        ON u.alt_id=r.alt_id
            WHEN MATCHED AND u.updated < r.updated
              THEN UPDATE SET *
            WHEN NOT MATCHED
              THEN INSERT *
    """)

Now we can apply this function to our data. 

Here, we'll run a trigger-available-now batch to process all records.

In [0]:
query = (unpacked_df.writeStream
                    .foreachBatch(batch_rank_upsert)
                    .outputMode("update")
                    .option("checkpointLocation", f"{DA.paths.checkpoints}/batch_rank_upsert")
                    .trigger(availableNow=True)
                    .start())

query.awaitTermination()

The **`users`** table should only have 1 record for each unique ID.

In [0]:
count_a = spark.table("users").count()
count_b = spark.table("users").select("alt_id").distinct().count()
assert count_a == count_b
print("All tests passed.")

Run the following cell to delete the tables and files associated with this lesson.

In [0]:
DA.cleanup()