In [0]:
# Standard library imports
from dataclasses import dataclass
from typing import Optional, Dict, List
from datetime import datetime
from pathlib import Path
import json
import csv
import uuid
import io
import re

In [0]:
@dataclass
class ProcessingConfig:
    """Configuration for file processing"""
    container_name: str
    file_pattern: str
    encoding: str = ""
    delimiter: str = ""
    quotechar: str = ""
    escapechar: str = ""
    skip_lines: int = 0
    audit_table: str = "file_processing_audit"
    sheet_name: str = ""
    excel_starting_cell: str = ""

In [0]:
    %pip install openpyxl xlrd

In [0]:
    class WidgetManager:
        """Manages the lifecycle of Databricks widgets"""

        # Widget definitions by domain
        WIDGET_GROUPS = {
            "Accounts Payable": [
                "INVOICE_NUMBER",
                "INVOICE_LINE_NUM",
                "INVOICE_DATE",
                "SUPPLIER_MASTER_ID",
                "SUPPLIER_NAME",
                "INVOICE_NET_AMOUNT",
                "INVOICE_NET_AMOUNT_USD",
                "INVOICE_QTY",
                "RECEIVED_DATE",
                "confirm_mapping"
            ],
            "Purchase Orders": [
                "PURCHASE_ORDER_NUM",
                "PURCHASE_ORDER_LINE_NUM",
                "PURCHASE_ORDER_DATE",
                "SUPPLIER_MASTER_ID",
                "SUPPLIER_NAME",
                "TOTAL_SPEND_LOCAL",
                "TOTAL_SPEND_USD",
                "RECEIVED_DATE",
                "confirm_mapping"
            ],
            "Item Master": [
                "PART_NUMBER",
                "PART_NAME",
                "PART_DESCRIPTION",
                "confirm_mapping"
            ],
            "Supplier Master": [
                "SUPPLIER_ID",
                "SUPPLIER_NAME",
                "confirm_mapping"
            ]
        }

        @staticmethod
        def check_widget_exists(widget_name: str) -> bool:
            """Safely check if a widget exists"""
            try:
                dbutils.widgets.get(widget_name)
                return True
            except Exception:
                return False

        @staticmethod
        def cleanup_widgets(domain: str = None):
            """Cleans up widgets based on domain"""
            if domain and domain in WidgetManager.WIDGET_GROUPS:
                widgets_to_remove = WidgetManager.WIDGET_GROUPS[domain]
                print(f"Cleaning up widgets for domain: {domain}")
            else:
                widgets_to_remove = []
                for group in WidgetManager.WIDGET_GROUPS.values():
                    widgets_to_remove.extend(group)
                print("Cleaning up all mapping widgets...")

            removed_count = 0
            for widget in widgets_to_remove:
                if WidgetManager.check_widget_exists(widget):
                    try:
                        dbutils.widgets.remove(widget)
                        removed_count += 1
                    except Exception:
                        pass

            if removed_count > 0:
                print(f"✓ Removed {removed_count} existing widgets")

        @staticmethod
        def create_base_widgets():
            """Creates the base widgets needed for file processing"""
            # Text widgets
            dbutils.widgets.text("container_name", "gibsonanalytics/testlake", "Blob Storage Container")
            dbutils.widgets.text("file_pattern", "", "File Name")
            dbutils.widgets.text("file_encoding", "", "File Encoding")
            dbutils.widgets.text("file_quotechar", "", "File Quote Character")
            dbutils.widgets.text("file_escapechar", "", "File Escape Character")
            dbutils.widgets.text("file_delimiter", "", "File Delimiter")
            dbutils.widgets.text("skip_lines", "0", "Number of lines to skip")
            dbutils.widgets.text("audit_table", "file_processing_audit", "Audit Table")
            dbutils.widgets.text("gc_portco_id", "", "GC Portfolio Company ID")
            dbutils.widgets.text("business_unit_id", "", "Business Unit ID")
            dbutils.widgets.text("sheet_name", "", "Excel Sheet Name")
            dbutils.widgets.text("start_cell", "A1", "Start Cell (e.g. A1, B4)")

            # Dropdown widget for data domain
            dbutils.widgets.dropdown(
                "processing_option",
                "null",
                ["null"] + list(WidgetManager.WIDGET_GROUPS.keys()),
                label="Select a Data Domain"
            )

    def get_config_from_widgets() -> ProcessingConfig:
        """Gets configuration from widgets"""
        try:
            skip_lines = int(dbutils.widgets.get("skip_lines"))
        except ValueError:
            print("Warning: Invalid value for skip_lines. Using 0 as default.")
            skip_lines = 0

        return ProcessingConfig(
            container_name=dbutils.widgets.get("container_name"),
            file_pattern=dbutils.widgets.get("file_pattern"),
            encoding=dbutils.widgets.get("file_encoding"),
            delimiter=dbutils.widgets.get("file_delimiter"),
            quotechar=dbutils.widgets.get("file_quotechar"),
            escapechar=dbutils.widgets.get("file_escapechar"),
            skip_lines=skip_lines,
            audit_table=dbutils.widgets.get("audit_table"),
            sheet_name=dbutils.widgets.get("sheet_name"),
            excel_starting_cell=dbutils.widgets.get("excel_starting_cell")
        )


