In [2]:
# Cell 1: setup, imports, constants, utility helpers
import os
import json
import math
import time
import shutil
import logging
from pathlib import Path
from typing import Dict, List, Tuple, Any, Optional

import numpy as np
import pandas as pd
import pyarrow as pa
import pyarrow.parquet as pq
import joblib
import torch

# optional: deltalake for reading infer-ready table; fallback to pyarrow dataset if not present
try:
    from deltalake import DeltaTable
    HAS_DELTALAKE = True
except Exception:
    HAS_DELTALAKE = False

# Configure logging
logging.basicConfig(level=logging.INFO, format="%(asctime)s %(levelname)s %(message)s")
log = logging.getLogger("dtc_infer")

# Root paths (Windows)
ROOT = Path(r"C:\engine_module_pipeline")
INFER_READY = ROOT / r"delta\engine_module_infer_ready"
ARTIFACTS_ROOT = ROOT / r"DTC_stage\artifacts"
OUTPUT_ROOT = ROOT / r"DTC_stage\data\Output"
PER_DTC_OUT = OUTPUT_ROOT / "per_dtc"
COMBINED_OUT = OUTPUT_ROOT / "combined"

# Create output dirs
PER_DTC_OUT.mkdir(parents=True, exist_ok=True)
COMBINED_OUT.mkdir(parents=True, exist_ok=True)

# Which DTCs to run inference for (will pick up artifacts)
DTC_CODES = ["P0234", "P0300", "P0420", "P0501", "P0562"]

# Max number of infer-ready rows to process in one run (as requested)
MAX_ROWS = 2000

# Device
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
log.info(f"Using device: {DEVICE}")

# --- tiny helpers ---
def atomic_write_df_to_parquet(df: pd.DataFrame, out_path: Path, partition_on_date: bool = True) -> None:
    """
    Write df to out_path as parquet atomically.
    If partition_on_date True, expects df['date'] present and writes into out_path/date=YYYY-MM-DD/
    Otherwise writes single parquet file out_path.
    """
    out_path = Path(out_path)
    tmp_path = out_path.with_suffix(".tmp.parquet")
    tmp_dir = tmp_path.parent
    tmp_dir.mkdir(parents=True, exist_ok=True)

    if partition_on_date and "date" in df.columns:
        # write per-date files inside out_path, each as parquet
        for date_val, grp in df.groupby("date"):
            date_str = pd.to_datetime(date_val).strftime("%Y-%m-%d")
            part_dir = out_path / f"date={date_str}"
            part_dir.mkdir(parents=True, exist_ok=True)
            file_name = f"part_{int(time.time())}_{date_str}.parquet"
            file_path_tmp = part_dir / (file_name + ".tmp")
            file_path_final = part_dir / file_name
            # write tmp then atomic replace
            grp.to_parquet(file_path_tmp, engine="pyarrow", index=False)
            os.replace(file_path_tmp, file_path_final)
    else:
        # single-file write
        df.to_parquet(tmp_path, engine="pyarrow", index=False)
        os.replace(tmp_path, out_path)

def load_json(path: Path) -> Dict[str, Any]:
    with open(path, "r", encoding="utf-8") as f:
        return json.load(f)

def save_json(obj: Any, path: Path) -> None:
    path.parent.mkdir(parents=True, exist_ok=True)
    with open(path, "w", encoding="utf-8") as f:
        json.dump(obj, f, indent=2)

def ensure_artifact_for_dtc(dtc_code: str) -> Path:
    art = ARTIFACTS_ROOT / dtc_code
    if not art.exists():
        raise FileNotFoundError(f"Artifacts for {dtc_code} not found under {art}")
    return art


2025-10-25 01:45:42,650 INFO Using device: cpu


