In [1]:
import findspark
findspark.init()

In [2]:
%run ./14-streaming-incremental-aggregates.ipynb

In [3]:
import shutil
import os
import time
from pyspark.sql import SparkSession
from delta import *

In [4]:
spark = SparkSession.builder.master("local[*]")\
                    .appName("Streaming Incremental Aggregates")\
                    .getOrCreate()
# spark = SparkSession.builder \
#     .appName("MedallionApproachTest") \
#     .config("spark.jars.packages", "io.delta:delta-core_2.12:2.1.0") \
#     .getOrCreate()
# spark = SparkSession.builder \
#         .appName("Streaming Incremental Aggregates") \
#         .config("spark.sql.catalogImplementation", "hive") \
#         .config("spark.jars.packages", "io.delta:delta-core_2.12:2.1.0") \
#         .master("local[*]") \
#         .getOrCreate()


In [5]:
class AggregationTestSuite():
    def __init__(self):
        self.base_data_dir = "data/invoices"
        self.spark_warehouse_dir = "spark-warehouse"

    def cleanTests(self):
        print(f"Starting Cleanup...", end='')
        spark.sql("drop table if exists invoices_bz")
        spark.sql("drop table if exists customer_rewards")

        def remove_dir(path):
            if os.path.exists(path):
                shutil.rmtree(path)

        remove_dir(f"{self.spark_warehouse_dir}/invoices_bz")
        remove_dir(f"{self.spark_warehouse_dir}/customer_rewards")
        # spark.sql(f"CREATE TABLE customer_rewards(CustomerCardNo STRING, TotalAmount DOUBLE, TotalPoints DOUBLE)")

        remove_dir(f"{self.base_data_dir}/checkpoint/invoices_bz")
        remove_dir(f"{self.base_data_dir}/checkpoint/customer_rewards")

        remove_dir(f"{self.base_data_dir}/data/aggregate/invoices")
        print("Done removing directories.")
        os.makedirs(f"{self.base_data_dir}/data/aggregate/invoices")
        print("Done")

    def ingestData(self, itr):
        print(f"\tStarting Ingestion...", end='')
        shutil.copy(f"{self.base_data_dir}/invoices-{itr}.json", f"{self.base_data_dir}/data/aggregate/invoices/")
        print("Done")

    def assertBronze(self, expected_count):
        print(f"\tStarting Bronze validation...", end='')
        actual_count = spark.sql("select count(*) from invoices_bz").collect()[0][0]
        assert expected_count == actual_count, f"Test failed! actual count is {actual_count}"
        print("Done")

    def assertGold(self, expected_value):
        print(f"\tStarting Gold validation...", end='')
        actual_value = spark.sql("select TotalAmount from customer_rewards").filter("CustomerCardNo = '2262471989'").collect()[0][0]
        assert expected_value == actual_value, f"Test failed! actual value is {actual_value}"
        print("Done")

    def waitForMicroBatch(self, sleep=120):
        import time
        print(f"\tWaiting for {sleep} seconds...", end='')
        time.sleep(sleep)
        print("Done.")    

    def runTests(self):
        self.cleanTests()
        # Ensure the customer_rewards table is created
        # create in memory table in spark-warehouse

        spark.conf.set("spark.sql.streaming.stateStore.rpoviderClass",
                       "org.apache.spark.sql.execution.streaming.state.RocksDBStateStoreProvider")
        bzStream = Bronze()
        bzQuery = bzStream.process()
        gdStream = Gold()
        gdQuery = gdStream.process()       

        print("\nTesting first iteration of invoice stream...") 
        self.ingestData(1)
        self.waitForMicroBatch()        
        self.assertBronze(501)
        self.assertGold(36859)
        print("Validation passed.\n")

        print("\nTesting second iteration of invoice stream...") 
        self.ingestData(2)
        self.waitForMicroBatch()        
        self.assertBronze(501+500)
        self.assertGold(36859+20740)
        print("Validation passed.\n")

        print("\nTesting second iteration of invoice stream...") 
        self.ingestData(3)
        self.waitForMicroBatch()        
        self.assertBronze(501+500+590)
        self.assertGold(36859+20740+31959)
        print("Validation passed.\n")

        bzQuery.stop()
        gdQuery.stop()

In [6]:
aTS = AggregationTestSuite()
aTS.runTests()	

Starting Cleanup...Done removing directories.
Done

Starting Bronze Stream...Done

Starting Silver Stream...
Testing first iteration of invoice stream...
	Starting Ingestion...Done
	Waiting for 120 seconds...Done.
	Starting Bronze validation...Done
	Starting Gold validation...

AnalysisException: [TABLE_OR_VIEW_NOT_FOUND] The table or view `customer_rewards` cannot be found. Verify the spelling and correctness of the schema and catalog.
If you did not qualify the name with a schema, verify the current_schema() output, or qualify the name with the correct schema and catalog.
To tolerate the error on drop use DROP VIEW IF EXISTS or DROP TABLE IF EXISTS.; line 1 pos 24;
'Project ['TotalAmount]
+- 'UnresolvedRelation [customer_rewards], [], false