In [0]:
    class FileProcessor:
        """Processes files in Databricks with built-in monitoring"""

        def __init__(self, spark: SparkSession, config: ProcessingConfig):
            self.spark = spark
            self.config = config
            self.execution_id = str(uuid.uuid4())
            self.start_time = datetime.now()
            self.domain = None  # Add domain tracking
            print("\n" + "=" * 140)
            print(f"Initializing FileProcessor with execution ID: {self.execution_id}")
            print("=" * 140 + "\n")
            self._setup_audit_table()

        def _setup_audit_table(self):
            """Creates audit table if it doesn't exist"""
            print("Setting up audit table...")
            self.spark.sql(f"""
                CREATE TABLE IF NOT EXISTS {self.config.audit_table} (
                    execution_id STRING,
                    file_name STRING,
                    file_size LONG,
                    file_modified_timestamp TIMESTAMP,
                    start_time TIMESTAMP,
                    end_time TIMESTAMP,
                    user_name STRING,
                    status STRING,
                    rows_processed LONG,
                    schema_info STRING,
                    file_properties STRING,
                    processing_details STRING
                ) USING DELTA
            """)
            print("✓ Audit table ready")

        def _sum_numeric_columns(self, df: DataFrame) -> Dict:
            """
            Calculates sum for all numeric columns in one pass.
            """
            numeric_cols = []
            for column_name, dtype in df.dtypes:
                if any(numeric_type in dtype.lower()
                       for numeric_type in ['int', 'bigint', 'double', 'decimal', 'float', 'long']):
                    numeric_cols.append(column_name)

            if not numeric_cols:
                return {}

            sums_row = df.agg(*(F.sum(c).alias(c) for c in numeric_cols)).collect()[0]

            numeric_sums = {}
            for c in numeric_cols:
                val = sums_row[c]
                if val is not None:
                    numeric_sums[c] = float(val)

            return numeric_sums

        def _analyze_dataframe(self, df: DataFrame) -> Dict:
            """Analyzes DataFrame schema and content"""
            print("\nAnalyzing DataFrame Schema")
            print("-" * 140)

            total_rows = df.count()
            numeric_sums = self._sum_numeric_columns(df)

            null_counts_row = df.select([
                F.count(F.when(F.col(c).isNull(), c)).alias(c) for c in df.columns
            ]).collect()[0].asDict()

            schema_analysis = []
            total_columns = len(df.dtypes)

            for idx, (column_name, dtype) in enumerate(df.dtypes, 1):
                print(f"Processing column {idx}/{total_columns}: {column_name}")

                null_count = null_counts_row[column_name]
                non_null_count = total_rows - null_count
                distinct_count = df.select(column_name).distinct().count()

                sample_values = (df.select(column_name)
                                 .where(col(column_name).isNotNull())
                                 .distinct()
                                 .orderBy(F.rand())
                                 .limit(5)
                                 .collect())
                sample_values = [str(row[0])[:20] for row in sample_values]

                column_info = {
                    "column_name": column_name,
                    "type": dtype,
                    "non_null_count": non_null_count,
                    "null_count": null_count,
                    "distinct_count": distinct_count,
                    "sample_values": sample_values
                }

                if column_name in numeric_sums:
                    column_info["sum"] = numeric_sums[column_name]

                schema_analysis.append(column_info)

            return {
                "total_rows": total_rows,
                "total_columns": len(df.columns),
                "schema_details": schema_analysis
            }

        def _display_processing_info(self, file_info: Dict, schema_analysis: Dict, properties: Dict):
            """Displays processing information in a readable format"""
            print("\n" + "=" * 140)
            print(f"{'File Processing Report - ' + self.execution_id:^140}")
            print("=" * 140 + "\n")

            print("PROCESSING SUMMARY")
            print("-" * 140)
            print(f"File Name:              {file_info['name']}")
            print(f"Size:                   {file_info['size']:,} bytes")
            print(f"Total Rows:             {schema_analysis['total_rows']:,}")
            print(f"Total Columns:          {schema_analysis['total_columns']}")
            print(f"Processing Duration:    {str(datetime.now() - self.start_time).split('.')[0]}")

            print("\nFILE PROPERTIES")
            print("-" * 140)
            print(f"Encoding:               {properties['encoding']}")
            print(f"Delimiter:              {repr(properties['delimiter'])}")
            print(f"Quote Character:        {repr(properties['quotechar'])}")
            print(
                f"Escape Character:       {repr(properties['escapechar']) if properties['escapechar'] else 'Same as quote character'}")
            print(f"Skip Lines:             {properties['skip_lines']}")

            print("\nSCHEMA ANALYSIS")
            print("-" * 140)
            col_format = "{:<32} {:<8} {:>12} {:>12} {:>8} {:>20} {:<35}"
            print(col_format.format(
                "Column Name", "Type", "Non-Null", "Distinct", "Null %", "Sum", "Samples"
            ))
            print("-" * 140)

            for col_info in schema_analysis["schema_details"]:
                total = col_info["null_count"] + col_info["non_null_count"]
                null_pct = (col_info["null_count"] / total * 100) if total > 0 else 0

                formatted_non_null = f"{col_info['non_null_count']:,}"
                formatted_distinct = f"{col_info['distinct_count']:,}"
                sum_str = f"{col_info['sum']:,.2f}" if 'sum' in col_info and col_info['sum'] is not None else ""
                samples_str = ', '.join(col_info['sample_values'])

                print(col_format.format(
                    col_info['column_name'],
                    col_info['type'],
                    formatted_non_null,
                    formatted_distinct,
                    f"{null_pct:>6.1f}%",
                    sum_str,
                    samples_str
                ))
            print("=" * 140 + "\n")

        def get_common_encodings(self) -> List[str]:
            """Returns a list of common encodings that are generally supported"""
            return [
                "Big5", "Big5-HKSCS", "CESU-8", "EUC-JP", "EUC-KR", "GB18030", "GB2312", "GBK", "IBM-Thai", "IBM00858",
                "IBM01140", "IBM01141", "IBM01142", "IBM01143", "IBM01144", "IBM01145", "IBM01146", "IBM01147",
                "IBM01148",
                "IBM01149", "IBM037", "IBM1026", "IBM1047", "IBM273", "IBM277", "IBM278", "IBM280", "IBM284", "IBM285",
                "IBM290", "IBM297", "IBM420", "IBM424", "IBM437", "IBM500", "IBM775", "IBM850", "IBM852", "IBM855",
                "IBM857",
                "IBM860", "IBM861", "IBM862", "IBM863", "IBM864", "IBM865", "IBM866", "IBM868", "IBM869", "IBM870",
                "IBM871",
                "IBM918", "ISO-2022-CN", "ISO-2022-JP", "ISO-2022-JP-2", "ISO-2022-KR", "ISO-8859-1", "ISO-8859-13",
                "ISO-8859-15", "ISO-8859-2", "ISO-8859-3", "ISO-8859-4", "ISO-8859-5", "ISO-8859-6", "ISO-8859-7",
                "ISO-8859-8",
                "ISO-8859-9", "JIS_X0201", "JIS_X0212-1990", "KOI8-R", "KOI8-U", "Shift_JIS", "TIS-620", "US-ASCII",
                "UTF-16",
                "UTF-16BE", "UTF-16LE", "UTF-32", "UTF-32BE", "UTF-32LE", "UTF-8", "windows-1250", "windows-1251",
                "windows-1252", "windows-1253", "windows-1254", "windows-1255", "windows-1256", "windows-1257",
                "windows-1258",
                "windows-31j"
            ]

        def validate_encoding(self, encoding: str) -> str:
            """Validates the detected encoding against the list of common encodings"""
            common_encodings = self.get_common_encodings()
            if encoding and encoding.upper() in [enc.upper() for enc in common_encodings]:
                return encoding
            else:
                print(f"Encoding {encoding} not supported. Falling back to UTF-8.")
                return 'UTF-8'

        def detect_file_properties(self, file_path: Path) -> Dict:
            """Detects file encoding and CSV properties"""
            print("Detecting file properties...")

            try:
                if self.config.encoding:
                    validated_encoding = self.validate_encoding(self.config.encoding)
                    print(f"Using user-provided encoding: {validated_encoding}")
                else:
                    print("Detecting file encoding...")
                    raw_data = file_path.read_bytes()[:3000000]  # Read first 3MB
                    result = chardet.detect(raw_data)

                    if result and result['confidence'] > 0.5:
                        detected_encoding = result['encoding']
                        print(f"Detected encoding: {detected_encoding} with confidence {result['confidence']}")
                        validated_encoding = self.validate_encoding(detected_encoding)
                    else:
                        print("Low confidence in detected encoding. Defaulting to UTF-8.")
                        validated_encoding = 'utf-8'

                print("Analyzing CSV format...")
                if not (self.config.delimiter and self.config.quotechar):
                    try:
                        with open(file_path, 'r', encoding=validated_encoding) as file:
                            for _ in range(self.config.skip_lines):
                                file.readline()
                            content = ''.join([file.readline() for _ in range(5)])

                        dialect = csv.Sniffer().sniff(content)
                        has_header = csv.Sniffer().has_header(content)

                        delimiter = self.config.delimiter or dialect.delimiter
                        quotechar = self.config.quotechar or dialect.quotechar
                        escapechar = self.config.escapechar

                        print(f"Using Delimiter: {repr(delimiter)}")
                        print(f"Using Quote Character: {repr(quotechar)}")
                        print(
                            f"Using Escape Character: {repr(escapechar) if escapechar else 'Same as quote character'}")
                        print(f"Skipping Lines: {self.config.skip_lines}")
                        print(f"Has Header: {has_header}")

                    except Exception as e:
                        print(f"Error detecting CSV dialect: {str(e)}")
                        delimiter = self.config.delimiter or ","
                        quotechar = self.config.quotechar or '"'
                        escapechar = self.config.escapechar
                else:
                    delimiter = self.config.delimiter
                    quotechar = self.config.quotechar
                    escapechar = self.config.escapechar
                    print(f"Using user-provided delimiter: {repr(delimiter)}")
                    print(f"Using user-provided quote character: {repr(quotechar)}")
                    print(
                        f"Using user-provided escape character: {repr(escapechar) if escapechar else 'Same as quote character'}")
                    print(f"Skipping Lines: {self.config.skip_lines}")

                print("✓ File properties detected successfully")
                return {
                    "encoding": validated_encoding,
                    "delimiter": delimiter,
                    "quotechar": quotechar,
                    "escapechar": escapechar,
                    "skip_lines": self.config.skip_lines
                }

            except Exception as e:
                print(f"Error detecting file properties: {str(e)}")
                return {
                    "encoding": "utf-8",
                    "delimiter": ",",
                    "quotechar": '"',
                    "escapechar": self.config.escapechar,
                    "skip_lines": self.config.skip_lines
                }

        def get_latest_file(self) -> Optional[str]:
            """Gets the most recent file matching the pattern"""
            print(f"Searching for files matching pattern: {self.config.file_pattern}")

            try:
                files = dbutils.fs.ls(f"/mnt/{self.config.container_name}")
                matching_files = [f for f in files if f.name.startswith(self.config.file_pattern)]

                if not matching_files:
                    print(f"No files found matching pattern: {self.config.file_pattern}")
                    return None

                latest = sorted(matching_files, key=lambda x: x.modificationTime, reverse=True)[0]
                print(f"✓ Selected file: {latest.name}")
                return latest.name

            except Exception as e:
                print(f"Error listing files: {str(e)}")
                return None

        def process_file(self, file_name: str) -> Optional[DataFrame]:
            """Main file processing method"""
            self.domain = dbutils.widgets.get("processing_option")
            if self.domain != "null":
                WidgetManager.cleanup_widgets(self.domain)

            print(f"\nStarting file processing - {file_name}")

            try:
                file_path = Path(f"/dbfs/mnt/{self.config.container_name}/{file_name}")
                file_stats = file_path.stat()
                file_modified = datetime.fromtimestamp(file_stats.st_mtime)

                file_info = {
                    "name": file_name,
                    "size": file_stats.st_size,
                    "modified": file_modified
                }

                if file_name.lower().endswith((".xlsx", ".xls")):
                    print("✓ Detected Excel file format")

                    sheet_name = self.config.sheet_name
                    excel_start_cell = self.config.excel_starting_cell

                    df = self.spark.read.format("com.crealytics.spark.excel") \
                        .option("dataAddress", f"'{sheet_name}'!{excel_start_cell}") \
                        .option("header", "true") \
                        .option("inferSchema", "true") \
                        .option("timestampFormat", "MM-dd-yyyy HH:mm:ss") \
                        .load(f"/mnt/{self.config.container_name}/{file_name}")

                    properties = {
                        "encoding": "binary",
                        "delimiter": None,
                        "quotechar": None,
                        "escapechar": None,
                        "skip_lines": 0
                    }

                elif file_name.lower().endswith((".csv", ".txt")):
                    print("✓ Detected CSV or text file format")

                    properties = self.detect_file_properties(file_path)

                    reader = self.spark.read.format("csv") \
                        .option("header", "true") \
                        .option("inferSchema", "true") \
                        .option("delimiter", properties["delimiter"]) \
                        .option("quote", properties["quotechar"]) \
                        .option("multiline", "true") \
                        .option("encoding", properties["encoding"])

                    if properties["escapechar"]:
                        reader = reader.option("escape", properties["escapechar"])
                    else:
                        reader = reader.option("escape", properties["quotechar"])

                    if properties["skip_lines"] > 0:
                        try:
                            print(f"Attempting to skip {properties['skip_lines']} line(s) using 'skipRows' option")
                            df = reader.option("skipRows", properties["skip_lines"]) \
                                .load(f"/mnt/{self.config.container_name}/{file_name}")
                        except Exception as e1:
                            print(f"Error with skipRows: {str(e1)}")
                            try:
                                with open(file_path, 'r', encoding=properties["encoding"]) as f:
                                    first_lines = [f.readline().strip() for _ in range(properties["skip_lines"] + 1)]
                                comment_char = first_lines[0][0] if first_lines[0] else '#'
                                print(f"Using comment filtering with character: {repr(comment_char)}")
                                with open(file_path, 'r', encoding=properties["encoding"]) as f:
                                    content = f.readlines()
                                for i in range(min(properties["skip_lines"], len(content))):
                                    if not content[i].startswith(comment_char):
                                        content[i] = comment_char + content[i]
                                temp_file_path = file_path.with_suffix('.temp.csv')
                                with open(temp_file_path, 'w', encoding=properties["encoding"]) as f:
                                    f.writelines(content)
                                df = reader.option("comment", comment_char) \
                                    .load(f"/mnt/{self.config.container_name}/{file_name}.temp")
                                temp_file_path.unlink(missing_ok=True)
                            except Exception as e2:
                                print(f"Error with comment workaround: {str(e2)}")
                                print("Falling back to loading all data and dropping first rows")
                                df = reader.load(f"/mnt/{self.config.container_name}/{file_name}")
                                df = df.limit(df.count() - properties["skip_lines"])
                    else:
                        df = reader.load(f"/mnt/{self.config.container_name}/{file_name}")

                else:
                    raise ValueError(f"Unsupported file type: {file_name}")

                analysis = self._analyze_dataframe(df)
                self._display_processing_info(file_info, analysis, properties)

                user = dbutils.notebook.entry_point.getDbutils().notebook().getContext().userName().get()
                self.spark.createDataFrame([(
                    self.execution_id,
                    file_name,
                    file_stats.st_size,
                    file_modified,
                    self.start_time,
                    datetime.now(),
                    user,
                    "SUCCESS",
                    analysis["total_rows"],
                    json.dumps(analysis["schema_details"]),
                    json.dumps(properties),
                    json.dumps({"file_processing_successful": True})
                )]).toDF(
                    "execution_id", "file_name", "file_size", "file_modified_timestamp",
                    "start_time", "end_time", "user_name", "status", "rows_processed",
                    "schema_info", "file_properties", "processing_details"
                ).write.mode("append").saveAsTable(self.config.audit_table)

                print("✓ File processing completed successfully")
                return df

            except Exception as e:
                error_msg = str(e)
                print("\nError processing file")
                print(f"File: {file_name}")
                print(f"Error: {error_msg}")

                user = dbutils.notebook.entry_point.getDbutils().notebook().getContext().userName().get()

                self.spark.createDataFrame([(
                    self.execution_id,
                    file_name,
                    file_stats.st_size if 'file_stats' in locals() else None,
                    file_modified if 'file_modified' in locals() else None,
                    self.start_time,
                    datetime.now(),
                    user,
                    "ERROR",
                    0,
                    None,
                    json.dumps(properties) if 'properties' in locals() else None,
                    json.dumps({"error": error_msg})
                )]).toDF(
                    "execution_id", "file_name", "file_size", "file_modified_timestamp",
                    "start_time", "end_time", "user_name", "status", "rows_processed",
                    "schema_info", "file_properties", "processing_details"
                ).write.mode("append").saveAsTable(self.config.audit_table)

                return None

    def create_file_processor(spark: SparkSession) -> FileProcessor:
        """Creates FileProcessor instance from widgets"""

        # Always create base widgets first
        WidgetManager.create_base_widgets()

        # Now that widgets are guaranteed to exist, get config
        config = get_config_from_widgets()

        # Initialize FileProcessor
        processor = FileProcessor(spark, config)

        return processor