In [3]:
# Cell 2: read up to `limit` rows from infer-ready table into a DataFrame
def read_infer_ready_batch(limit: int = MAX_ROWS) -> pd.DataFrame:
    """
    Attempt to read up to `limit` rows from the infer-ready Delta table.
    Tries deltalake first; if not available, will scan parquet files under the delta directory.
    Returns pandas.DataFrame with at least the canonical 25 features + row_hash + timestamp + date.
    """
    if not INFER_READY.exists():
        raise FileNotFoundError(f"Infer-ready Delta path not found: {INFER_READY}")

    if HAS_DELTALAKE:
        log.info("Reading infer-ready via deltalake (fast path)")
        dt = DeltaTable(INFER_READY.as_posix())
        # fetch only necessary columns: row_hash, timestamp and canonical features (we don't know exact features file here)
        # We'll ask the table for schema and then select first N rows
        try:
            tbl = dt.to_pyarrow_table()
            df = tbl.to_pandas()
            if len(df) > limit:
                df = df.iloc[:limit].copy()
            # Ensure timestamp is datetime (UTC naive)
            if "timestamp" in df.columns:
                df["timestamp"] = pd.to_datetime(df["timestamp"], utc=True).dt.tz_convert(None)
                df["date"] = df["timestamp"].dt.date
            else:
                raise RuntimeError("infer-ready table missing 'timestamp' column")
            return df
        except Exception as e:
            log.exception("deltalake read path failed; falling back to parquet scanning: %s", e)

    # Fallback: scan parquet files under the delta folder (fast, robust)
    log.info("Reading infer-ready via parquet fallback (scanning files)")
    parquet_files = list((INFER_READY).rglob("*.parquet"))
    parquet_files = sorted(parquet_files, key=lambda p: p.stat().st_mtime, reverse=False)
    if not parquet_files:
        raise FileNotFoundError(f"No parquet files found under {INFER_READY}")
    rows = []
    read = 0
    for pf in parquet_files:
        try:
            table = pq.read_table(pf.as_posix(), columns=None)
            df_part = table.to_pandas()
            # enforce timestamp parsing
            if "timestamp" in df_part.columns:
                df_part["timestamp"] = pd.to_datetime(df_part["timestamp"], utc=True).dt.tz_convert(None)
                df_part["date"] = df_part["timestamp"].dt.date
            else:
                # skip files without timestamp column
                continue
            need = limit - read
            if need <= 0:
                break
            df_chunk = df_part.iloc[:need]
            rows.append(df_chunk)
            read += len(df_chunk)
            if read >= limit:
                break
        except Exception:
            log.exception("Failed reading parquet file %s; skipping", pf)
            continue
    if not rows:
        raise RuntimeError("No rows read from infer-ready source")
    df = pd.concat(rows, ignore_index=True)
    log.info("Read %d rows from infer-ready", len(df))
    return df

# quick smoke read (not executed automatically)
# df_infer = read_infer_ready_batch(100)
# df_infer.head()


In [4]:
# Cell 3: windowing and thresholding utilities

from sklearn.preprocessing import RobustScaler
from sklearn.isotonic import IsotonicRegression

def sliding_windows_from_df(
    df: pd.DataFrame,
    features: List[str],
    window: int,
    stride: int = 1,
) -> Tuple[np.ndarray, List[pd.Timestamp], List[str]]:
    """
    Build sliding windows of shape [N_windows, window, n_features] from df[features].
    Returns (X_windows, end_timestamps, row_hashes_for_end)
    - X_windows: float32 numpy array (N_windows, window, C)
    - end_timestamps: list of timestamps (window end row timestamp)
    - row_hashes_for_end: list of row_hash strings mapped to the window end
    Missing values remain as np.nan (scaler must handle or we impute later).
    """
    n_rows = len(df)
    if n_rows < window:
        return np.zeros((0, window, len(features)), dtype=np.float32), [], []
    idx_starts = list(range(0, n_rows - window + 1, stride))
    X = np.zeros((len(idx_starts), window, len(features)), dtype=np.float32)
    end_ts = []
    end_row_hash = []
    for i, s in enumerate(idx_starts):
        wnd = df.iloc[s : s + window]
        X[i, :, :] = wnd[features].to_numpy(dtype=np.float32)
        end_ts.append(pd.to_datetime(wnd["timestamp"].iloc[-1]))
        end_row_hash.append(str(wnd["row_hash"].iloc[-1]) if "row_hash" in wnd.columns else "")
    return X, end_ts, end_row_hash

