## Introduction

This notebook applies changes to a table in a datalake, rollbacks the changes and reapplies the changes.

Points to Note:

1. Each change records a changeId corresponding to the epoch time of the change e.g. changeId=1558676810
2. If the table location is in "s3://&lt;bucket>/&lt;prefix>/TABLE", the changes to each table are stored in the Changes location in "s3://&lt;bucket>/&lt;prefix>/TABLE_changes".
3. For each change the change file is a json file containing the file and version of each change.
4. A rollback simply deletes the version corresponding to a change, and deletes the record of the change from the Changes location.
5. In addition, each change also stages the changed files in "s3://&lt;bucket>/&lt;prefix>/TABLE_staging"
6. The changed files contain the original file in the prefix hence it is trivial to apply a change.
7. A change can be applied if it is the last change in the Staging location but is not recorded in the Changes location.
8. A rollback to a previous changeId deletes all changes after the changeId and all file versions after the changeId is deleted.


Boto3 is required to be installed on EMR:
```
$> wget https://bootstrap.pypa.io/get-pip.py
$> sudo python3 get-pip.py
$> sudo /usr/local/bin/pip3 install boto3
```

In [None]:
%%sh
echo '{"kernel_python_credentials" : {"url": "http://172.31.35.13:8998/"}, "session_configs": 
{"executorMemory": "2g","executorCores": 2,"numExecutors":4}}' > ~/.sparkmagic/config.json
less ~/.sparkmagic/config.json

Let's set executor size to a single core and 7 GB RAM. Note that the ratio 1:7 matches the instances used in the cluster - r5.2xlarge.

In [1]:
%%configure -f
{"driverMemory": "8000M","executorMemory": "6000M", "executorCores": 1, "conf":  { "spark.task.maxFailures":"10","spark.executor.memoryOverhead":"1G"}}

In [2]:
from pyspark.sql.functions import input_file_name
from __future__ import print_function
from pyspark.sql import Row
from pyspark.sql.functions import UserDefinedFunction
from pyspark.sql.types import StringType
from pyspark.sql import functions as f
from pyspark.sql.types import IntegerType
from urllib.parse import unquote
import random
import calendar
import boto3
import time
import logging
import json

Starting Spark application


ID,YARN Application ID,Kind,State,Spark UI,Driver log,Current session?
163,application_1555125850663_0167,pyspark,idle,Link,Link,✔


SparkSession available as 'spark'.


Utility functions used to generate data, run tests. etc. 

In [3]:
max_n=767435999
udf = UserDefinedFunction(lambda x: random.randint(0,100), IntegerType())    

def generate_delete_data(n,key):
    ids=[str(random.randrange(0, max_n)) for k in range(n)]
    return (ids,sc.parallelize(ids).map(lambda x:Row(x)).toDF([key]))

def generate_update_data(table,n,key):
    ids=[str(random.randrange(0, max_n)) for k in range(n)]
    sql="SELECT * FROM %s WHERE %s IN (%s)"%(table,key,",".join(ids))
    #print ("SQL : "+sql)
    to_update_df=spark.sql(sql)
    name='QUANTITY'
    paths=spark.sql("SELECT input_file_name() as INPUT FROM %s limit 1"%(table)).rdd.map(lambda x:x[0]).take(1)
    input_df=spark.read.load(paths)
    to_update_df_1 = to_update_df.select(*[udf(column).alias(name) if column == name else column for column in input_df.columns])
    return (ids,to_update_df_1)

def get_delete_count(table,key,ids):
    sql="SELECT count(*) FROM %s WHERE %s IN (%s)"%(table,key,",".join(ids))
    count=spark.sql(sql).rdd.map(lambda x:x[0]).take(1)[0]
    print ('Count of matching records : %d'%(count))

def get_update_count(table,key,ids):
    sql="SELECT count(*), SUM(QUANTITY) FROM %s WHERE %s IN (%s)"%(table,key,",".join(ids))
    count=spark.sql(sql).rdd.map(lambda x:(x[0],x[1])).take(1)
    print ('Count of matching records : %d'%(count[0][0]))
    print ('Sum of Quantity : %d'%(count[0][1]))

## Core Functions

The core operations implemented are:
    
   1. **apply_to_datalake** : Applies Deletes and Updates to a table in the datalake.
   2. **rollback** : Rollbacks the most recent change to a database.
   3. **rollback_to_change** : Rollbacks to a previous change.
   4. **apply_change** : Re-applies a previous change.

In [4]:
table_name='cdc.sales_order_fact'
primary_key='order_id'

def get_table_location(t):
    tables={"cdc.sales_order_fact":"s3://<bucket>/cdc/SALES_ORDER_FACT"}
    return tables[t]