In [0]:
    if __name__ == "__main__":
        # Create processor instance
        processor = create_file_processor(spark)

        # Process latest file
        if latest_file := processor.get_latest_file():
            if df := processor.process_file(latest_file):
                print("\nPreview of processed data:")
                display(df.limit(20))

In [0]:
# Function to sanitize column names
def sanitize_column_name(name):
    return re.sub(r'[^0-9a-zA-Z$]+', '_', name).lower()

# Getting unique column names after sanitization
unique_names = set()
new_column_names = {}

for original_name in df.columns:
    sanitized_name = sanitize_column_name(original_name)
    unique_name = sanitized_name
    count = 1

    # Ensure the uniqueness of the column name
    while unique_name in unique_names:
        unique_name = f"{sanitized_name}_{count}"
        count += 1

    unique_names.add(unique_name)
    new_column_names[original_name] = unique_name

# Rename columns
for original_name, new_name in new_column_names.items():
    df = df.withColumnRenamed(original_name, new_name)

# Optionally, print the updated column names to verify
print("Updated column names:", df.columns)

In [0]:
try:
    # Count the number of rows before removing duplicates
    row_count_before = df.count()

    # Drop duplicate rows (keeps the first occurrence of duplicate rows)
    df_deduplicated = df.dropDuplicates()

    # Count the number of rows after removing duplicates
    row_count_after = df_deduplicated.count()

    # Calculate the number of duplicate rows
    total_duplicate_rows = row_count_before - row_count_after

    if total_duplicate_rows > 0:
        print(f"Total number of duplicate rows removed: {total_duplicate_rows}")
        df = df_deduplicated
    else:
        print("No duplicate rows found. No action required.")

    # Show the count of rows after potential deduplication
    print(f"Number of rows after deduplication: {row_count_after}")