def apply_thresholds_with_hysteresis(
    probs: np.ndarray,
    dPdt: np.ndarray,
    thresholds: Dict[str, Any],
) -> np.ndarray:
    """
    probs: 1D array of calibrated probabilities
    dPdt: 1D array of derivative (same length)
    thresholds: dict containing keys:
        'T_on','T_off','rate_threshold','min_consec_on'
    Returns boolean array of same length: True when predicted ON.
    """
    T_on = float(thresholds.get("T_on", 0.75))
    T_off = float(thresholds.get("T_off", max(0.0, T_on - 0.10)))
    rate_thr = float(thresholds.get("rate_threshold", 0.1))
    min_consec = int(thresholds.get("min_consec_on", 1))

    out = np.zeros(len(probs), dtype=np.uint8)
    on = False
    consec = 0
    for i in range(len(probs)):
        p = float(probs[i]) if not (np.isnan(probs[i])) else 0.0
        dp = float(dPdt[i]) if not (np.isnan(dPdt[i])) else 0.0
        trigger = (p >= T_on) or (dp >= rate_thr and p >= max(0.4, T_on - 0.2))
        if on:
            if p <= T_off:
                on = False
                consec = 0
        else:
            if trigger:
                consec += 1
                if consec >= min_consec:
                    on = True
                    consec = 0
            else:
                consec = 0
        out[i] = 1 if on else 0
    return out


In [5]:
# Cell 4 (FIXED): per-DTC inference (returns tidy DataFrame)
from math import ceil
from datetime import datetime

