In [0]:
from pyspark.sql import DataFrame
from pyspark.sql import functions as F
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
sns.set(style="whitegrid")

In [0]:
def collect_missing_stats(df: DataFrame, treat_empty_str_as_null: bool = True) -> pd.DataFrame:
    """
    Returns pandas DataFrame: ['column','missing_count','total_count','missing_percent']
    Uses PySpark aggregations (scales to big tables).
    """
    total_count = df.count()
    exprs = []
    for c in df.columns:
        cond = F.col(c).isNull()
        if treat_empty_str_as_null:
            cond = cond | (F.col(c) == '')
        exprs.append(F.sum(F.when(cond | F.isnan(F.col(c)), 1).otherwise(0)).alias(c))
    row = df.select(*exprs).collect()[0].asDict()
    rows = []
    for c in df.columns:
        missing = int(row.get(c, 0) or 0)
        pct = round(100 * missing / total_count, 3) if total_count > 0 else 0.0
        rows.append({"column": c, "missing_count": missing, "total_count": total_count, "missing_percent": pct})
    return pd.DataFrame(rows)

In [0]:
def plot_pct_missing_bar(missing_df: pd.DataFrame, top_n: int = None, figsize=(10,8)):
    df = missing_df.sort_values("missing_percent", ascending=True)
    if top_n:
        df = df.tail(top_n)
    plt.figure(figsize=figsize)
    sns.barplot(x="missing_percent", y="column", data=df, edgecolor="k")
    plt.xlabel("% missing")
    plt.ylabel("Column")
    plt.title("Percent missing per column")
    plt.tight_layout()
    plt.show()

In [0]:
def sample_missing_matrix(df: DataFrame, sample_size: int = 500, seed: int = 42, treat_empty_str_as_null: bool = True):
    total = df.count()
    if total <= sample_size:
        df_s = df
    else:
        frac = float(sample_size) / float(total)
        df_s = df.sample(withReplacement=False, fraction=frac, seed=seed)
        if df_s.count() > sample_size:
            df_s = df_s.limit(sample_size)
    select_exprs = []
    for c in df.columns:
        cond = (F.col(c).isNull())
        if treat_empty_str_as_null:
            cond = cond | (F.col(c) == '')
        select_exprs.append(F.when(cond, 1).otherwise(0).alias(c))
    pdf = df_s.select(*select_exprs).toPandas()
    return pdf

In [0]:
def plot_missing_matrix(pdf, max_cols=80, figsize=(14,6), save_path: str = None):
    """pdf is pandas DataFrame of 0/1 missing indicators (rows x cols)"""
    if pdf.shape[1] > max_cols:
        pdf = pdf.iloc[:, :max_cols]
    plt.figure(figsize=figsize)
    sns.heatmap(pdf.T, cbar=False, cmap="Greys", vmin=0, vmax=1)
    plt.xlabel("Sampled rows")
    plt.ylabel("Columns")
    plt.title(f"Missingness matrix (sample rows={pdf.shape[0]})")
    plt.tight_layout()
    if save_path:
        plt.savefig(save_path, dpi=150, bbox_inches="tight")
    plt.show()