except Exception as e:
    print(f"Error during duplicate row removal: {e}")

In [0]:
try:
    # Get widget values directly
    gc_portco_id = dbutils.widgets.get("gc_portco_id")
    business_unit_id = dbutils.widgets.get("business_unit_id")
    ingested_by = dbutils.notebook.entry_point.getDbutils().notebook().getContext().userName().get()

    # Add the SRC_FILE column with the name of the source file
    df = df.withColumn("SRC_FILE", lit(latest_file))
    print("Added SRC_FILE column to DataFrame.")

    # Add the EXECUTION_ID from the FileProcessor
    df = df.withColumn("EXECUTION_ID", lit(processor.execution_id))
    print("Added EXECUTION_ID column to DataFrame with value:", processor.execution_id)

    # Add the GC_PORTCO_ID column using the value from the widget
    df = df.withColumn("GC_PORTCO_ID", lit(gc_portco_id))
    print("Added GC_PORTCO_ID column to DataFrame with value:", gc_portco_id)

    # Add the BUSINESS_UNIT_ID column using the value from the widget
    df = df.withColumn("BUSINESS_UNIT_ID", lit(business_unit_id))
    print("Added BUSINESS_UNIT_ID column to DataFrame with value:", business_unit_id)

    # Add the INGESTED_BY column using the current user
    df = df.withColumn("INGESTED_BY", lit(ingested_by))
    print("Added INGESTED_BY column to DataFrame with value:", ingested_by)

    # Add the INGESTION_DT column with the current timestamp
    df = df.withColumn("INGESTION_DT", current_timestamp())
    print("Added INGESTION_DT column to DataFrame.")

    # Display a sample of 20 records using display()
    print("\nDisplaying a 20 record sample of the DataFrame:")
    display(df.limit(20))

