In [None]:
import glob
import pandas as pd

from pyspark.sql.types import *
from pyspark.sql.functions import col, when, lit, date_format, to_date, current_timestamp
from delta.tables import *

class SACrimeRecordWrangler:

    @staticmethod
    def create_init_schema():
        return StructType([
            StructField('ReportedDate', StringType()),
            StructField('Suburb', StringType()),
            StructField('Postcode', StringType()),
            StructField('LevelOneDesc', StringType()),
            StructField('LevelTwoDesc', StringType()),
            StructField('LevelThreeDesc', StringType()),
            StructField('Count', IntegerType())
        ])


    @staticmethod
    def load_excel_files(spark, url, schema):
        excel_files = glob.glob(f'/lakehouse/default/{url}/*.xlsx')
        df = pd.concat((pd.read_excel(f) for f in excel_files))
        return spark.createDataFrame(df, schema)


    @staticmethod
    def load_csv_files(spark, url, schema):
        return spark.read \
            .format('csv') \
            .option('header', 'true') \
            .schema(schema) \
            .load(f'{url}/*.csv')


    @staticmethod
    def remove_all_na(df):
        return df.na.drop(how='all')


    @staticmethod
    def reported_date_str_to_date(df):
        return df.withColumn('ReportedDate', to_date(col('ReportedDate'), 'd/M/y'))


    @staticmethod
    def postcode_str_to_short(df):
        return df.withColumn('Postcode', col('Postcode').cast(ShortType()))


    @staticmethod
    def suburb_fill_null_empty(df):
        return df.withColumn('Suburb', when((col('Suburb').isNull() | (col('Suburb') == '')), lit('N/A')).otherwise(col('Suburb')))


    @staticmethod
    def postcode_fill_null(df):
        return df.na.fill(value=0, subset=['Postcode'])


    @staticmethod
    def cleanse_df(df):
        tmp_df = SACrimeRecordWrangler.remove_all_na(df)
        tmp_df = SACrimeRecordWrangler.reported_date_str_to_date(tmp_df)
        tmp_df = SACrimeRecordWrangler.postcode_str_to_short(tmp_df)
        tmp_df = SACrimeRecordWrangler.suburb_fill_null_empty(tmp_df)
        return SACrimeRecordWrangler.postcode_fill_null(tmp_df)


    @staticmethod
    def create_crime_records_silver_table(spark_session, table_name):
        DeltaTable.createIfNotExists(spark_session) \
            .tableName(table_name) \
            .addColumn('ReportedDate', DateType()) \
            .addColumn('Suburb', StringType()) \
            .addColumn('Postcode', ShortType()) \
            .addColumn('LevelOneDesc', StringType()) \
            .addColumn('LevelTwoDesc', StringType()) \
            .addColumn('LevelThreeDesc', StringType()) \
            .addColumn('Count', IntegerType()) \
            .addColumn('UpdatedTS', TimestampType()) \
            .execute()


    @staticmethod
    def upsert_delta_table(delta_table, df):

        df_updates = df

        match_condition = (
            'silver.ReportedDate = updates.ReportedDate and '
            'silver.Suburb = updates.Suburb and '
            'silver.Postcode = updates.Postcode and '
            'silver.LevelOneDesc = updates.LevelOneDesc and '
            'silver.LevelTwoDesc = updates.LevelTwoDesc and '
            'silver.LevelThreeDesc = updates.LevelThreeDesc'
        )
        
        delta_table.alias('silver') \
            .merge(
                df_updates.alias('updates'),
                match_condition
            ) \
            .whenMatchedUpdate(
                condition='silver.Count != updates.Count',
                set=
                {
                    'Count': 'updates.Count',
                    'UpdatedTS': current_timestamp()
                }
            ) \
            .whenNotMatchedInsert(values=
                {
                    'ReportedDate': 'updates.ReportedDate',
                    'Suburb': 'updates.Suburb',
                    'Postcode': 'updates.Postcode',
                    'LevelOneDesc': 'updates.LevelOneDesc',
                    'LevelTwoDesc': 'updates.LevelTwoDesc',
                    'LevelThreeDesc': 'updates.LevelThreeDesc',
                    'Count': 'updates.Count',
                    'UpdatedTS': current_timestamp()
                }
            ) \
            .execute()