def infer_for_one_dtc(
    dtc_code: str,
    df_batch: pd.DataFrame,
    cadence_seconds: float = 1.0,
    batch_windows_inference: int = 256,
) -> pd.DataFrame:
    """
    Run inference for one DTC over the df_batch (which is a slice of infer-ready rows).
    Returns tidy DataFrame with columns described in the plan.
    This version fixes datetime/diff handling and is more defensive about scalers.
    """
    art = ensure_artifact_for_dtc(dtc_code)
    feature_spec = load_json(art / "feature_spec.json")
    features = feature_spec["features"]
    window_length = int(feature_spec.get("window_length", 64))
    stride = int(feature_spec.get("stride", 1))
    cadence_seconds = float(feature_spec.get("cadence_seconds", cadence_seconds))

    # Load artifacts (fail early with clear message)
    scaler_path = art / f"scaler_{dtc_code}.pkl"
    iso_prec_path = art / f"calib_{dtc_code}_precursor.pkl"
    iso_fault_path = art / f"calib_{dtc_code}_fault.pkl"
    thresholds_path = art / "thresholds.json"
    model_ts_path = art / f"model_{dtc_code}.ts"

    for p in (scaler_path, iso_prec_path, iso_fault_path, thresholds_path, model_ts_path):
        if not p.exists():
            raise FileNotFoundError(f"Missing artifact for {dtc_code}: {p}")

    scaler = joblib.load(scaler_path.as_posix())
    iso_prec = joblib.load(iso_prec_path.as_posix())
    iso_fault = joblib.load(iso_fault_path.as_posix())
    thresholds = load_json(thresholds_path)

    model = torch.jit.load(model_ts_path.as_posix(), map_location="cpu").to(DEVICE)
    model.eval()

    # Ensure the df_batch contains all features; if missing, create columns with NaN and coerce to float
    df_proc = df_batch.copy().reset_index(drop=True)
    for f in features:
        if f not in df_proc.columns:
            df_proc[f] = np.nan
    # coerce feature dtypes to float32 (NaNs remain NaN)
    df_proc[features] = df_proc[features].astype(float)

    # Build sliding windows
    X_wnd, end_ts, end_row_hash = sliding_windows_from_df(df_proc, features, window_length, stride)
    n_w = X_wnd.shape[0]
    if n_w == 0:
        log.info("No windows produced for %s (dataset too small for window %d)", dtc_code, window_length)
        return pd.DataFrame(
            columns=[
                "row_hash", "timestamp", "date", "dtc_code",
                "p_raw_precursor", "p_calib_precursor", "p_raw_fault", "p_calib_fault",
                "dPdt_precursor", "dPdt_fault", "pred_precursor_on", "pred_fault_on",
                "alert_level", "feature_snapshot"
            ]
        )

    # Apply scaler robustly:
    # scaler.transform will error on unexpected shapes; we also need to handle NaNs.
    # Strategy: flatten windows -> (n_w * window, C). If scaler fails due to NaNs, fall back to per-column median imputation.
    X_reshaped = X_wnd.reshape(-1, len(features))  # (n_w * window, C)
    try:
        X_scaled_flat = scaler.transform(X_reshaped)
    except Exception as e:
        log.warning("scaler.transform failed for %s due to %s; performing median impute before transform", dtc_code, e)
        # median impute per-column based on df_proc stats
        col_medians = np.nanmedian(X_reshaped, axis=0)
        nan_mask = np.isnan(X_reshaped)
        X_filled = X_reshaped.copy()
        for ci in range(X_filled.shape[1]):
            X_filled[nan_mask[:, ci], ci] = col_medians[ci] if not np.isnan(col_medians[ci]) else 0.0
        X_scaled_flat = scaler.transform(X_filled)

    X_scaled = X_scaled_flat.reshape(n_w, window_length, len(features)).astype(np.float32)

    # Inference in mini-batches
    p_prec_raw_list = []
    p_fault_raw_list = []
    with torch.no_grad():
        for i in range(0, n_w, batch_windows_inference):
            xb = torch.from_numpy(X_scaled[i : i + batch_windows_inference]).to(DEVICE)  # [B, T, C]
            res = model(xb)
            # handle typical model outputs (dict with keys or tuple)
            if isinstance(res, dict):
                if "p_precursor" in res and "p_fault" in res:
                    pp = res["p_precursor"].cpu().numpy()
                    pf = res["p_fault"].cpu().numpy()
                elif "logits" in res:
                    # logits may be tuple/list with two tensors
                    l = res["logits"]
                    if isinstance(l, (list, tuple)) and len(l) >= 2:
                        pp = torch.sigmoid(l[0]).cpu().numpy()
                        pf = torch.sigmoid(l[1]).cpu().numpy()
                    else:
                        raise RuntimeError("Unexpected 'logits' format from model for %s" % dtc_code)
                else:
                    raise RuntimeError("Unexpected model output dict keys: %s" % list(res.keys()))
            elif isinstance(res, (tuple, list)) and len(res) >= 2:
                # maybe model returned (p_precursor, p_fault)
                pp = res[0].cpu().numpy()
                pf = res[1].cpu().numpy()
            else:
                raise RuntimeError("Unexpected model output type: %s" % type(res))

            # Take last timestep of each window (index -1)
            if pp.ndim == 2:
                p_prec_raw_list.append(pp[:, -1].reshape(-1))
            elif pp.ndim == 1:
                # some models directly output a single prob per window
                p_prec_raw_list.append(pp.reshape(-1))
            else:
                raise RuntimeError("Unexpected shape for pp output: %s" % (pp.shape,))
            if pf.ndim == 2:
                p_fault_raw_list.append(pf[:, -1].reshape(-1))
            elif pf.ndim == 1:
                p_fault_raw_list.append(pf.reshape(-1))
            else:
                raise RuntimeError("Unexpected shape for pf output: %s" % (pf.shape,))

    p_prec_raw = np.concatenate(p_prec_raw_list, axis=0)[:n_w]
    p_fault_raw = np.concatenate(p_fault_raw_list, axis=0)[:n_w]

    # Calibration (isotonic on the scalar probs) with safe fallback
    try:
        p_prec_calib = iso_prec.transform(p_prec_raw)
    except Exception:
        log.warning("precursor calibrator failed for %s; clipping raw probabilities", dtc_code)
        p_prec_calib = np.clip(p_prec_raw, 0.0, 1.0)
    try:
        p_fault_calib = iso_fault.transform(p_fault_raw)
    except Exception:
        log.warning("fault calibrator failed for %s; clipping raw probabilities", dtc_code)
        p_fault_calib = np.clip(p_fault_raw, 0.0, 1.0)

    # dP/dt (finite diff) using plain python datetime diffs (robust)
    # ensure end_ts are datetime objects
    end_ts_dt = []
    for ts in end_ts:
        if isinstance(ts, (pd.Timestamp, np.datetime64)):
            end_ts_dt.append(pd.to_datetime(ts).to_pydatetime())
        elif isinstance(ts, datetime):
            end_ts_dt.append(ts)
        else:
            # fallback: try parse string
            end_ts_dt.append(pd.to_datetime(ts).to_pydatetime())

    dp_prec = np.zeros_like(p_prec_calib, dtype=np.float32)
    dp_fault = np.zeros_like(p_fault_calib, dtype=np.float32)
    for i in range(1, len(end_ts_dt)):
        dt_sec = (end_ts_dt[i] - end_ts_dt[i - 1]).total_seconds()
        if dt_sec <= 0:
            dt_sec = cadence_seconds
        dp_prec[i] = float((p_prec_calib[i] - p_prec_calib[i - 1]) / dt_sec)
        dp_fault[i] = float((p_fault_calib[i] - p_fault_calib[i - 1]) / dt_sec)

    # Thresholding/hysteresis -> boolean series
    thr_prec = {
        "T_on": thresholds.get("prec_on", thresholds.get("prec_on", 0.75)),
        "T_off": thresholds.get("prec_off", thresholds.get("prec_off", max(0.0, thresholds.get("prec_on", 0.75) - 0.1))),
        "rate_threshold": thresholds.get("dPdt_prec", 0.1),
        "min_consec_on": thresholds.get("min_consec_on_prec", 1),
    }
    thr_fault = {
        "T_on": thresholds.get("fault_on", thresholds.get("fault_on", 0.75)),
        "T_off": thresholds.get("fault_off", thresholds.get("fault_off", max(0.0, thresholds.get("fault_on", 0.75) - 0.1))),
        "rate_threshold": thresholds.get("dPdt_fault", 0.1),
        "min_consec_on": thresholds.get("min_consec_on_fault", 1),
    }
    pred_prec = apply_thresholds_with_hysteresis(p_prec_calib, dp_prec, thr_prec)
    pred_fault = apply_thresholds_with_hysteresis(p_fault_calib, dp_fault, thr_fault)

    # alert_level: 0 none, 1 precursor-only, 2 fault, 3 both
    alert_level = (pred_prec.astype(np.uint8) + pred_fault.astype(np.uint8) * 2)

    # Build feature_snapshot as JSON strings (small) for hover/debug
    df_reset = df_proc.reset_index(drop=True)
    n_rows = len(df_reset)
    starts = list(range(0, n_rows - window_length + 1, stride))
    feature_snapshots = []
    for s in starts:
        row = df_reset.iloc[s + window_length - 1]
        snap = {f: (None if (pd.isna(row.get(f))) else float(row.get(f))) for f in features}
        feature_snapshots.append(json.dumps(snap))

    # assemble output DataFrame
    out_df = pd.DataFrame({
        "row_hash": end_row_hash,
        "timestamp": pd.to_datetime(end_ts_dt),
        "date": [d.date() for d in end_ts_dt],
        "dtc_code": [dtc_code] * n_w,
        "p_raw_precursor": p_prec_raw.astype(np.float32),
        "p_calib_precursor": p_prec_calib.astype(np.float32),
        "p_raw_fault": p_fault_raw.astype(np.float32),
        "p_calib_fault": p_fault_calib.astype(np.float32),
        "dPdt_precursor": dp_prec.astype(np.float32),
        "dPdt_fault": dp_fault.astype(np.float32),
        "pred_precursor_on": pred_prec.astype(np.uint8),
        "pred_fault_on": pred_fault.astype(np.uint8),
        "alert_level": alert_level.astype(np.uint8),
        "feature_snapshot": feature_snapshots,
    })

    # ensure proper dtypes
    out_df["timestamp"] = pd.to_datetime(out_df["timestamp"])
    return out_df