except Exception as e:
    print(f"Error modifying DataFrame: {e}")

In [0]:
try:
    print("Starting to trim leading and trailing spaces from each column.")

    # Iterate over each column in the DataFrame
    for col_name in df.columns:
        # 'trim' function is applied to each column to remove leading and trailing spaces
        df = df.withColumn(col_name, trim(df[col_name]))

    print("Completed trimming spaces from all columns.")
except Exception as e:
    print(f"Error during trimming process: {e}")

In [0]:
class WidgetManager:
    """Manages the lifecycle of Databricks widgets"""

    # Widget definitions by domain
    WIDGET_GROUPS = {
        "Accounts Payable": [
            "INVOICE_NUMBER",
            "INVOICE_LINE_NUM",
            "INVOICE_DATE",
            "SUPPLIER_MASTER_ID",
            "SUPPLIER_NAME",
            "INVOICE_NET_AMOUNT",
            "INVOICE_NET_AMOUNT_USD",
            "INVOICE_QTY",
            "RECEIVED_DATE",
            "confirm_mapping"
        ],
        "Purchase Orders": [
            "PURCHASE_ORDER_NUM",
            "PURCHASE_ORDER_LINE_NUM",
            "PURCHASE_ORDER_DATE",
            "SUPPLIER_MASTER_ID",
            "SUPPLIER_NAME",
            "TOTAL_SPEND_LOCAL",
            "TOTAL_SPEND_USD",
            "RECEIVED_DATE",
            "confirm_mapping"
        ],
        "Item Master": [
            "PART_NUMBER",
            "PART_NAME",
            "PART_DESCRIPTION",
            "confirm_mapping"
        ],
        "Supplier Master": [
            "SUPPLIER_ID",
            "SUPPLIER_NAME",
            "confirm_mapping"
        ]
    }

    GENERAL_REPORT_WIDGETS = [
        "MAPPED_SUPPLIER_NAME",
        "MAPPED_PART_NUMBER",
        "MAPPED_TOTAL_SPEND",
        "confirm_general_mapping"
    ]

    @staticmethod
    def check_widget_exists(widget_name: str) -> bool:
        """
        Safely check if a widget exists

        Args:
            widget_name: Name of widget to check

        Returns:
            bool: True if widget exists, False otherwise
        """
        try:
            dbutils.widgets.get(widget_name)
            return True
        except Exception:
            return False

    @staticmethod
    def cleanup_widgets(domain: str = None):
        """
        Cleans up widgets based on domain. If no domain specified, cleans up all known widgets.

        Args:
            domain (str, optional): Specific domain whose widgets should be cleaned up
        """
        if domain and domain in WidgetManager.WIDGET_GROUPS:
            # Clean up specific domain widgets
            widgets_to_remove = WidgetManager.WIDGET_GROUPS[domain]
            print(f"Cleaning up widgets for domain: {domain}")
        else:
            # Clean up all known widgets
            widgets_to_remove = []
            for group in WidgetManager.WIDGET_GROUPS.values():
                widgets_to_remove.extend(group)
            print("Cleaning up all mapping widgets...")

        removed_count = 0
        for widget in widgets_to_remove:
            if WidgetManager.check_widget_exists(widget):
                try:
                    dbutils.widgets.remove(widget)
                    removed_count += 1
                except Exception:
                    pass  # Ignore errors during removal

        if removed_count > 0:
            print(f"✓ Removed {removed_count} existing widgets")

    @staticmethod
    def cleanup_general_widgets():
        for widget in WidgetManager.GENERAL_REPORT_WIDGETS:
            if WidgetManager.check_widget_exists(widget):
                dbutils.widgets.remove(widget)


class WidgetMappingManager:
    """Manages the creation and mapping of domain-specific widgets"""

    # Define the required fields for each domain
    DOMAIN_MAPPINGS = {
        "Purchase Orders": {
            "PURCHASE_ORDER_NUM": "Map: Purchase Order Number",
            "PURCHASE_ORDER_LINE_NUM": "Map: Purchase Order Line Number",
            "PURCHASE_ORDER_DATE": "Map: Purchase Order Date",
            "SUPPLIER_MASTER_ID": "Map: Supplier Master ID",
            "SUPPLIER_NAME": "Map: Supplier Name",
            "TOTAL_SPEND_LOCAL": "Map: Total Spend Local",
            "TOTAL_SPEND_USD": "Map: Total Spend USD",
            "RECEIVED_DATE": "Map: Received Date"
        },
        "Accounts Payable": {
            "INVOICE_NUMBER": "Map: Invoice Number",
            "INVOICE_LINE_NUM": "Map: Invoice Date",
            "INVOICE_DATE": "Map: Due Date",
            "SUPPLIER_MASTER_ID": "Map: Supplier Master ID",
            "SUPPLIER_NAME": "Map: Supplier Name",
            "INVOICE_NET_AMOUNT": "Map: Invoice Amount",
            "INVOICE_NET_AMOUNT_USD": "Map: Invoice Amount USD",
            "RECEIVED_DATE": "Map: Payment Status"
        },
        "Item Master": {
            "PART_NUMBER": "Map: Item Number",
            "PART_NAME": "Map: Item Description",
            "PART_DESCRIPTION": "Map: Item Category"
        },
        "Supplier Master": {
            "SUPPLIER_ID": "Map: Supplier ID",
            "SUPPLIER_NAME": "Map: Supplier Name"
        }
    }

    @staticmethod
    def create_mapping_widgets(df, domain: str):
        """
        Creates mapping widgets for the specified domain using DataFrame columns

        Args:
            df: Spark DataFrame containing the source data
            domain: The data domain to create widgets for
        """
        print(f"\nCreating mapping widgets for domain: {domain}")

        if domain not in WidgetMappingManager.DOMAIN_MAPPINGS:
            print(f"Error: Unknown domain '{domain}'. No mapping widgets created.")
            return

        # Get column list from DataFrame and add empty option
        column_list = [""] + df.columns
        print(f"Available columns: {column_list}")

        # Create widgets for the specified domain
        domain_fields = WidgetMappingManager.DOMAIN_MAPPINGS[domain]
        for field_name, field_label in domain_fields.items():
            dbutils.widgets.dropdown(field_name, "", column_list, field_label)

        # Always add the confirmation widget
        dbutils.widgets.dropdown("confirm_mapping", "no", ["no", "yes"], "Set to 'yes' to apply mapping")

        print(f"✓ Created {len(domain_fields)} mapping widgets for {domain}")
        print("Please make your selections at the top of the notebook.")

    @staticmethod
    def get_field_mappings(domain: str) -> dict:
        """
        Gets the current field mappings for the specified domain

        Args:
            domain: The data domain to get mappings for

        Returns:
            dict: Dictionary of field mappings {target_field: source_column}
        """
        if domain not in WidgetMappingManager.DOMAIN_MAPPINGS:
            return {}

        mappings = {}
        for field_name in WidgetMappingManager.DOMAIN_MAPPINGS[domain].keys():
            mapped_column = dbutils.widgets.get(field_name)
            if mapped_column:  # Only include non-empty mappings
                mappings[field_name] = mapped_column

        return mappings


