### Prepare environment

In [0]:
%run ../environment/prepare_environment

### Prepare utility functions


The following utility functions are used during the Bronze â†’ Silver write operations:
* **add_technical_columns**  
  Adds technical metadata columns such as `_loaded_at` used for ingestion tracking.
* **get_max_loaded_at**  
  Retrieves the maximum `_loaded_at` timestamp from the provided DataFrame.
* **build_merge_condition**  
  Constructs a SQL merge condition based on key columns joining source and target tables.
* **update_processed_flag**  
  Updates the `_is_processed` flag for records that were loaded on or before the specified timestamp.
* **merge_into_table**  
  Executes a Delta MERGE operation to upsert records and optionally delete unmatched target rows.

In [0]:
import logging
from delta import DeltaTable
from datetime import datetime
from pyspark.sql import DataFrame
from pyspark.sql import functions as F

logging.basicConfig(level=logging.ERROR)
logger = logging.getLogger(__name__)

spark.sql("USE CATALOG ai_ml_in_practice")


def add_technical_columns(df: DataFrame) -> DataFrame:

    return (
        df.withColumn("_loaded_at", F.current_timestamp())
    )

def get_max_loaded_at(df: DataFrame) -> datetime:

    max_loaded_at_row = df.agg(
        F.max("_loaded_at").alias("MAX_LOADED_AT")
    ).collect()
    max_loaded_at = max_loaded_at_row[0]["MAX_LOADED_AT"]

    return max_loaded_at

def build_merge_condition(
    merge_keys: list[str], target_alias: str = "target", source_alias: str = "source") -> str:
    return " AND ".join(
        [f"{target_alias}.{key} = {source_alias}.{key}" for key in merge_keys]
    )

def update_processed_flag(table: str, max_loaded_at: datetime) -> None:
    delta_table = DeltaTable.forName(spark, table)

    delta_table.update(
        condition=(F.col("_is_processed") == F.lit(False))
        & (F.col("_loaded_at") <= max_loaded_at),
        set={"_is_processed": F.lit(True)},
    )

def merge_into_table(
    df: DataFrame,
    table: str,
    merge_keys: list[str],
    delete_unmatched_table_rows: bool = False,
) -> None:
    if not spark.catalog.tableExists(table):
        logging.warning(
            f"Table {table} doesn't exist. Appending the DataFrame to the empty table"
        )
        df.write.format("delta").mode("append").option("mergeSchema", "false").saveAsTable(
            table
        )
        return

    logging.info(f"Executing merge into Delta table: {table}")
    delta_table = DeltaTable.forName(spark, table)

    merge_condition = build_merge_condition(
        merge_keys=merge_keys, target_alias="table", source_alias="df"
    )
    logging.debug(f"Using merge condition: {merge_condition}")

    update_columns = {
        col: f"df.{col}" for col in df.columns if col not in ["_loaded_at"]
    }

    merge_builder = (
        delta_table.alias("table")
        .merge(source=df.alias("df"), condition=merge_condition)
        .whenMatchedUpdate(set=update_columns)
        .whenNotMatchedInsertAll()
    )

    if delete_unmatched_table_rows:
        merge_builder = merge_builder.whenNotMatchedBySourceDelete()

    merge_builder.execute()
    logging.info("Merge operation completed")

### Load bronze data to silver table.

During data loading from bronze to silver layer, we will introduce several operations:
* column names standardization
* column types casting:
  * `senior_citizen`, `partner`, `dependants`, `phone_service`, `internet_service`, `paperless_billing`, `churn` to `Boolean`
  * `tenure` to `Integer`
* handling missing data
* outliers removal

In [0]:
bronze_df = spark.sql(
    """
    SELECT
    customer_id as customer_id,
    gender as gender,
    seniorCitizen as senior_citizen,
    partner as partner,
    Dependents as dependents,
    tenure as tenure,
    Phone_Service as phone_service,
    MultipleLines as multiple_lines,
    InternetService as internet_service,
    Online_Security as online_security,
    Online_Backup as online_backup,
    DeviceProtection as device_protection,
    TechSupport as tech_support,
    StreamingTV as streaming_tv,
    StreamingMovies as streaming_movies,
    Contract as contract,
    paperlessbilling as paperless_billing,
    paymentmethod as payment_method,
    MonthlyCharges as monthly_charges,
    TotalCharges as total_charges,
    Churn as churn,
    _loaded_at,
    _is_processed,
    _row_id
    FROM ai_ml_in_practice.telco_customer_churn_bronze.telco_bronze
    WHERE _is_processed = False
    """
    )