In [5]:
def replace_file(src_bucket,target_bucket,key):    
    target_url=key.split('/')[3].split("input=")[1]
    target_url = unquote(target_url)
    target_key=target_url.split("s3://%s/"%(target_bucket))[1]

    s3 = boto3.client('s3')
    resp=s3.copy_object(Bucket=target_bucket, CopySource=src_bucket+"/"+key, Key=target_key)
    return (target_key,resp['VersionId'])

In [6]:
# Applies changes to the datalake
def apply_to_datalake(table,key,delete_df,update_df):
    
    try:
        ## Step 1: Get files impacted
        table_location=get_table_location(table)
        values=update_df.select(key).union(delete_df.select(key)).rdd.map(lambda x:x[0]).collect()
        sql="SELECT %s, input_file_name() as input FROM %s WHERE %s IN (%s)"%(key,table,key,",".join(list(set(values))))
        #print ("SQL : "+sql)
        start_time = time.time()
        result=spark.sql(sql).rdd.map(lambda x:(x[0],x[1])).collect()
        elapsed_time = time.time() - start_time
        keyfilemap={x[0]:x[1] for x in result}
        paths=list(set([x[1] for x in result]))
        print ("Identified impacted files in %f seconds"%(elapsed_time))
        assert (len(paths) > 0),"No records matched the keys."
        print ("%d files impacted"%(len(paths)))
    
        ## Step 2: Read the data from those files using Spark/Glue.
        input_df=spark.read.load(paths)
        #input_record_count = input_df.count()
        #print ("Impacted Record Count : %d"%(input_record_count))
        # Add the Input filename as a column
        input_df=input_df.withColumn("INPUT", input_file_name())
        #filename=input_df.select("INPUT").rdd.map(lambda x: x[0]).take(1)[0]
        #assert (len(filename) > 0),"Input Filename not populated."
    
        ## Step 3: Join with the incremental data to filter out deletes and updates
        keyfilemaps = sc.broadcast(keyfilemap)
        lookup_file_udf = UserDefinedFunction(lambda x: keyfilemaps.value[x], StringType()) 
        update_df=update_df.withColumn('INPUT', lookup_file_udf(f.col(key)))
        deleted_filter="ORDER_ID NOT IN (%s)"%",".join(values)
        changed_df=input_df.filter(deleted_filter)
        #changed_record_count= changed_df.count()
        #updated_record_count= update_df.count()
        result_df=changed_df.union(update_df).cache()
        #print ("Records to Remove: %d"%(input_record_count-changed_record_count))
        #print ("Records to Update: %d"%(updated_record_count))
        #expected_record_count=changed_record_count+updated_record_count
        #print ("Expected Result Record Count: %d"%(expected_record_count))

        ## Step 4: Write out new files back out to staging location in S3
        change_id=str(calendar.timegm(time.gmtime()))
        staging_location="%s_staging/changeId=%s"%(table_location,change_id)
        print ("change_id : %s"%(change_id))
        partition_columns=["input"]
        start_time = time.time()
        result_df.repartition(*partition_columns).write.partitionBy(partition_columns).parquet(staging_location)
        elapsed_time = time.time() - start_time
        print ("Staged files in %f seconds"%(elapsed_time))
        #print ("Staging Location: %s"%(staging_location))
    
        ## Step 5: Run a parallel job to replace the old files with the new files.
        apply_change(table,change_id)
    
        ## Validate counts
        #input_record_count = input_df.count()
        #print ("Final Record Count : %d"%(input_record_count))
        #assert (input_record_count == expected_record_count),"Final Record Counts do not match."
        return change_id
    except AssertionError:
        print ("Error: ")
        logging.error("Assertion Exception : ", exc_info=True)

In [7]:
def rollback(table,change_id):
    table_location=get_table_location(table)
    change_key="%s_changes/changeId=%s"%(table_location,change_id)
    print (change_key)
    get_last_modified = lambda obj: int(obj['LastModified'].strftime('%s'))
    start_time = time.time()
    result=[]
    s3 = boto3.client('s3')
    bucket_name=change_key.split('/')[2]
    prefix='/'.join(change_key.split('/')[3:])
    resp = s3.get_object(Bucket=bucket_name, Key=prefix) 
    changes = resp["Body"].read().decode()
    files=json.loads(changes)
    for k,v in files.items():
        # delete the versions
        resp=s3.delete_object(Bucket=bucket_name, Key=k, VersionId=v)
        respCode=resp['ResponseMetadata']['HTTPStatusCode']
        result.append(respCode)
    # delete the change record as well
    resp=s3.delete_object(Bucket=bucket_name, Key=prefix)
    #print(resp)
    elapsed_time = time.time() - start_time
    print ("Change %s rolled back in %f seconds"%(change_id,elapsed_time))

