# 02 – Configuration and DAG utilities

This notebook contains helper functions to:

- Read a DAG.json from OneLake using the `file:` protocol.
- Validate that the DAG belongs to the expected `SOURCE_NAME`.
- Extract the list of enabled tables for a source.
- Retrieve global settings from the DAG:
  - `base_files` – root folder where parquet files are written.
  - `bronze_schema` – schema name for bronze Delta tables.

It also configures Spark's date/time behaviour to match the original notebooks
(by relaxing ANSI and using legacy time parsing).


In [None]:
# [1] Imports and Spark configuration

import json
from notebookutils import mssparkutils
from pyspark.sql import functions as F


# Relaxed date/time parsing to match legacy behaviour of the original notebooks
spark.conf.set("spark.sql.legacy.timeParserPolicy", "LEGACY")
spark.conf.set("spark.sql.ansi.enabled", "false")


## [2] DAG helper functions

This cell defines:

- `read_dag(dag_path)` – read the DAG.json via OneLake `file:` path.
- `validate_dag_for_source(dag, source_name)` – ensure the DAG is for the expected source.
- `get_tables_for_source(dag, source_name)` – list of enabled tables.
- `get_base_files(dag)` – returns the configured base_files folder.
- `get_bronze_schema(dag)` – returns the bronze schema name (default: `bronze`).


In [None]:
# [2] DAG utilities
def read_dag(dag_path: str) -> dict:
    """
    Read the DAG.json from OneLake using the file: protocol.

    dag_path example: "/lakehouse/default/Files/config/dag_anva_meeus_week.json"
    """
    if not dag_path:
        raise ValueError("dag_path is empty")

    p = dag_path.strip()

    # Strip eventuele prefixes die Fabric hier niet nodig heeft
    if p.startswith("file:"):
        p = p[len("file:"):]
    if p.startswith("/lakehouse/default/"):
        p = p[len("/lakehouse/default/"):]

    # Nu verwachten we zoiets als: 'Files/config/dag_anva_meeus_week.json'
    try:
        txt = mssparkutils.fs.head(p)
    except Exception as e:
        raise ValueError(
            f"Failed to read DAG at '{dag_path}' (normalized '{p}'): {e}"
        ) from e

    return json.loads(txt)

def validate_dag_for_source(dag: dict, source_name: str) -> None:
    """
    Ensure that the DAG belongs to the requested source.
    """
    dag_source = dag.get("source")
    if dag_source != source_name:
        raise ValueError(f"DAG is for source {dag_source}, but expected {source_name}")


def get_tables_for_source(dag: dict, source_name: str):
    """
    Return the list of enabled tables for a given source in the DAG.

    Currently the DAG contains a single 'source', but we keep the function
    flexible in case this changes in the future.
    """
    validate_dag_for_source(dag, source_name)

    tables = dag.get("tables", [])
    enabled = [t for t in tables if t.get("enabled", True)]

    if not enabled:
        raise ValueError(f"No enabled tables found in DAG for source '{source_name}'")

    return enabled


def get_base_files(dag: dict) -> str:
    """
    Return the base_files folder used to store parquet exports.
    Defaults to 'greenhouse_sources' if not specified.
    """
    return dag.get("base_files", "greenhouse_sources")


def normalize_files_path(path: str) -> str:
    """
    Normalize any Lakehouse path to canonical 'Files/...' form.
    Accepts:
      - 'Files/...'
      - 'config/dag.json'
      - '/lakehouse/default/Files/...'
      - 'file:/lakehouse/default/Files/...'
    Returns always: 'Files/...'
    """
    p = (path or "").strip()

    # 1) strip file: prefix
    if p.startswith("file:"):
        p = p[len("file:"):]

    # 2) strip leading slashes
    while p.startswith("/"):
        p = p[1:]

    # 3) strip 'lakehouse/default/' als prefix
    if p.lower().startswith("lakehouse/default/"):
        p = p[len("lakehouse/default/"):]

    # 4) zorg dat het met 'Files/' begint
    if not p.startswith("Files/"):
        p = "Files/" + p

    return p

def fs_ls(path: str):
    """Wrapper around mssparkutils.fs.ls using normalized Lakehouse Files path."""
    return mssparkutils.fs.ls(normalize_files_path(path))


def fs_head(path: str) -> str:
    """Read small text file from Lakehouse (e.g. JSON config)."""
    return mssparkutils.fs.head(normalize_files_path(path))


def spark_read_parquet(path: str):
    """Read parquet from Lakehouse Files as Spark DataFrame."""
    return spark.read.parquet(normalize_files_path(path))


