In [0]:
from pyspark.sql import functions as F
import pandas as pd
import matplotlib.pyplot as plt
import numpy as np

df = (
    spark.read.option("header", "true")
    .option("inferSchema", "true")
    .csv("/databricks-datasets/learning-spark-v2/sf-airbnb/sf-airbnb.csv")
)


def collect_missing_stats(df, treat_empty_str_as_null=True):
    """
    Compute missing values per column in a Spark DataFrame.
    Returns pandas DataFrame: ['column','missing_count','total_count','missing_percent']
    """
    total_count = df.count()
    exprs = []

    for c, t in df.dtypes:
        cond = F.col(c).isNull()
        if treat_empty_str_as_null and t == "string":
            cond = cond | (F.col(c) == "")
        if t in ["double", "float", "int", "bigint", "smallint", "tinyint", "decimal"]:
            cond = cond | F.isnan(F.col(c))
        exprs.append(F.sum(F.when(cond, 1).otherwise(0)).alias(c))

    row = df.select(*exprs).collect()[0].asDict()
    rows = []
    for c in df.columns:
        missing = int(row.get(c, 0) or 0)
        pct = round(100 * missing / total_count, 3) if total_count > 0 else 0.0
        rows.append({"column": c, "missing_count": missing, "total_count": total_count, "missing_percent": pct})
    return pd.DataFrame(rows)

missing_stats = collect_missing_stats(df)
missing_stats_sorted = missing_stats.sort_values("missing_percent", ascending=False)

display(missing_stats_sorted.head(20))  

top_20_cols = missing_stats_sorted.head(20)["column"].tolist()

plt.figure(figsize=(10, 6))
plt.barh(missing_stats_sorted.head(20)["column"], missing_stats_sorted.head(20)["missing_percent"], color="steelblue")
plt.gca().invert_yaxis()
plt.xlabel("Percent Missing (%)")
plt.title("Top 20 Columns by Missing Values — Airbnb Dataset")
plt.grid(axis="x", linestyle="--", alpha=0.7)
plt.tight_layout()
plt.show()
pdf_sample = df.select(top_20_cols).toPandas()
missing_matrix = pdf_sample.isnull().astype(int)

plt.figure(figsize=(12, 8))
plt.imshow(missing_matrix.T, aspect="auto", cmap="gray_r", interpolation="nearest")
plt.yticks(range(len(missing_matrix.columns)), missing_matrix.columns)
plt.xticks([])
plt.xlabel("Row index (sample of 300)")
plt.title("Heatmap of Missing Values — Top 20 Airbnb Columns")
plt.colorbar(label="Missing (1 = missing, 0 = present)")
plt.tight_layout()
plt.show()
