In [None]:
import fnmatch
from pyspark.sql import SparkSession
from pyspark.sql.functions import col
from datetime import datetime
import os

print('||-------------run_stage_load.ipynb--------------||')
print('||-----')

# Load all REGISTERED files
registered_files_df = spark.sql("""
    SELECT * FROM lk_cdsa_bronze.meta_db.data_file
    WHERE file_status = 'REGISTERED'
""")

staged_count = 0
for file_row in registered_files_df.collect():
    try:
        file_id = file_row.file_id
        filename = file_row.filename
        object_id = file_row.object_id
        source_id = file_row.object_id
        batch_id = file_row.batch_id
        landing_directory = file_row.file_path

        full_path = f"abfss://CDSA@onelake.dfs.fabric.microsoft.com/lk_cdsa_landing_zone.Lakehouse{landing_directory}/{filename}"
        print(f"\n|| Processing file_id: {file_id}")
        print(f"|| Filename: {filename}")
        print(f"|| Full path: {full_path}")

        # Load source attributes
        source_attr_df = spark.sql(f"""
            SELECT *, sf.source_id as source_feed_id, t.target_id as target_object_id
            FROM lk_cdsa_bronze.meta_db.source s
            JOIN lk_cdsa_bronze.meta_db.data_object do ON s.source_id = do.object_id
            LEFT JOIN lk_cdsa_bronze.meta_db.source_feed sf ON s.source_id = sf.source_id
            LEFT JOIN lk_cdsa_bronze.meta_db.target t ON sf.target_object_name = t.target_name
            WHERE s.source_id = '{source_id}'
        """)

        if source_attr_df.count() == 0:
            print(f"|| Skipping file_id {file_id}: No source attributes found.")
            continue

        source_row = source_attr_df.first()
        column_delimiter = source_row.column_delimiter or ","
        stage_name = source_row.stage_name

        # Generate stage table name
        stage_table_name_df = spark.sql(f"""
            SELECT UPPER(LOWER(c.stage_name) || '_' || RIGHT('0000000' || CAST(a.file_id AS STRING), 7)) AS table_name
            FROM lk_cdsa_bronze.meta_db.data_file a
            JOIN lk_cdsa_bronze.meta_db.source c ON a.object_id = c.source_id
            WHERE a.file_id = {file_id}
        """)
        stage_table_name = stage_table_name_df.first()["table_name"]
        print(f"|| Stage table name: {stage_table_name}")

        # Get source feed columns
        feed_columns_df = spark.sql(f"""
            SELECT column_name, data_type, max_length, scale
            FROM lk_cdsa_bronze.meta_db.source_feed_column
            WHERE source_id = '{source_id}'
            ORDER BY ordinal_position, column_id
        """)
        feed_columns = feed_columns_df.collect()

        # Load file
        df = spark.read \
            .option("header", True) \
            .option("delimiter", column_delimiter) \
            .csv(full_path)

        print(f"|| Records read from file: {df.count()}")

        # Normalize column names
        df = df.toDF(*[col_name.lower() for col_name in df.columns])

        # Check for schema mismatch
        expected_columns = [row.column_name.lower() for row in feed_columns]
        df_columns = [col.lower() for col in df.columns]

        if set(expected_columns) != set(df_columns):
            print(f"|| Schema mismatch for file_id {file_id}.")
            print(f"|| Expected columns: {expected_columns}")
            print(f"|| Found columns: {df_columns}")
            continue

        # Cast columns to expected types
        for row in feed_columns:
            col_name = row.column_name.lower()
            expected_type = row.data_type.lower()

            if expected_type == "decimal":
                df = df.withColumn(col_name, col(col_name).cast(f"decimal({int(row.max_length)},{int(row.scale)})"))
            elif expected_type in ["char", "varchar", "string"]:
                df = df.withColumn(col_name, col(col_name).cast("string"))
            elif expected_type in ["bigint", "long"]:
                df = df.withColumn(col_name, col(col_name).cast("long"))
            elif expected_type in ["int", "integer"]:
                df = df.withColumn(col_name, col(col_name).cast("int"))
            elif expected_type == "double":
                df = df.withColumn(col_name, col(col_name).cast("double"))
            elif expected_type == "float":
                df = df.withColumn(col_name, col(col_name).cast("float"))
            elif expected_type == "boolean":
                df = df.withColumn(col_name, col(col_name).cast("boolean"))
            elif expected_type == "date":
                df = df.withColumn(col_name, col(col_name).cast("date"))
            elif expected_type == "timestamp":
                df = df.withColumn(col_name, col(col_name).cast("timestamp"))
            else:
                df = df.withColumn(col_name, col(col_name).cast("string"))  # fallback

        # Reorder columns
        df = df.select(*expected_columns)

        # Create stage table
        column_defs = []
        for row in feed_columns:
            col_name = row.column_name.lower()
            data_type = row.data_type.lower()

            if data_type in ["char", "varchar"]:
                col_def = f"{col_name} string"
            elif data_type == "decimal":
                col_def = f"{col_name} decimal({int(row.max_length)},{int(row.scale)})"
            else:
                col_def = f"{col_name} {data_type}"

            column_defs.append(col_def)

        create_table_sql = f"""
        CREATE TABLE IF NOT EXISTS lk_cdsa_bronze.bronze_db.{stage_table_name} (
            {', '.join(column_defs)}
        ) USING DELTA
        """
        spark.sql(create_table_sql)

        # Write to stage table
        output_table = f"lk_cdsa_bronze.bronze_db.{stage_table_name}"
        df.write.format("delta").mode("append").saveAsTable(output_table)

        row_count = df.count()
        print(f"|| Records staged to {stage_table_name}: {row_count}")

        # Update metadata
        spark.sql(f"""
            UPDATE lk_cdsa_bronze.meta_db.data_file
            SET row_count = {row_count}, file_status = 'COMPLETED'
            WHERE file_id = {file_id}
        """)

        print(f"|| ✅ Staged file_id {file_id} successfully.")
        staged_count += 1

    except Exception as e:
        print(f"|| ❌ Error processing file_id {file_row.file_id}: {str(e)}")


# Final output
if staged_count == 0:
    print("|| No REGISTERED files were staged.")
    print("||----------------SKIPPED----------------||")
else:
    print(f"|| Stage load completed successfully for {staged_count} file(s).")
    print("||----------------SUCCESS----------------||")

print('||-----')


In [1]:
%%sql
-- update lk_cdsa_bronze.meta_db.data_file set file_status = 'REGISTERED' where file_id = 1;
-- drop table lk_cdsa_bronze.bronze_db.stg_cdtq_customer_0000001;
select * from lk_cdsa_bronze.meta_db.data_file;
select * from lk_cdsa_bronze.bronze_db.stg_cdtq_customer_0000001;


StatementMeta(, a68279ec-f6ed-4a92-b896-942b94d735d5, 3, Finished, Available, Finished)

<Spark SQL result set with 1 rows and 20 fields>

<Spark SQL result set with 500 rows and 16 fields>