In [0]:
import pandas as pd
import seaborn as sns
from pyspark.sql.functions import sum, when, col, rand
from functools import reduce
import matplotlib.pyplot as plt







In [0]:
def df_shape(df):
    # Rows
    num_rows = df.count()

    # Columns
    num_cols = len(df.columns)

    print(f"{num_rows} rows × {num_cols} columns")

In [0]:
def plot_dist_and_box(
    df,
    cols,
    bins,
    hist_color: str = "#5C666C",
    box_color: str = "#BED62F",
):

    # Accept a single string or any iterable of strings
    if isinstance(cols, str):
        cols = [cols]

    sns.set_theme(style="whitegrid")

    for c in cols:
        if c not in df.columns:
            print(f"Column '{c}' not found – skipping.")
            continue

        # Pull only non-null rows for this column
        pdf = df.select(c).where(col(c).isNotNull()).toPandas()

        # Skip empty or non-numeric data
        if pdf.empty or not pd.api.types.is_numeric_dtype(pdf[c]):
            print(f"Column '{c}' is empty or non-numeric – skipping.")
            continue

        # Create the two-row figure
        fig, (ax_dist, ax_box) = plt.subplots(
            2, 1, figsize=(8, 6), sharex=True, gridspec_kw={"height_ratios": [3, 1]}
        )

        # Histogram + KDE
        sns.histplot(pdf[c], bins=bins, kde=True, color=hist_color, ax=ax_dist)
        ax_dist.set_title(f"Distribution of {c}")
        ax_dist.set_xlabel("")  # Hide x-axis on the top panel
        ax_dist.set_ylabel("Frequency")

        # Horizontal box-plot
        sns.boxplot(x=pdf[c], orient="h", color=box_color, ax=ax_box)
        ax_box.set_xlabel(c)
        ax_box.set_yticks([])

        plt.tight_layout()
        plt.show()

In [0]:
def summarize_nulls_and_dtype(df, only_missing=False):
    """
    Print one line per column:
        <column name> | <# nulls> | <Spark dtype>
    If only_missing is True, only columns with at least one null are shown.
    """
    # compute null counts for every column
    counts_row = (
        df
        .select([ sum(when(col(c).isNull(), 1).otherwise(0)).alias(c)
                  for c in df.columns ])
        .first()                     # a single Row with all counts
        .asDict()
    )

    # header
    header = f"{'column':30} | {'# nulls':>10} | dtype"
    print(header)
    print("-" * len(header))

    # iterate through (name, dtype) pairs
    for name, dtype in df.dtypes:
        nulls = counts_row.get(name, 0)
        # if only_missing, skip columns with zero nulls
        if only_missing and nulls == 0:
            continue
        print(f"{name:30} | {nulls:10,} | {dtype}")



In [0]:

def count_null_overlaps(df, cols):
    """
    For a given DataFrame `df` and list of column names `cols`:
      - Counts rows where all of `cols` are null.
      - Counts rows where some but not all of `cols` are null.
    Prints both counts.
    """

    # build expressions
    any_missing = reduce(lambda a, b: a | b,
                         [col(c).isNull() for c in cols])
    all_missing = reduce(lambda a, b: a & b,
                         [col(c).isNull() for c in cols])

    # perform counts
    all_null_count = df.filter(all_missing).count()
    partial_null_count = df.filter(any_missing & ~all_missing).count()

    # output
    print(f"Rows with all {len(cols)} columns null at once: {all_null_count:,}")
    if partial_null_count == 0:
        print("✅ No partial overlaps: whenever one is null, they’re all null together.")
    else:
        print(f"❌ Found {partial_null_count:,} rows with some nulls but not all.")

    # optionally, return the counts
    return all_null_count, partial_null_count


In [0]:

def inspect_missing(df, column_name):
    """
    Prints the percentage of missing values in `column_name` and
    displays 10 random rows where that column is null.
    """
    # Total rows
    total_count = df.count()
    
    # Rows with null in the specified column
    missing_df = df.filter(col(column_name).isNull())
    missing_count = missing_df.count()
    
    # Compute percentage
    missing_pct = (missing_count / total_count * 100) if total_count > 0 else 0.0
    print(f"{column_name}: {missing_count:,} missing out of {total_count:,} "
          f"→ {missing_pct:.2f}% missing")
    
    # Display 10 sample rows where column is null
    random_nulls = missing_df.orderBy(rand()).limit(10)
    print(f"\nShowing 10 random rows where `{column_name}` IS NULL:")
    display(random_nulls)  