In [6]:
# Cell 5: orchestrator runner
def run_inference_batch_and_write(limit: int = MAX_ROWS) -> Dict[str, Any]:
    """
    Read up to `limit` infer-ready rows, run each DTC model, and write per-DTC and combined parquet outputs.
    Returns a summary dict.
    """
    start_ts = time.time()
    df_infer = read_infer_ready_batch(limit)
    if df_infer.empty:
        raise RuntimeError("No infer-ready rows to process")

    # ensure canonical ordering columns exist (we'll not reorder here, each DTC picks columns it needs)
    summary = {"rows_read": len(df_infer), "dtc": {}}
    per_dtc_outputs = []
    for dtc in DTC_CODES:
        try:
            log.info("Running inference for %s", dtc)
            out_df = infer_for_one_dtc(dtc, df_infer)
            # if empty, still create empty file marker
            art_dir = PER_DTC_OUT / dtc
            art_dir.mkdir(parents=True, exist_ok=True)
            timestamp_tag = pd.Timestamp.utcnow().strftime("%Y%m%d_%H%M%S")
            out_file = art_dir / f"{dtc}_infer_{timestamp_tag}.parquet"
            # write partitioned by date under per_dtc/<dtc>/
            atomic_write_df_to_parquet(out_df, art_dir, partition_on_date=True)
            summary["dtc"][dtc] = {"rows": len(out_df), "written": str(art_dir)}
            per_dtc_outputs.append(out_df)
            log.info("Wrote %d rows for %s", len(out_df), dtc)
        except Exception as e:
            log.exception("Inference for %s failed: %s", dtc, e)
            summary["dtc"][dtc] = {"error": str(e)}

    # Combined tidy table: concat all per_dtc_outputs
    combined_df = pd.concat(per_dtc_outputs, ignore_index=True) if per_dtc_outputs else pd.DataFrame()
    if not combined_df.empty:
        # dedupe on (dtc_code,timestamp,row_hash) keeping last (most recent)
        combined_df.sort_values(["dtc_code", "timestamp"], inplace=True)
        combined_df = combined_df.drop_duplicates(subset=["dtc_code", "timestamp", "row_hash"], keep="last")
        # write combined output (partition by date)
        atomic_write_df_to_parquet(combined_df, COMBINED_OUT, partition_on_date=True)
        summary["combined"] = {"rows": len(combined_df), "written": str(COMBINED_OUT)}
        log.info("Wrote combined table with %d rows to %s", len(combined_df), COMBINED_OUT)
    else:
        summary["combined"] = {"rows": 0}

    summary["time_seconds"] = time.time() - start_ts
    return summary