In [8]:
def apply_change(table,change_id):
    table_location=get_table_location(table)
    s3 = boto3.client('s3')
    staging_location="%s_staging/changeId=%s"%(table_location,change_id)
    bucket=staging_location.split('/')[2]
    prefix='/'.join(staging_location.split('/')[3:])
    resp = s3.list_objects_v2(Bucket=bucket,Prefix=prefix)
    keys=[obj['Key'] for obj in resp['Contents'] if obj['Size'] > 0]
    files=sc.parallelize(keys).map(lambda x: Row(x)).toDF(["keys"])
    files=files.repartition(files.count())
    start_time = time.time()
    
    # replace original files
    results=files.rdd.map(lambda x:replace_file(bucket,bucket,x["keys"])).collect()
    j={k[0]:k[1] for k in results}
    j=json.dumps(j, ensure_ascii=False)
    change_key="%s_changes/changeId=%s"%(table_location,change_id)
    change_key='/'.join(change_key.split('/')[3:])
    #print (change_key)
    
    # record the change id
    resp=s3.put_object(Body=j, Bucket=bucket, Key=change_key)
    #print (resp['ResponseMetadata']['HTTPStatusCode'])
    elapsed_time = time.time() - start_time
    print ("Applied changes in %f seconds"%(elapsed_time))
    assert(len(results)==len(keys)),"Count of files written does not match."
    

## Generate some Delete and Update records

Let's Generate some Delete and Update records

In [9]:
# Let's generate some random data to delete and update
n=1000
delete_ids,cdc_delete_df=generate_delete_data(n,primary_key)
update_ids,cdc_update_df=generate_update_data(table_name,n,primary_key)

----------------------------------------
Exception happened during processing of request from ('127.0.0.1', 54628)
Traceback (most recent call last):
  File "/usr/lib64/python3.6/socketserver.py", line 320, in _handle_request_noblock
    self.process_request(request, client_address)
  File "/usr/lib64/python3.6/socketserver.py", line 351, in process_request
    self.finish_request(request, client_address)
  File "/usr/lib64/python3.6/socketserver.py", line 364, in finish_request
    self.RequestHandlerClass(request, client_address, self)
  File "/usr/lib64/python3.6/socketserver.py", line 724, in __init__
    self.handle()
  File "/usr/lib/spark/python/lib/pyspark.zip/pyspark/accumulators.py", line 266, in handle
    poll(authenticate_and_accum_updates)
  File "/usr/lib/spark/python/lib/pyspark.zip/pyspark/accumulators.py", line 241, in poll
    if func():
  File "/usr/lib/spark/python/lib/pyspark.zip/pyspark/accumulators.py", line 254, in authenticate_and_accum_updates
    received_to

In [10]:
# Let's get the counts of records to be deleted and values before 
# records are updated.
get_delete_count(table_name,primary_key,delete_ids)
get_update_count(table_name,primary_key,update_ids)

Count of matching records : 2905
Count of matching records : 3080
Sum of Quantity : 154122

## Apply the changes

Let's apply the changes to our table.

In [11]:
# Applying the changes
change_id=apply_to_datalake(table_name,primary_key,cdc_delete_df,cdc_update_df)
print ('change_id=%s'%(change_id))

Identified impacted files in 63.430865 seconds
693 files impacted
change_id : 1559349039
Staged files in 1368.523836 seconds
Applied changes in 68.640070 seconds
change_id=1559349039

## Validate the results:

Let's run some queries to validate the results after the operation.

In [12]:
# Let's get the counts of matching records after deletion  
# and values after records are updated.
get_delete_count(table_name,primary_key,delete_ids)
get_update_count(table_name,primary_key,update_ids)

Count of matching records : 0
Count of matching records : 3080
Sum of Quantity : 154704

As expected:
1. The matching records for the Deleted Ids is 0, so the records are deleted.
2. The count of matching records for the Updated Ids remain the same, but the sum of the Quantity field which we are updating is changed.

## Rollback the changes

Let's rollback the changes now.

In [13]:
rollback(table_name,change_id)

s3://<bucket>/cdc/SALES_ORDER_FACT_changes/changeId=1559349039
Change 1559349039 rolled back in 17.296286 seconds

The changes were rolled back. The rollback is multiple lightweight S3 delete operations on the versions of a file corresponding to changeIds.

## Validating the Rollback

Let's now validate that the counts remain the same before the changes were applied.

In [15]:
get_delete_count(table_name,primary_key,delete_ids)
get_update_count(table_name,primary_key,update_ids)

Count of matching records : 2905
Count of matching records : 3080
Sum of Quantity : 154122

We can see that the counts match the counts before the changes.

## Reapply a change

Let us reapply the last change.

In [16]:
apply_change(table_name,change_id)

Applied changes in 67.277989 seconds

In [17]:
# Let's recompute the counts.
get_delete_count(table_name,primary_key,delete_ids)
get_update_count(table_name,primary_key,update_ids)

Count of matching records : 0
Count of matching records : 3080
Sum of Quantity : 154704