In [0]:
# COMMAND ----------
from __future__ import annotations
from typing import Dict, Any, Iterable, Optional
from pyspark.sql import DataFrame
from pyspark.sql.functions import col, explode, lit
from pyspark.sql.types import StructType
logger = get_logger("functions")

In [0]:
# Calling logging configuration file
 %run ./logging_config

In [0]:
# Calling configurations notebook
%run ./config

In [0]:
# --- I/O helpers -------------------------------------------------------------

def read_json_files(
    path: str,
    *,
    schema: Optional[StructType] = None,
    mode: str = "PERMISSIVE"
) -> DataFrame:
    """
    Read JSON files into a Spark DataFrame.

    Parameters
    ----------
    path : str
        DBFS path or mount (e.g., '/mnt/raw/nasa/neo').
    schema : StructType, optional
        Optional explicit schema for performance and stability.
    mode : str
        Parse mode: 'PERMISSIVE', 'FAILFAST', or 'DROPMALFORMED'.

    Returns
    -------
    DataFrame
        Parsed DataFrame.
    """
    logger.info(f"Reading JSON from {path} (mode={mode})")
    df = spark().read.option("mode", mode).json(path, schema=schema)
    logger.info(f"Read {df.count()} rows / {len(df.columns)} columns")
    return df

In [0]:
def write_delta(
    df: DataFrame,
    path: str,
    *,
    mode: str = "append",
    merge_schema: bool = True
) -> None:
    """
    Write a DataFrame to Delta.

    Parameters
    ----------
    df : DataFrame
        DataFrame to persist.
    path : str
        Output Delta path.
    mode : str
        Save mode: 'append' or 'overwrite'.
    merge_schema : bool
        Whether to merge schema on write.

    Raises
    ------
    DataQualityError
        If df is empty, to prevent writing bad checkpoints.
    """
    if df.rdd.isEmpty():
        raise DataQualityError(f"Refusing to write empty DataFrame to {path}")
    logger.info(f"Writing Delta to {path} (mode={mode}, merge_schema={merge_schema})")
    (df.write
       .format("delta")
       .mode(mode)
       .option("mergeSchema", str(merge_schema).lower())
       .save(path))
    logger.info("Write complete.")


In [0]:
# --- Transformations (example shapes for NASA NEO) ---------------------------

def flatten_neo_feed(df: DataFrame) -> DataFrame:
    """
    Flatten a NASA NEO 'feed' JSON structure into a row-per-object DataFrame.

    Notes
    -----
    This assumes input like:
      {'near_earth_objects': {'2020-01-01': [ {...}, {...} ], '2020-01-02': [...] }, ...}

    Parameters
    ----------
    df : DataFrame
        Raw DataFrame loaded from the feed format.

    Returns
    -------
    DataFrame
        Flattened DataFrame with one row per NEO object and a 'close_approach_date' column.
    """
    logger.info("Flattening NEO feed structure")
    # near_earth_objects is a map[date -> array]
    neo_map_col = "near_earth_objects"
    exploded_dates = df.selectExpr(f"stack(1, {neo_map_col}) as tmp") if neo_map_col not in df.columns else df.select(neo_map_col)
    if neo_map_col in df.columns:
        # explode map into (date, array)
        exploded = df.selectExpr("inline(near_earth_objects) as (close_date, objs)")
    else:
        exploded = df  # fallback for already exploded structures

    # explode the array of objects
    flattened = exploded.select(col("close_date"), explode(col("objs")).alias("obj"))
    # Pull common fields to top-level; modify as needed based on your raw schema
    cols = [
        col("close_date").alias("close_approach_date"),
        col("obj.id").cast("string").alias("neo_id"),
        col("obj.name").alias("name"),
        col("obj.is_potentially_hazardous_asteroid").alias("is_hazardous"),
        col("obj.nasa_jpl_url").alias("reference_url"),
    ]
    # Example nested: close_approach_data[0].relative_velocity.kilometers_per_second
    # Keep it defensive for missing arrays
    first = col("obj.close_approach_data")[0]
    velocity = first["relative_velocity"]["kilometers_per_second"].cast("double").alias("kps")
    miss_distance_km = first["miss_distance"]["kilometers"].cast("double").alias("miss_km")

    result = flattened.select(*cols, velocity, miss_distance_km)
    logger.info(f"Flattened to {result.count()} rows")
    return result


In [0]:
def expect_non_empty(df: DataFrame, *, step: str) -> None:
    """
    Assert that a DataFrame is non-empty or raise DataQualityError.

    Parameters
    ----------
    df : DataFrame
        DataFrame to validate.
    step : str
        Step name to include in the error message.

    Raises
    ------
    DataQualityError
        If the DataFrame is empty.
    """
    if df.rdd.isEmpty():
        raise DataQualityError(f"Validation failed: '{step}' produced an empty DataFrame.")