# Run the orchestrator for up to MAX_ROWS rows
summary = run_inference_batch_and_write(limit=MAX_ROWS)
print(json.dumps(summary, indent=2))


2025-10-25 01:45:56,232 INFO Reading infer-ready via deltalake (fast path)
2025-10-25 01:45:56,489 INFO Running inference for P0234
2025-10-25 01:45:58,029 INFO Wrote 1953 rows for P0234
2025-10-25 01:45:58,029 INFO Running inference for P0300
2025-10-25 01:45:59,200 INFO Wrote 1937 rows for P0300
2025-10-25 01:45:59,200 INFO Running inference for P0420
2025-10-25 01:46:00,694 INFO Wrote 1873 rows for P0420
2025-10-25 01:46:00,694 INFO Running inference for P0501
2025-10-25 01:46:01,904 INFO Wrote 1937 rows for P0501
2025-10-25 01:46:01,904 INFO Running inference for P0562
2025-10-25 01:46:03,042 INFO Wrote 1937 rows for P0562
2025-10-25 01:46:03,091 INFO Wrote combined table with 9637 rows to C:\engine_module_pipeline\DTC_stage\data\Output\combined


{
  "rows_read": 2000,
  "dtc": {
    "P0234": {
      "rows": 1953,
      "written": "C:\\engine_module_pipeline\\DTC_stage\\data\\Output\\per_dtc\\P0234"
    },
    "P0300": {
      "rows": 1937,
      "written": "C:\\engine_module_pipeline\\DTC_stage\\data\\Output\\per_dtc\\P0300"
    },
    "P0420": {
      "rows": 1873,
      "written": "C:\\engine_module_pipeline\\DTC_stage\\data\\Output\\per_dtc\\P0420"
    },
    "P0501": {
      "rows": 1937,
      "written": "C:\\engine_module_pipeline\\DTC_stage\\data\\Output\\per_dtc\\P0501"
    },
    "P0562": {
      "rows": 1937,
      "written": "C:\\engine_module_pipeline\\DTC_stage\\data\\Output\\per_dtc\\P0562"
    }
  },
  "combined": {
    "rows": 9637,
    "written": "C:\\engine_module_pipeline\\DTC_stage\\data\\Output\\combined"
  },
  "time_seconds": 6.858848571777344
}


