In [1]:
%run nb_03_fact_record_wrangler

StatementMeta(, fd9a733b-21e8-4c42-89a2-aac6afe0b551, 3, Finished, Available, Finished)

In [2]:
import unittest
import datetime
from pyspark.sql import SparkSession, Row
from pyspark.sql.types import StructType, StructField, DateType, StringType, ShortType, IntegerType
from pyspark.testing.utils import assertDataFrameEqual, assertSchemaEqual
from delta.tables import DeltaTable

class TestFactRecordWrangler(unittest.TestCase):

    @classmethod
    def setUpClass(cls):

        cls.spark = SparkSession.builder.appName('fact_record_test').getOrCreate()
        cls.delta_table_name = 'fact_record_test'
        FactRecordWrangler.create_delta_table(cls.spark, cls.delta_table_name)
        cls.spark.sql(f'DELETE FROM {cls.delta_table_name}')

    def expected_schema(self):
        return StructType([
            StructField('ReportedDate', DateType()),
            StructField('Suburb', StringType()),
            StructField('Postcode', ShortType()),
            StructField('DescID', IntegerType()),
            StructField('Count', IntegerType()),
        ])

    def test_extract_silver_df(self):

        # Prepare silver_df and dim_desc_table
        silver_data = [
            Row(ReportedDate=datetime.date(2023, 12, 5), Suburb='A', Postcode=1234, LevelOneDesc='L1', LevelTwoDesc='L2', LevelThreeDesc='L3', Count=5),
            Row(ReportedDate=datetime.date(2023, 12, 6), Suburb='B', Postcode=5678, LevelOneDesc='L4', LevelTwoDesc='L5', LevelThreeDesc='L6', Count=7),
        ]
        silver_schema = StructType([
            StructField('ReportedDate', DateType()),
            StructField('Suburb', StringType()),
            StructField('Postcode', ShortType()),
            StructField('LevelOneDesc', StringType()),
            StructField('LevelTwoDesc', StringType()),
            StructField('LevelThreeDesc', StringType()),
            StructField('Count', IntegerType()),
        ])
        silver_df = spark.createDataFrame(silver_data, silver_schema)

        desc_data = [
            Row(LevelOneDesc='L1', LevelTwoDesc='L2', LevelThreeDesc='L3', DescID=1),
            Row(LevelOneDesc='L4', LevelTwoDesc='L5', LevelThreeDesc='L6', DescID=2),
        ]
        desc_schema = StructType([
            StructField('LevelOneDesc', StringType()),
            StructField('LevelTwoDesc', StringType()),
            StructField('LevelThreeDesc', StringType()),
            StructField('DescID', IntegerType()),
        ])
        dim_desc_table = spark.createDataFrame(desc_data, desc_schema)
        result_df = FactRecordWrangler.extract_silver_df(silver_df, dim_desc_table)

        expected_data = [
            (datetime.date(2023, 12, 5), 'A', 1234, 1, 5),
            (datetime.date(2023, 12, 6), 'B', 5678, 2, 7),
        ]
        expected_df = spark.createDataFrame(expected_data, self.expected_schema())
        assertDataFrameEqual(result_df, expected_df)

    def test_create_delta_table_schema(self):

        table_schema = spark.table(self.delta_table_name).schema
        assertSchemaEqual(table_schema, self.expected_schema())

    def test_upsert_delta_table_insert(self):
        # Insert a new row
        data = [
            (datetime.date(2023, 12, 5), 'A', 1234, 1, 5),
        ]
        schema = self.expected_schema()
        df = spark.createDataFrame(data, schema)
        delta_table = DeltaTable.forName(spark, self.delta_table_name)

        FactRecordWrangler.upsert_delta_table(delta_table, df)
        result_df = spark.sql(f'SELECT * FROM {self.delta_table_name}')

        assert result_df.count() == 1
        row = result_df.first()
        assert row['ReportedDate'] == datetime.date(2023, 12, 5)
        assert row['Suburb'] == 'A'
        assert row['Postcode'] == 1234
        assert row['DescID'] == 1
        assert row['Count'] == 5

        spark.sql(f'DELETE FROM {self.delta_table_name}')

    def test_upsert_delta_table_update(self):

        # Insert, then update Count
        schema = self.expected_schema()
        data = [
            (datetime.date(2023, 12, 5), 'A', 1234, 1, 5),
        ]
        df = spark.createDataFrame(data, schema)
        delta_table = DeltaTable.forName(spark, self.delta_table_name)

        FactRecordWrangler.upsert_delta_table(delta_table, df)

        # Update Count
        update_data = [
            (datetime.date(2023, 12, 5), 'A', 1234, 1, 10),
        ]
        update_df = spark.createDataFrame(update_data, schema)

        FactRecordWrangler.upsert_delta_table(delta_table, update_df)
        result_df = spark.sql(f'SELECT * FROM {self.delta_table_name}')
        
        assert result_df.count() == 1
        row = result_df.first()
        assert row['Count'] == 10

        spark.sql(f'DELETE FROM {self.delta_table_name}')

    def test_upsert_delta_table_no_duplicate(self):

        # Insert the same row again, should not duplicate
        data = [
            (datetime.date(2023, 12, 5), 'A', 1234, 1, 5),
        ]
        schema = self.expected_schema()
        df = spark.createDataFrame(data, schema)
        delta_table = DeltaTable.forName(spark, self.delta_table_name)

        FactRecordWrangler.upsert_delta_table(delta_table, df)
        FactRecordWrangler.upsert_delta_table(delta_table, df)
        
        result_df = spark.sql(f'SELECT * FROM {self.delta_table_name}')
        assert result_df.count() == 1

        spark.sql(f'DELETE FROM {self.delta_table_name}')

    def test_upsert_delta_table_multiple(self):

        # Insert multiple new rows
        data = [
            (datetime.date(2023, 12, 5), 'A', 1234, 1, 5),
            (datetime.date(2023, 12, 6), 'B', 5678, 2, 7),
        ]
        schema = self.expected_schema()
        df = spark.createDataFrame(data, schema)
        delta_table = DeltaTable.forName(spark, self.delta_table_name)

        FactRecordWrangler.upsert_delta_table(delta_table, df)

        result_df = spark.sql(f'SELECT * FROM {self.delta_table_name}')
        assert result_df.count() == 2

        spark.sql(f'DELETE FROM {self.delta_table_name}')

    @classmethod
    def tearDownClass(cls):

        cls.spark.sql(f'DROP TABLE IF EXISTS {cls.delta_table_name}')
        cls.spark.stop()