# Usage example:
def create_mapping_widgets_for_domain(df, domain: str = None):
    """
    Creates mapping widgets based on the selected domain

    Args:
        df: Spark DataFrame containing the source data
        domain: Optional domain override, otherwise uses widget value
    """
    # Get domain from widget if not specified
    if domain is None:
        domain = dbutils.widgets.get("processing_option")

    if domain == "null":
        print("No domain selected. Please select a data domain first.")
        return

    # Clean up ALL existing mapping widgets, not just the current domain
    WidgetManager.cleanup_widgets()  # No domain parameter means clean up all

    # Create new mapping widgets
    WidgetMappingManager.create_mapping_widgets(df, domain)


def create_general_report_widgets(df):
    columns = [""] + df.columns
    WidgetManager.cleanup_general_widgets()

    # Get the current domain and any mapped widget values
    try:
        domain = dbutils.widgets.get("processing_option")
        domain_mappings = WidgetMappingManager.get_field_mappings(domain)
    except:
        domain = None
        domain_mappings = {}

    # Helper to get the actual selected value from a domain widget if it exists
    def get_widget_value_safe(widget_name):
        try:
            return dbutils.widgets.get(widget_name)
        except:
            return ""

    # Only create general widgets if their equivalent domain values are blank
    if not get_widget_value_safe("SUPPLIER_NAME"):
        dbutils.widgets.dropdown("MAPPED_SUPPLIER_NAME", "", columns, "Map: Supplier Name")

    if not get_widget_value_safe("PART_NUMBER"):
        dbutils.widgets.dropdown("MAPPED_PART_NUMBER", "", columns, "Map: SKU / Part Number")

    # Spend column logic: check all domain-level options
    spend_mapped = any([
        get_widget_value_safe("INVOICE_NET_AMOUNT"),
        get_widget_value_safe("TOTAL_SPEND_USD"),
        get_widget_value_safe("TOTAL_SPEND_LOCAL")
    ])

    if not spend_mapped:
        dbutils.widgets.dropdown("MAPPED_TOTAL_SPEND", "", columns, "Map: Spend Column")

    # Always create confirm_mapping if it doesn't exist
    if not WidgetManager.check_widget_exists("confirm_mapping"):
        dbutils.widgets.dropdown("confirm_mapping", "no", ["no", "yes"], "Set to 'yes' to apply mappings")

    print("✓ General ingestion mapping widgets created (if needed).")

In [0]:
# Create mapping widgets based on the selected domain
create_mapping_widgets_for_domain(df)  # domain-specific mappings
create_general_report_widgets(df)      # general-purpose mappings

In [0]:
def apply_widget_mappings(df, domain: str = None):
    """
    Applies the widget mappings to create a new DataFrame with standardized column names

    Args:
        df: Spark DataFrame containing the source data
        domain: Optional domain override, otherwise uses widget value

    Returns:
        DataFrame: New DataFrame with mapped column names, or None if mappings not confirmed
    """
    print("=" * 132)
    print("Applying Widget Mappings".center(132))
    print("=" * 132)
    print()

    # Get domain from widget if not specified
    if domain is None:
        domain = dbutils.widgets.get("processing_option")

    if domain == "null":
        print("Error: No domain selected. Please select a data domain first.")
        return None

    # Check if mappings are confirmed
    if dbutils.widgets.get("confirm_mapping") != "yes":
        print("Warning: Mappings not confirmed. Please set confirm_mapping to 'yes' to apply mappings.")
        return None

    # Get current mappings
    mappings = WidgetMappingManager.get_field_mappings(domain)

    if not mappings:
        print(f"Error: No mappings found for domain '{domain}'")
        return None

    print(f"Processing mappings for domain: {domain}")
    print("-" * 132)

    print("\nField Mappings")
    print("-" * 132)
    col_width = max(len(max(mappings.keys(), key=len)), len(max(mappings.values(), key=len))) + 2
    print(f"{'Source Column'.ljust(col_width)} → {'Target Field'.ljust(col_width)}")
    print("-" * (col_width * 2 + 3))
    for target_field, source_column in mappings.items():
        print(f"{source_column.ljust(col_width)} → {target_field.ljust(col_width)}")

    # Create new DataFrame with mapped columns
    mapped_df = df.select([
        df[source_column].alias(target_field)
        for target_field, source_column in mappings.items()
    ])

    print("\nOutput Schema")
    print("-" * 132)
    mapped_df.printSchema()

    print("\n" + "=" * 132)
    print(f"✓ Successfully created DataFrame with {len(mappings)} mapped columns".center(132))
    print("=" * 132)

    return mapped_df


# Usage example:
mapped_df = apply_widget_mappings(df)  # domain logic