def spark_read_json(path: str):
    """Read JSON from Lakehouse Files as Spark DataFrame."""
    return spark.read.json(normalize_files_path(path))


def spark_read_csv(path: str, **options):
    """Read CSV from Lakehouse Files as Spark DataFrame."""
    return spark.read.options(**options).csv(normalize_files_path(path))

def build_files_path(*segments: str) -> str:
    """
    Join path segments and normalize to 'Files/...'.
    Example:
      build_files_path("config", "dag.json") -> 'Files/config/dag.json'
      build_files_path("Files", "config", "dag.json") -> 'Files/config/dag.json'
    """
    clean = [s.strip("/\\") for s in segments if s]
    rel = "/".join(clean)
    return normalize_files_path(rel)


In [None]:
# [X] Helper: choose worker profile based on last efficiency + last workers

def choose_worker_profile(source_name: str, debug: bool = False) -> int:
    """
    Decide how many workers to use for this source based on the last run:

    - Uses logs.bronze_run_summary (SUMMARY_LOG_TABLE_FULLNAME)
    - Looks at the latest row for this source (by run_ts desc)
    - Reads:
        - last_workers (column 'workers')
        - last_efficiency_pct (column 'efficiency_pct', 0..100)

    Rules (always step of 2 workers up/down):

    Efficiency bands:
      0–20%   -> -2 workers (very bad)
      20–40%  -> -2 workers (bad)
      40–60%  ->  0 workers (keep)
      60–80%  -> +2 workers (good)
      80–100% -> +2 workers (very good)

    Hard bounds:
      MIN_WORKERS = 2
      MAX_WORKERS = 12

    If no historical row exists or something fails:
      -> default 12 workers (large)
    """

    MIN_WORKERS = 2
    MAX_WORKERS = 12
    DEFAULT_WORKERS = 12

    try:
        # If summary table does not exist yet: start with default
        if not spark.catalog.tableExists(SUMMARY_LOG_TABLE_FULLNAME):
            if debug:
                print(f"[MASTER] No summary table yet ({SUMMARY_LOG_TABLE_FULLNAME}), using default {DEFAULT_WORKERS} workers.")
            return DEFAULT_WORKERS

        # Last row for this source with non-null efficiency and workers
        last_df = (
            spark.table(SUMMARY_LOG_TABLE_FULLNAME)
                 .filter(F.col("source") == source_name)
                 .where(F.col("efficiency_pct").isNotNull() & F.col("workers").isNotNull())
                 .orderBy(F.col("run_ts").desc())
                 .limit(1)
        )

        rows = last_df.collect()
        if not rows:
            if debug:
                print(f"[MASTER] No historical efficiency/workers for source={source_name}, using default {DEFAULT_WORKERS} workers.")
            return DEFAULT_WORKERS

        row = rows[0]
        last_eff = float(row["efficiency_pct"])
        last_workers = int(row["workers"])

        # Clamp to sensible bounds
        if last_eff < 0:
            last_eff = 0.0
        if last_eff > 100:
            last_eff = 100.0

        # Determine delta based on efficiency band
        if last_eff < 20.0:
            delta = -2    # very bad
            band = "0-20"
        elif last_eff < 40.0:
            delta = -2    # bad
            band = "20-40"
        elif last_eff < 60.0:
            delta = 0     # neutral
            band = "40-60"
        elif last_eff < 80.0:
            delta = +2    # good
            band = "60-80"
        else:
            delta = +2    # very good (80-100)
            band = "80-100"

        new_workers = last_workers + delta

        # Respect global min/max
        if new_workers < MIN_WORKERS:
            new_workers = MIN_WORKERS
        if new_workers > MAX_WORKERS:
            new_workers = MAX_WORKERS

        if debug:
            print(
                f"[MASTER] Historical: workers_last={last_workers}, eff_last={last_eff:.1f}% (band {band}), "
                f"delta={delta} -> new_workers={new_workers}"
            )

        return new_workers

    except Exception as e:
        if debug:
            print(f"[MASTER] Error reading historical profile for {source_name}: {e}. Using default {DEFAULT_WORKERS} workers.")
        return DEFAULT_WORKERS