StatementMeta(, fd9a733b-21e8-4c42-89a2-aac6afe0b551, 4, Finished, Available, Finished)



In [3]:
test_case = TestFactRecordWrangler()
TestFactRecordWrangler.setUpClass()

StatementMeta(, fd9a733b-21e8-4c42-89a2-aac6afe0b551, 5, Finished, Available, Finished)

In [4]:
test_case.test_extract_silver_df()

StatementMeta(, fd9a733b-21e8-4c42-89a2-aac6afe0b551, 6, Finished, Available, Finished)

In [5]:
test_case.test_create_delta_table_schema()

StatementMeta(, fd9a733b-21e8-4c42-89a2-aac6afe0b551, 7, Finished, Available, Finished)

In [6]:
test_case.test_upsert_delta_table_insert()

StatementMeta(, fd9a733b-21e8-4c42-89a2-aac6afe0b551, 8, Finished, Available, Finished)

In [7]:
test_case.test_upsert_delta_table_update()

StatementMeta(, fd9a733b-21e8-4c42-89a2-aac6afe0b551, 9, Finished, Available, Finished)

In [8]:
test_case.test_upsert_delta_table_no_duplicate()

StatementMeta(, fd9a733b-21e8-4c42-89a2-aac6afe0b551, 10, Finished, Available, Finished)

In [9]:
test_case.test_upsert_delta_table_multiple()

StatementMeta(, fd9a733b-21e8-4c42-89a2-aac6afe0b551, 11, Finished, Available, Finished)

In [10]:
TestFactRecordWrangler.tearDownClass()

StatementMeta(, fd9a733b-21e8-4c42-89a2-aac6afe0b551, 12, Finished, Available, Finished)