In [8]:
# Cell: Export per-DTC and combined parquet outputs -> strict CSVs (atomic)
import os
import shutil
from pathlib import Path
import pandas as pd
import pyarrow.parquet as pq

ROOT = Path(r"C:\engine_module_pipeline")
PER_DTC_OUT = ROOT / r"DTC_stage\data\Output\per_dtc"
COMBINED_OUT = ROOT / r"DTC_stage\data\Output\combined"
CSV_OUT_DIR = ROOT / r"DTC_stage\data\csv"
CSV_OUT_DIR.mkdir(parents=True, exist_ok=True)

# Define the canonical output columns (must match the DataFrame built during inference)
EXPECTED_COLUMNS = [
    "row_hash",
    "timestamp",
    "date",
    "dtc_code",
    "p_raw_precursor",
    "p_calib_precursor",
    "p_raw_fault",
    "p_calib_fault",
    "dPdt_precursor",
    "dPdt_fault",
    "pred_precursor_on",
    "pred_fault_on",
    "alert_level",
    "feature_snapshot",
]

def _read_all_parquets_under(path: Path) -> pd.DataFrame:
    """Recursively read all .parquet files under 'path' and concat into a single DataFrame.
       If no parquet found, returns empty DataFrame().
    """
    parquet_files = sorted([p for p in path.rglob("*.parquet") if p.is_file()])
    if not parquet_files:
        return pd.DataFrame()
    parts = []
    for pf in parquet_files:
        try:
            # Use pyarrow for robust reading, then to_pandas
            table = pq.read_table(pf.as_posix())
            parts.append(table.to_pandas())
        except Exception as e:
            # fallback: pandas read_parquet
            try:
                parts.append(pd.read_parquet(pf.as_posix()))
            except Exception as e2:
                print(f"Warning: failed to read {pf}: {e} / {e2} — skipping")
                continue
    if not parts:
        return pd.DataFrame()
    df = pd.concat(parts, ignore_index=True, sort=False)
    return df

def _ensure_columns_and_fill(df: pd.DataFrame, expected_cols: list) -> pd.DataFrame:
    """Enforce expected columns order, add missing cols, drop unexpected unnamed columns,
       and fill NaNs with empty string so every cell has a value.
    """
    # Drop weird unnamed columns created by previous CSVs (like 'Unnamed: 0')
    unnamed_cols = [c for c in df.columns if str(c).lower().startswith("unnamed")]
    if unnamed_cols:
        df = df.drop(columns=unnamed_cols, errors="ignore")

    # Add any expected cols that are missing
    for c in expected_cols:
        if c not in df.columns:
            df[c] = ""  # create with empty strings

    # Keep any extra columns beyond expected (but move them after expected cols) — preserve info
    extra_cols = [c for c in df.columns if c not in expected_cols]
    ordered_cols = expected_cols + extra_cols

    # Reorder and fill missing values with empty string
    df = df.reindex(columns=ordered_cols)
    df = df.fillna("")  # strict: no NaNs, no None
    return df

def _atomic_write_csv(df: pd.DataFrame, out_path: Path, index: bool = False) -> None:
    out_path.parent.mkdir(parents=True, exist_ok=True)
    tmp = out_path.with_suffix(out_path.suffix + ".tmp")
    df.to_csv(tmp.as_posix(), index=index)
    # atomic replace
    os.replace(tmp.as_posix(), out_path.as_posix())

