In [0]:
%run "./01_config"

In [0]:
from pyspark.sql.functions import *
from datetime import datetime

class HistoryLoader():
    def __init__(self, env):
        Conf = Config()
        self.catalog = f"sbit_{env}_catalog"
        self.db_name = Conf.db_name

    def load_date_lookup(self, start_date="2020-01-01", end_date="2030-12-31"):
        print(f"Creating date_lookup data ({start_date} to {end_date})...", end='')
    
        df_base = spark.sql(f"SELECT explode(sequence(to_date('{start_date}'), to_date('{end_date}'), interval 1 day)) as date")
    
        date_df = df_base.select(
            "date",
            weekofyear("date").alias("week"),
            year("date").alias("year"),
            month("date").alias("month"),
            dayofweek("date").alias("dayofweek"),
            dayofmonth("date").alias("dayofmonth"),
            dayofyear("date").alias("dayofyear"),
            concat(year("date"), lit("-"), lpad(weekofyear("date"), 2, "0")).alias("week_part")
        )
    
        date_df.write.mode("overwrite").saveAsTable(f"{self.catalog}.{self.db_name}.date_lookup")
        print("OK")
        
        d1 = datetime.strptime(start_date, "%Y-%m-%d")
        d2 = datetime.strptime(end_date, "%Y-%m-%d")
        return (d2 - d1).days + 1

    def assert_count(self, table_name, expected_count):
        print(f"Validating record counts in {table_name}...", end='')
        actual_count = spark.read.table(f"{self.catalog}.{self.db_name}.{table_name}").count()
        
        assert actual_count == expected_count, f"Expected {expected_count:,} records, found {actual_count:,} in {table_name}"
        print(f"Found {actual_count:,} / Expected {expected_count:,} records: Success")

    def validate(self, expected_date_count):
        import time
        start = int(time.time())
        print(f"\nStarting historical data load validation...")

        self.assert_count("date_lookup", expected_date_count)
        
        print(f"Historical data load validation completed in {int(time.time()) - start} seconds")

    def load_history(self):
        import time
        start = int(time.time())
        print(f"\nStarting historical data load ...")
        
        expected_count = self.load_date_lookup()
    
        self.validate(expected_count)
        
        print(f"\nTotal historical data load process completed in {int(time.time()) - start} seconds")