In [0]:
def generate_general_ingestion_report(df):
    from pyspark.sql.functions import col, regexp_replace, when

    if dbutils.widgets.get("confirm_mapping") != "yes":
        print("⚠ Please confirm mapping first by setting 'confirm_mapping' to yes.")
        return

    # Try general mapping widgets first
    use_general = WidgetManager.check_widget_exists("MAPPED_TOTAL_SPEND")

    if use_general:
        supplier_col = dbutils.widgets.get("MAPPED_SUPPLIER_NAME")
        part_col = dbutils.widgets.get("MAPPED_PART_NUMBER")
        spend_col = dbutils.widgets.get("MAPPED_TOTAL_SPEND")
    else:
        # Use domain mappings
        domain = dbutils.widgets.get("processing_option")
        domain_mappings = WidgetMappingManager.get_field_mappings(domain)

        supplier_col = domain_mappings.get("SUPPLIER_NAME")
        part_col = (
                domain_mappings.get("PART_NUMBER") or
                dbutils.widgets.get("MAPPED_PART_NUMBER")  # fallback to general if PART_NUMBER isn't mapped
        )
        spend_col = (
                domain_mappings.get("INVOICE_NET_AMOUNT") or
                domain_mappings.get("TOTAL_SPEND_USD") or
                domain_mappings.get("TOTAL_SPEND_LOCAL") or
                dbutils.widgets.get("MAPPED_TOTAL_SPEND")  # fallback
        )

    if not (supplier_col and part_col and spend_col):
        print(f"⚠ Missing mappings → supplier_col: {supplier_col}, part_col: {part_col}, spend_col: {spend_col}")
        print("⚠ Missing mappings for supplier, part, or spend column.")
        return

    print("=" * 132)
    print("General Ingestion Report".center(132))
    print("=" * 132)

    dedup_df = df.dropDuplicates()
    total_rows = df.count()
    dedup_rows = dedup_df.count()
    duplicate_count = total_rows - dedup_rows

    from pyspark.sql.functions import regexp_replace

    casted_df = dedup_df.withColumn(
        spend_col,
        when(col(spend_col).rlike(r"\(.*\)"),
             # Convert (1234.56) to -1234.56
             regexp_replace(regexp_replace(col(spend_col), "[()]", ""), "[$,]", "").cast("double") * -1
             ).otherwise(
            # Normal positive number
            regexp_replace(col(spend_col), "[$,]", "").cast("double")
        )
    )

    total_spend = casted_df.select(spend_col).na.drop().groupBy().sum(spend_col).collect()[0][0]
    unique_suppliers = dedup_df.select(supplier_col).distinct().count()
    unique_skus = dedup_df.select(part_col).distinct().count()

    print(f"Total Rows Read: {total_rows}")
    print(f"Duplicate Rows Removed: {duplicate_count}")
    print(f"Total Spend: {total_spend}")
    print(f"Unique Suppliers: {unique_suppliers}")
    print(f"Unique SKUs: {unique_skus}")
    print("=" * 132)
    # Try to get file name
    try:
        src_file = dbutils.widgets.get("file_pattern")
    except:
        src_file = "unknown"

    # Create summary_df using df.select(...) to match mapped_df style
    summary_df = df.limit(1).select(
        lit(src_file).alias("SRC_FILE"),
        lit(total_rows).cast("BIGINT").alias("TOTAL_ROWS"),
        lit(duplicate_count).cast("BIGINT").alias("DUPLICATE_ROWS_REMOVED"),
        lit(total_spend).cast("DOUBLE").alias("TOTAL_SPEND"),
        lit(unique_suppliers).cast("BIGINT").alias("UNIQUE_SUPPLIERS"),
        lit(unique_skus).cast("BIGINT").alias("UNIQUE_SKUS"),
        current_timestamp().alias("INGESTION_TIMESTAMP")
    )

    # ✅ Append to Delta table
    summary_df.write.mode("append").format("delta").saveAsTable("gibsonanalytics.ingestion.audit_ingestion_summary")

    # Show in notebook
    print("\n📄 Summary of Ingestion Metrics:")
    display(summary_df)


generate_general_ingestion_report(df)  # general insights

In [0]:
dbutils.widgets.text("api_mapping_table", "api mapping table", "API Mapping Table")
api_mapping_table = dbutils.widgets.get("api_mapping_table")

try:
    # Write the DataFrame to a Delta table
    mapped_df.write \
        .format("delta") \
        .mode("overwrite") \
        .saveAsTable(api_mapping_table)

    print(f"✓ Successfully wrote data to Delta table")

except Exception as e:
    print(f"Error writing to Delta table: {e}")

In [0]:

# Retrieve widget values. (These widgets can be defined in an earlier cell.)
dbutils.widgets.text("max_new_column_percentage", "0.4", "Max New Column Percentage")
dbutils.widgets.text("file_pattern", "myfile", "File Pattern")
dbutils.widgets.text("api_mapping_table", "gibsonanalytics.ingestion.api_mapping", "API Mapping Table")

max_new_column_percentage     = float(dbutils.widgets.get("max_new_column_percentage"))
business_unit_id              = dbutils.widgets.get("business_unit_id")
file_pattern                  = dbutils.widgets.get("file_pattern")
api_mapping_table             = dbutils.widgets.get("api_mapping_table")

In [0]:
import sys
import logging
from datetime import datetime
from pyspark.sql.types import StructType
from typing import Tuple, Dict, Any
from pyspark.sql import functions as F

# Retrieve table identifiers from widgets
dbutils.widgets.text("catalog_name", "", "Bronze Catalog Name")
dbutils.widgets.text("schema_name", "", "Bronze Schema Name")
dbutils.widgets.text("table_name", "", "Bronze Table Name")


catalog_name = dbutils.widgets.get("catalog_name")
schema_name = dbutils.widgets.get("schema_name")
table_name = dbutils.widgets.get("table_name")
full_table_name = f"{catalog_name}.{schema_name}.{table_name}"

# Configure a dedicated logger
logger = logging.getLogger("DeltaTableLogger")
if not logger.handlers:
    logger.setLevel(logging.INFO)
    handler = logging.StreamHandler(sys.stdout)
    formatter = logging.Formatter(fmt="%(asctime)s - %(levelname)s - %(message)s", datefmt="%Y-%m-%d %H:%M:%S")
    handler.setFormatter(formatter)
    logger.addHandler(handler)
    logger.propagate = False