In [0]:
from pyspark.sql.types import BooleanType, ShortType, IntegerType
from pyspark.sql.functions import col, when

# Convert string columns to binary format
binary_columns = ["senior_citizen", "partner", "dependents", "phone_service", "internet_service", "paperless_billing", "churn"]
for column in binary_columns:
    bronze_df = bronze_df.withColumn(column, col(column).cast(BooleanType()))

# Convert tenure to integer
bronze_df = bronze_df.withColumn("tenure", col("tenure").cast(IntegerType()))

bronze_df.printSchema()
display(bronze_df)

### Outliers

Outliers are data points that fall far outside the typical range of values in a dataset. Common methods for handling outliers include removing them, filtering, transforming the data, or replacing outliers with more representative values.

Identified outliers:
* negative values in `total_charges`

In [0]:
from pyspark.sql.functions import col

# Use .filter method and SQL col() function
bronze_df = bronze_df.filter((col("total_charges") > 0) | (col("TotalCharges").isNull()))


### Handling Missing Values

To handle missing values we need to identify places with high percentages of missing data. We can later either:
* drop whole columns
* drop certain rows
* impute columns with some default value
* impute numeric columns using statistical methods
* treat null as categorical feature

Identified problems:
* `internet_service` will be imputed based on `online_security` column
* rows with most missing columns will be removed
* remove each row where `total_charge` or `tenure` is missing
* impute `Boolean` columns with default `False` and `String` columns with `N/A`

In [0]:
from pyspark.sql import DataFrame
from pyspark.sql import functions as F

def count_nulls(df: DataFrame) -> DataFrame:
    """
    Returns a DataFrame containing the number of null values for each column.
    
    Parameters:
        df (DataFrame): Input Spark DataFrame.
    
    Returns:
        DataFrame: Two-column DataFrame with column name and null count.
    """
    null_counts = [
        F.sum(F.col(c).isNull().cast("int")).alias(c)
        for c in df.columns
    ]

    return df.agg(*null_counts)

display(count_nulls(bronze_df))

In [0]:
# Imputation on internet_service column
bronze_df = bronze_df.withColumn(
    "internet_service",
    F.when(F.col("online_security") == "No internet service", F.lit(False))
     .otherwise(F.lit(True))
)

display(count_nulls(bronze_df))

In [0]:
# Rows with missing data removal
n_cols = len(bronze_df.columns)
bronze_df = bronze_df.na.drop(how='any', thresh=round(n_cols*.30))

display(count_nulls(bronze_df))

In [0]:
from pyspark.sql.types import StringType

# String columns imputation
string_cols = [c.name for c in bronze_df.schema.fields if c.dataType == StringType()]
bronze_df = bronze_df.na.fill(value='N/A', subset=string_cols)

display(count_nulls(bronze_df))

In [0]:
from pyspark.sql.types import BooleanType

# Boolean columns imputation
bool_cols = [c.name for c in bronze_df.schema.fields if (c.dataType == BooleanType())]
bronze_df = bronze_df.na.fill(value=False, subset=bool_cols)

display(count_nulls(bronze_df))

In [0]:
# Drop rows with missing tenure or total_charges
bronze_df = bronze_df.na.drop(subset=["tenure", "total_charges"])

display(count_nulls(bronze_df))


### Great Expectations

Great Expectations validation created in `2.4_telco_great_expectations` will be used to validate the DataFrame before writing it to silver layer.

In [0]:
import sys
import great_expectations as gx

# Get GE context and validation definition
context_root_dir = "telco_ge"
context = gx.get_context(context_root_dir=context_root_dir)
validation_definition = context.validation_definitions.get("telco_silver_validation")

batch_parameters = {"dataframe": bronze_df}

validation_results = validation_definition.run(batch_parameters=batch_parameters)
print(validation_results)

if not validation_results.get("success", False):
    print("Validation failed! Stopping notebook execution.")
    sys.exit(1)

In [0]:
# Get last bronze load date
max_loaded_at = get_max_loaded_at(bronze_df)

# Create silver dataframe
silver_df = add_technical_columns(bronze_df).drop("_loaded_at")

# Merge into silver table
spark.sql("CREATE SCHEMA IF NOT EXISTS telco_customer_churn_silver")
merge_into_table(silver_df, "ai_ml_in_practice.telco_customer_churn_silver.telco_silver", ["_row_id"])

# Update bronze table
update_processed_flag("ai_ml_in_practice.telco_customer_churn_bronze.telco_bronze", max_loaded_at)