In [0]:

def plot_binary_with_target(df, binary_col, target_col):
    """
    Given a DataFrame `df`, a binary column `binary_col` (0/1),
    and a numeric target column `target_col`, this function:
      1) Plots a horizontal bar chart of counts for each binary class.
      2) Plots a boxplot of `target_col` grouped by `binary_col`.
    """
    # 1) Aggregate counts by the binary column
    status = (
        df
        .groupBy(binary_col)
        .count()
        .toPandas()
    )
    status = status.sort_values(by=binary_col)
    status[binary_col] = status[binary_col].map({0: "False", 1: "True"})
    status['count'] = status['count'].astype(int)
    
    # 2) Prepare data for the boxplot
    box_pdf = (
        df
        .select(binary_col, target_col)
        .na.drop()
        .toPandas()
    )
    box_pdf[binary_col] = box_pdf[binary_col].map({0: "False", 1: "True"})
    
    # 3) Create subplots
    fig, (ax_bar, ax_box) = plt.subplots(1, 2, figsize=(14, 6))
    
    # 4) Bar plot on ax_bar
    sns.barplot(
        data=status,
        y=binary_col,
        x="count",
        palette=main_palette,
        ax=ax_bar
    )
    max_count = status['count'].max()
    for p in ax_bar.patches:
        ax_bar.text(
            p.get_width() + max_count * 0.01,
            p.get_y() + p.get_height() / 2,
            int(p.get_width()),
            va='center'
        )
    ax_bar.set_title(f"Count of {binary_col} (False vs True)", fontsize=14)
    ax_bar.set_xlabel("Count")
    ax_bar.set_ylabel(binary_col)
    
    # 5) Boxplot on ax_box
    sns.boxplot(
        data=box_pdf,
        x=binary_col,
        y=target_col,
        palette=main_palette,
        ax=ax_box
    )
    ax_box.set_title(f"{target_col} Distribution by {binary_col}", fontsize=14)
    ax_box.set_xlabel(binary_col)
    ax_box.set_ylabel(target_col)
    
    plt.tight_layout()
    plt.show()



In [0]:
def plot_categorical_with_target(df, cat_col, target_col):
    """
    Given a DataFrame `df`, a categorical column `cat_col`,
    and a numeric target column `target_col`, this function:
      1) Ignores rows where `cat_col` is null.
      2) Plots a horizontal bar chart of counts for each category in `cat_col`.
      3) Plots a boxplot of `target_col` grouped by `cat_col`.
    """
    # 1) Filter out missing categorical values before aggregation
    df_nonull = df.filter(col(cat_col).isNotNull())
    
    # 2) Aggregate counts by the categorical column
    status = (
        df_nonull
        .groupBy(cat_col)
        .count()
        .toPandas()
    )
    status = status.sort_values(by="count", ascending=False)
    status['count'] = status['count'].astype(int)
    
    # 3) Prepare data for the boxplot (also drop nulls in target)
    box_pdf = (
        df_nonull
        .filter(col(target_col).isNotNull())
        .select(cat_col, target_col)
        .toPandas()
    )
    
    # 4) Create subplots
    fig, (ax_bar, ax_box) = plt.subplots(1, 2, figsize=(16, 6))

    # 5) Bar plot on ax_bar
    sns.barplot(
        data=status,
        y=cat_col,
        x="count",
        palette=main_palette,
        order=status[cat_col],  # ensure bars follow the sorted order
        ax=ax_bar
    )
    max_count = status['count'].max()
    for p in ax_bar.patches:
        ax_bar.text(
            p.get_width() + max_count * 0.01,
            p.get_y() + p.get_height() / 2,
            int(p.get_width()),
            va='center'
        )
    ax_bar.set_title(f"Count of categories in `{cat_col}`", fontsize=14)
    ax_bar.set_xlabel("Count")
    ax_bar.set_ylabel(cat_col)
    
    # 6) Boxplot on ax_box
    sns.boxplot(
        data=box_pdf,
        x=cat_col,
        y=target_col,
        palette=main_palette,
        order=status[cat_col],  # match the same order as bar chart
        ax=ax_box
    )
    ax_box.set_title(f"{target_col} Distribution by `{cat_col}`", fontsize=14)
    ax_box.set_xlabel(cat_col)
    ax_box.set_ylabel(target_col)
    ax_box.tick_params(axis='x', rotation=45)  # rotate x-labels if categories are long
    
    plt.tight_layout()
    plt.show()