def get_schema_metrics(existing_schema: StructType, new_schema: StructType,
                      max_new_column_percentage: float,
                      min_schema_match_percentage: float,
                      validation_result: Tuple[bool, str]) -> Dict[str, Any]:
    existing_fields = {field.name for field in existing_schema.fields}
    new_fields = {field.name for field in new_schema.fields}
    new_columns = new_fields - existing_fields
    matching_columns = existing_fields.intersection(new_fields)
    existing_count = len(existing_fields)
    new_column_count = len(new_columns)
    matching_count = len(matching_columns)
    new_column_percentage = new_column_count / existing_count if existing_count > 0 else 0
    match_percentage = matching_count / existing_count if existing_count > 0 else 1.0

    is_valid, msg = validation_result
    if not is_valid:
        if "too many new columns" in msg.lower():
            validation_status = "ERROR: Too many new columns"
        elif "doesn't match enough existing columns" in msg.lower():
            validation_status = "ERROR: Insufficient column match"
        else:
            validation_status = "ERROR: Schema validation failed"
    else:
        validation_status = "SUCCESS: Schema validation passed"

    return {
        "validation_status": validation_status,
        "new_columns": list(sorted(new_columns)),
        "new_column_count": new_column_count,
        "new_column_percentage": f"{new_column_percentage:.2%}",
        "max_allowed_new_column_percentage": f"{max_new_column_percentage:.2%}",
        "matching_columns_count": matching_count,
        "match_percentage": f"{match_percentage:.2%}",
        "min_required_match_percentage": f"{min_schema_match_percentage:.2%}",
        "total_columns": len(new_fields),
        "schema_validation_timestamp": datetime.now().isoformat()
    }

def validate_schema_changes(existing_schema: StructType, new_schema: StructType,
                          max_new_column_percentage: float = 0.4,
                          min_schema_match_percentage: float = 0.8) -> Tuple[bool, str]:
    existing_fields = {field.name for field in existing_schema.fields}
    new_fields = {field.name for field in new_schema.fields}

    if len(existing_fields) == 0:
        msg = "No existing schema to compare; skipping validation."
        logger.info(msg)
        return True, msg

    matching_columns = existing_fields.intersection(new_fields)
    match_percentage = len(matching_columns) / len(existing_fields)

    if match_percentage < min_schema_match_percentage:
        error_msg = (
            "\n" + "="*80 + "\n" +
            "SCHEMA VALIDATION ERROR\n" +
            "="*80 + "\n" +
            "The incoming DataFrame schema doesn't match enough existing columns.\n\n" +
            f"Existing table column count: {len(existing_fields)}\n" +
            f"Matching columns: {len(matching_columns)}\n" +
            f"Match percentage: {match_percentage:.2%}\n" +
            f"Minimum required: {min_schema_match_percentage:.2%}\n\n" +
            f"Missing columns that would be populated with nulls:\n" +
            "\n".join(f"  - {col}" for col in sorted(existing_fields - matching_columns)) + "\n" +
            "="*80
        )
        logger.error(error_msg)
        return False, error_msg

    new_columns = new_fields - existing_fields
    new_column_percentage = len(new_columns) / len(existing_fields)

    if new_column_percentage > max_new_column_percentage:
        error_msg = (
            "\n" + "="*80 + "\n" +
            "SCHEMA VALIDATION ERROR\n" +
            "="*80 + "\n" +
            "The incoming DataFrame schema contains too many new columns.\n\n" +
            f"Existing table column count: {len(existing_fields)}\n" +
            f"New columns detected ({len(new_columns)}):\n" +
            "\n".join(f"  - {col}" for col in sorted(new_columns)) + "\n\n" +
            f"New column percentage: {new_column_percentage:.2%}\n" +
            f"Maximum allowed: {max_new_column_percentage:.2%}\n" +
            "="*80
        )
        logger.error(error_msg)
        return False, error_msg

    return True, "Schema changes within acceptable limits."

def table_exists(table: str) -> bool:
    try:
        spark.table(table).limit(1).collect()
        logger.info(f"Table exists: {table}")
        return True
    except Exception as e:
        logger.info(f"Table does not exist: {table}. Exception: {e}")
        return False

def adjust_schema(incoming_df, base_schema):
    for field in base_schema.fields:
        if field.name not in incoming_df.columns:
            logger.info(f"Adding missing column '{field.name}' with null values.")
            incoming_df = incoming_df.withColumn(field.name, F.lit(None).cast(field.dataType))
    return incoming_df

try:
    if not table_exists(full_table_name):
        logger.info("Delta table does not exist. Creating a new managed Delta table...")
        df_clean = df.select([F.col(c).alias(c.replace(' ', '_')) for c in df.columns])
        df_clean.write.format("delta").mode("overwrite").saveAsTable(full_table_name)
        logger.info(f"DataFrame successfully written to: {full_table_name}")
    else:
        logger.info("Delta table exists. Validating schema changes...")
        existing_schema = spark.table(full_table_name).schema

        validation_result = validate_schema_changes(
            existing_schema,
            df.schema,
            max_new_column_percentage=0.4,
            min_schema_match_percentage=0.8
        )

        metrics = get_schema_metrics(
            existing_schema,
            df.schema,
            max_new_column_percentage=0.4,
            min_schema_match_percentage=0.8,
            validation_result=validation_result
        )
        logger.info(f"Schema evolution metrics: {metrics}")

        if not validation_result[0]:
            raise ValueError("Schema validation error encountered. Please review the log above for details.")

        logger.info("Schema validation passed. Adjusting DataFrame schema if necessary...")
        df_adjusted = adjust_schema(df, existing_schema)

        pre_append_count = spark.table(full_table_name).count()
        logger.info(f"Record count before append: {pre_append_count}")

        df_adjusted.write.format("delta").mode("append").saveAsTable(full_table_name)
        logger.info("Append operation successful.")

        post_append_count = spark.table(full_table_name).count()
        logger.info(f"Record count after append: {post_append_count}")
        logger.info(f"Records added: {post_append_count - pre_append_count}")

    logger.info("Previewing the first 5 rows from the Delta table:")
    display(spark.table(full_table_name).limit(5))

except ValueError as ve:
    logger.error(str(ve))
    raise

except Exception as e:
    logger.exception("An unexpected error occurred during Delta table processing.")
    raise