summary = {}

# 1) Per-DTC folders
if PER_DTC_OUT.exists():
    for dtc_dir in sorted([d for d in PER_DTC_OUT.iterdir() if d.is_dir()]):
        dtc_code = dtc_dir.name
        df = _read_all_parquets_under(dtc_dir)
        if df.empty:
            print(f"No parquet parts found for {dtc_code} under {dtc_dir}, skipping CSV creation.")
            summary[dtc_code] = {"rows": 0, "csv": None}
            continue
        # Ensure canonical columns and fill
        df_strict = _ensure_columns_and_fill(df, EXPECTED_COLUMNS)
        # Ensure timestamp and date columns are strings in CSV (ISO format)
        if "timestamp" in df_strict.columns:
            df_strict["timestamp"] = pd.to_datetime(df_strict["timestamp"], errors="coerce").astype(str).fillna("")
        if "date" in df_strict.columns:
            df_strict["date"] = df_strict["date"].astype(str).fillna("")
        out_csv = CSV_OUT_DIR / f"{dtc_code}_infer_output.csv"
        _atomic_write_csv(df_strict, out_csv, index=False)
        summary[dtc_code] = {"rows": len(df_strict), "csv": str(out_csv)}

# 2) Combined table
if COMBINED_OUT.exists():
    df_comb = _read_all_parquets_under(COMBINED_OUT)
    if df_comb.empty:
        print("No combined parquet parts found, skipping combined CSV creation.")
        summary["combined"] = {"rows": 0, "csv": None}
    else:
        df_comb_strict = _ensure_columns_and_fill(df_comb, EXPECTED_COLUMNS)
        if "timestamp" in df_comb_strict.columns:
            df_comb_strict["timestamp"] = pd.to_datetime(df_comb_strict["timestamp"], errors="coerce").astype(str).fillna("")
        if "date" in df_comb_strict.columns:
            df_comb_strict["date"] = df_comb_strict["date"].astype(str).fillna("")
        out_csv_comb = CSV_OUT_DIR / f"combined_infer_output.csv"
        _atomic_write_csv(df_comb_strict, out_csv_comb, index=False)
        summary["combined"] = {"rows": len(df_comb_strict), "csv": str(out_csv_comb)}
else:
    print(f"Combined output directory not found at {COMBINED_OUT}, skipping combined CSV creation.")
    summary["combined"] = {"rows": 0, "csv": None}

# Report
print("CSV export summary:")
for k, v in summary.items():
    print(f" - {k}: rows={v['rows']}, csv={v['csv']}")


CSV export summary:
 - P0101: rows=1937, csv=C:\engine_module_pipeline\DTC_stage\data\csv\P0101_infer_output.csv
 - P0125: rows=1905, csv=C:\engine_module_pipeline\DTC_stage\data\csv\P0125_infer_output.csv
 - P0133: rows=1937, csv=C:\engine_module_pipeline\DTC_stage\data\csv\P0133_infer_output.csv
 - P0171: rows=1937, csv=C:\engine_module_pipeline\DTC_stage\data\csv\P0171_infer_output.csv
 - P0217: rows=1905, csv=C:\engine_module_pipeline\DTC_stage\data\csv\P0217_infer_output.csv
 - P0234: rows=1953, csv=C:\engine_module_pipeline\DTC_stage\data\csv\P0234_infer_output.csv
 - P0300: rows=1937, csv=C:\engine_module_pipeline\DTC_stage\data\csv\P0300_infer_output.csv
 - P0420: rows=1873, csv=C:\engine_module_pipeline\DTC_stage\data\csv\P0420_infer_output.csv
 - P0501: rows=1937, csv=C:\engine_module_pipeline\DTC_stage\data\csv\P0501_infer_output.csv
 - P0562: rows=1937, csv=C:\engine_module_pipeline\DTC_stage\data\csv\P0562_infer_output.csv
 - combined: rows=19258, csv=C:\engine_module_pipe