In [None]:
def choose_worker_profile_from_history(
    source_name: str,
    default_workers: int = 8,
    min_workers: int = 2,
    max_workers_cap: int = 12,
    lookback_runs: int = 3,
) -> int:
    """
    Pick a sensible #workers for the next run of `source_name`, based on the
    last N runs in logs.bronze_run_summary.

    - Uses up to `lookback_runs` (1..5) most recent runs for that source.
    - For each run: computes throughput = total_rows / duration_seconds.
    - Aggregates average throughput per workers-count.
    - Chooses the workers-count with best throughput (within 5% of max,
      prefer the smallest workers).
    - Moves from last_workers to target in steps of ±2 workers max.
    - Applies simple volume-based caps (tiny runs don’t get 12 workers).
    """

    # Clamp lookback between 1 and 5
    lookback_runs = max(1, min(lookback_runs, 5))

    # If no summary table yet → just use default
    if not spark.catalog.tableExists(SUMMARY_LOG_TABLE_FULLNAME):
        return int(default_workers)

    base_df = spark.table(SUMMARY_LOG_TABLE_FULLNAME).filter(
        F.col("source") == source_name
    )

    # If there is no data for this source → default
    if base_df.head(1) == []:
        return int(default_workers)

    # Order by run_start if aanwezig, anders run_ts
    cols = base_df.columns
    if "run_start" in cols:
        base_df = base_df.orderBy(F.col("run_start").desc())
    else:
        base_df = base_df.orderBy(F.col("run_ts").desc())

    hist_df = base_df.limit(lookback_runs).select(
        "workers",
        "efficiency_pct",
        "total_rows",
        "duration_seconds",
    )

    rows = hist_df.collect()
    if not rows:
        return int(default_workers)

    # Laatste run is eerste rij in de sortering
    last = rows[0]
    last_workers = int(last["workers"]) if last["workers"] is not None else default_workers

    # Bouw lijst van “geldige” runs (met bruikbare metrics)
    history = []
    for r in rows:
        w = r["workers"]
        eff = r["efficiency_pct"]
        tot = r["total_rows"]
        dur = r["duration_seconds"]

        if w is None or dur is None or dur <= 0 or tot is None or tot <= 0:
            continue

        # Skip extreem slechte eff (<20%) als input voor throughput-analyse
        if eff is not None and eff < 20.0:
            continue

        throughput = float(tot) / float(dur)
        history.append(
            {
                "workers": int(w),
                "eff": float(eff) if eff is not None else None,
                "rows": int(tot),
                "duration": float(dur),
                "throughput": throughput,
            }
        )

    # Als we niets bruikbaars overhouden → houd gewoon last_workers of default
    if not history:
        return int(last_workers or default_workers)

    # 1) Volume-profiel bepalen (mediane rows over deze runs)
    sorted_by_rows = sorted(history, key=lambda x: x["rows"])
    mid = len(sorted_by_rows) // 2
    if len(sorted_by_rows) % 2 == 1:
        median_rows = sorted_by_rows[mid]["rows"]
    else:
        median_rows = int(
            (sorted_by_rows[mid - 1]["rows"] + sorted_by_rows[mid]["rows"]) / 2
        )

    # 2) Gemiddelde throughput per workers berekenen
    from collections import defaultdict

    thr_by_workers = defaultdict(list)
    for h in history:
        thr_by_workers[h["workers"]].append(h["throughput"])

    avg_thr_by_workers = {
        w: sum(vals) / len(vals) for w, vals in thr_by_workers.items()
    }

    # 3) Kies target_workers op basis van throughput
    best_thr = max(avg_thr_by_workers.values())
    # Workers die binnen 5% van best_thr zitten
    candidate_workers = [
        w for w, thr in avg_thr_by_workers.items() if thr >= 0.95 * best_thr
    ]
    # Kies de kleinste daarvan (resources sparen)
    target_workers = min(candidate_workers)

    # 4) Volume-based caps:
    #    - heel kleine runs hebben weinig aan veel workers
    #    - pas deze grenzen gerust aan op basis van jouw landschap
    if median_rows < 100_000:
        target_workers = min(target_workers, 4)
    elif median_rows < 1_000_000:
        target_workers = min(target_workers, 8)
    # bij grotere volumes laten we max_workers_cap de bovengrens zijn

    # 5) Clamp target binnen globale min/max
    target_workers = max(min_workers, min(max_workers_cap, target_workers))

    # 6) Beweeg maximaal ±2 workers t.o.v. laatste run (stapjes)
    if target_workers > last_workers:
        new_workers = min(last_workers + 2, target_workers, max_workers_cap)
    elif target_workers < last_workers:
        new_workers = max(last_workers - 2, target_workers, min_workers)
    else:
        new_workers = last_workers

    return int(new_workers)
