# Renewable Energy Forecasting Pipeline

This notebook walks through building a **next-24h renewable generation forecast system** with:

- **EIA data integration** - Hourly wind/solar generation for US regions
- **Weather features** - Open-Meteo integration (wind speed, solar radiation)
- **Probabilistic forecasting** - Dual prediction intervals (80%, 95%)
- **Drift monitoring** - Automatic detection of model degradation

## Architecture Overview

```
```
┌─────────────┐      ┌──────────────┐      ┌─────────────┐
│  EIA API    │──┬──▶│ Data         │──┬──▶│ StatsForecast│
│ (WND/SUN)   │  │   │ Pipeline     │  │   │ Models       │
└─────────────┘  │   └──────────────┘  │   └─────────────┘
                 │                     │           │
┌─────────────┐  │   ┌──────────────┐  │   ┌─────▼──────┐
│ Open-Meteo  │──┘   │ Validation   │  │   │Probabilistic│
│ Weather API │      │ & Quality    │  │   │Forecasts    │
└─────────────┘      │ Gates        │  │   │(80%, 95% CI)│
                     └──────────────┘  │   └────────────┘
                                       │           │
                                       │   ┌───────▼─────┐
                                       └──▶│  Artifacts  │
                                           │  Commit &   │
                                           │  Dashboard  │
                                           └─────────────┘
```

## Key Concepts

1. **StatsForecast format**: `[unique_id, ds, y]` - where `unique_id` = `{region}_{fuel_type}`
2. **Zero-value handling**: Solar generates 0 at night - we use RMSE/MAE, NOT MAPE
3. **Leakage prevention**: Use **forecasted** weather for predictions, not historical
4. **Drift detection**: Threshold = mean + 2*std from backtest

## Setup

First, let's ensure we have the project root in our path, configure logging and set up our pyproject.toml (will be updated over time.)

In [1]:
import sys
import logging
from pathlib import Path
import os 

# Add project root to path
project_root = r"c:\docker_projects\atsaf"
if str(project_root) not in sys.path:
    sys.path.insert(0, str(project_root))

if os.getcwd() != str(project_root):
    os.chdir(project_root)
    print(f"Changed working directory to project root: {project_root} we are currently at {os.getcwd()}")

# Configure logging for visibility
logging.basicConfig(
    level=logging.INFO,
    format="%(asctime)s - %(name)s - %(levelname)s - %(message)s"
)

print(f"Project root: {project_root}")

Changed working directory to project root: c:\docker_projects\atsaf we are currently at c:\docker_projects\atsaf
Project root: c:\docker_projects\atsaf


In [2]:
%%writefile pyproject.toml
[build-system]
requires = ["hatchling"]
build-backend = "hatchling.build"

[project]
name = "atsaf"
version = "0.1.0"
description = "Applied Time Series Analysis and Forecasting - Learning with real EIA data"
readme = "README.md"
license = {text = "MIT"}
authors = [
    {name = "Time Series Learner", email = "learner@example.com"}
]
keywords = ["time-series", "forecasting", "analysis", "eia", "electricity"]
classifiers = [
    "Development Status :: 3 - Alpha",
    "Intended Audience :: Education",
    "Topic :: Scientific/Engineering :: Information Analysis",
    "Programming Language :: Python :: 3",
    "Programming Language :: Python :: 3.11",
    "Programming Language :: Python :: 3.12",
    "Programming Language :: Python :: 3.13",
]
requires-python = ">=3.11,<3.14"



dependencies = [
    # Data fetching and processing
    "requests>=2.28.0",           # HTTP requests for API calls
    "pandas>=1.5.0",              # Data manipulation and analysis
    "numpy>=1.24.0",              # Numerical computing
    "duckdb>=0.8.0",              # Fast SQL queries on data
    "python-dotenv>=1.0.0",       # Load environment variables from .env file

    # Time series forecasting
    "statsforecast>=1.5.0",       # Statistical forecasting methods
    "mlforecast>=0.10.0",         # ML-based forecasting
    "nixtla>=0.0.1",              # Nixtla ecosystem utilities
    "skforecast>=0.14.0",         # ML-based recursive forecasting
    "lightgbm>=4.0.0",            # Gradient boosting for forecasting

    # Model interpretability
    "shap>=0.50.0",               # SHAP values for model explanations

    # Statistical analysis and modeling
    "statsmodels>=0.14.0",        # Statistical modeling and tests
    "scipy>=1.10.0",              # Scientific computing
    "scikit-learn>=1.2.0",        # Machine learning and preprocessing

    # Visualization
    "matplotlib>=3.7.0",          # Static plotting
    "seaborn>=0.12.0",            # Statistical data visualization
    "plotly>=5.0.0",              # Interactive web-based visualization

    # MLOps and workflow
    "mlflow>=2.0.0",              # Experiment tracking and model registry
    "apache-airflow>=2.5.0",      # Workflow orchestration

    # CLI and monitoring (Chapter 3-4)
    "typer>=0.9.0",               # CLI framework
    "rich>=13.0.0",               # Pretty CLI output
    "evidently>=0.4.0",           # Drift detection and monitoring
    "streamlit>=1.30.0",          # Dashboard UI
    "streamlit-mermaid~=0.3.0",   # Mermaid diagrams in Streamlit (latest 0.3.0) 
    "pydantic-settings>=2.0.0",   # Type-safe configuration
    "protobuf>=4.21,<7",
    "ipykernel>=6.0.0",            # IPython kernel for Jupyter

]

[project.optional-dependencies]
dev = [
    "pytest>=7.3.0",              # Testing framework
    "pytest-cov>=4.1.0",          # Coverage reporting
    "black>=23.3.0",              # Code formatter
    "ruff>=0.0.270",              # Fast Python linter
    "mypy>=1.0.0",                # Static type checker
]

jupyter = [
    "jupyter>=1.0.0",             # Jupyter notebook
    "jupyterlab>=4.0.0",          # JupyterLab
    "ipywidgets>=8.0.0",          # Interactive widgets
]

[project.urls]
Homepage = "https://github.com/RamiKrispin/atsaf"
Documentation = "https://github.com/RamiKrispin/atsaf"
Repository = "https://github.com/RamiKrispin/atsaf.git"
Issues = "https://github.com/RamiKrispin/atsaf/issues"

[tool.hatch.build.targets.wheel]
packages = ["src"]

[tool.black]
line-length = 100
target-version = ['py311', 'py312', 'py313']
include = '\.pyi?$'
extend-exclude = '''
/(
  # directories
  \.eggs
  | \.git
  | \.hg
  | \.mypy_cache
  | \.tox
  | \.venv
  | build
  | dist
)/
'''

[tool.ruff]
line-length = 100
select = [
    "E",    # pycodestyle errors
    "W",    # pycodestyle warnings
    "F",    # Pyflakes
    "I",    # isort
    "C",    # flake8-comprehensions
    "B",    # flake8-bugbear
]
ignore = [
    "E501",  # line too long (handled by black)
    "W503",  # line break before binary operator
]

[tool.mypy]
python_version = "3.11"
warn_return_any = true
warn_unused_configs = true
disallow_untyped_defs = false
disallow_incomplete_defs = false
check_untyped_defs = false
no_implicit_optional = true
warn_redundant_casts = true
warn_unused_ignores = true
warn_no_return = true

[tool.pytest.ini_options]
testpaths = ["tests"]
python_files = ["test_*.py"]
python_classes = ["Test*"]
python_functions = ["test_*"]
addopts = "-v --cov=. --cov-report=html --cov-report=term"


Overwriting pyproject.toml


---

# Module 1: Region Definitions

**File:** `src/renewable/regions.py`

This module maps **EIA balancing authority regions** to their geographic coordinates. Why do we need coordinates?

- **Weather API lookup**: Open-Meteo requires latitude/longitude
- **Regional analysis**: Compare forecast accuracy across regions
- **Timezone handling**: Each region has a primary timezone

## Key Design Decisions

1. **NamedTuple for RegionInfo**: Immutable, type-safe, and memory-efficient
2. **Centroid coordinates**: Approximate centers - good enough for hourly weather
3. **Fuel type codes**: `WND` (wind), `SUN` (solar) - match EIA's API

In [3]:
%%writefile src/renewable/regions.py
# src/renewable/regions.py
from __future__ import annotations

from typing import NamedTuple, Optional


class RegionInfo(NamedTuple):
    """Region metadata for EIA and weather lookups."""
    name: str
    lat: float
    lon: float
    timezone: str
    # Some internal regions may not map cleanly to an EIA respondent.
    # We keep them in REGIONS for weather/features, but EIA fetch requires this.
    eia_respondent: Optional[str] = None


REGIONS: dict[str, RegionInfo] = {
    # Western Interconnection
    "CALI": RegionInfo(
        name="California ISO",
        lat=36.7,
        lon=-119.4,
        timezone="America/Los_Angeles",
        eia_respondent="CISO",
    ),
    "NW": RegionInfo(
        name="Northwest",
        lat=45.5,
        lon=-122.0,
        timezone="America/Los_Angeles",
        eia_respondent=None,  # intentionally unset until verified
    ),
    "SW": RegionInfo(
        name="Southwest",
        lat=33.5,
        lon=-112.0,
        timezone="America/Phoenix",
        eia_respondent=None,  # intentionally unset until verified
    ),

    # Texas Interconnection
    "ERCO": RegionInfo(
        name="ERCOT (Texas)",
        lat=31.0,
        lon=-100.0,
        timezone="America/Chicago",
        eia_respondent="ERCO",
    ),

    # Midwest
    "MISO": RegionInfo(
        name="Midcontinent ISO",
        lat=41.0,
        lon=-93.0,
        timezone="America/Chicago",
        eia_respondent="MISO",
    ),
    "PJM": RegionInfo(
        name="PJM Interconnection",
        lat=39.0,
        lon=-77.0,
        timezone="America/New_York",
        eia_respondent="PJM",
    ),
    "SWPP": RegionInfo(
        name="Southwest Power Pool",
        lat=37.0,
        lon=-97.0,
        timezone="America/Chicago",
        eia_respondent="SWPP",
    ),

    # Internal/aggregate regions kept for non-EIA use (weather/features/etc.)
    "SE": RegionInfo(name="Southeast", lat=33.0, lon=-84.0, timezone="America/New_York", eia_respondent=None),
    "FLA": RegionInfo(name="Florida", lat=28.0, lon=-82.0, timezone="America/New_York", eia_respondent=None),
    "CAR": RegionInfo(name="Carolinas", lat=35.5, lon=-80.0, timezone="America/New_York", eia_respondent=None),
    "TEN": RegionInfo(name="Tennessee Valley", lat=35.5, lon=-86.0, timezone="America/Chicago", eia_respondent=None),

    "US48": RegionInfo(name="Lower 48 States", lat=39.8, lon=-98.5, timezone="America/Chicago", eia_respondent=None),
}

FUEL_TYPES = {"WND": "Wind", "SUN": "Solar"}


def list_regions() -> list[str]:
    return sorted(REGIONS.keys())


def get_region_info(region_code: str) -> RegionInfo:
    return REGIONS[region_code]


def get_region_coords(region_code: str) -> tuple[float, float]:
    r = REGIONS[region_code]
    return (r.lat, r.lon)


def get_eia_respondent(region_code: str) -> str:
    """Return the code EIA expects for facets[respondent][]. Fail loudly if missing."""
    info = REGIONS[region_code]
    if not info.eia_respondent:
        raise ValueError(
            f"Region '{region_code}' has no configured eia_respondent. "
            f"Set REGIONS['{region_code}'].eia_respondent to a verified EIA respondent code "
            f"before using it for EIA fetches."
        )
    return info.eia_respondent


def validate_region(region_code: str) -> bool:
    return region_code in REGIONS


def validate_fuel_type(fuel_type: str) -> bool:
    return fuel_type in FUEL_TYPES



if __name__ == "__main__":
    # Example run - test region functions

    print("=== Available Regions ===")
    print(f"Total regions: {len(REGIONS)}")
    print(f"Region codes: {list_regions()}")

    print("\n=== Example: California ===")
    cali_info = get_region_info("CALI")
    print(f"Name: {cali_info.name}")
    print(f"Coordinates: ({cali_info.lat}, {cali_info.lon})")
    print(f"Timezone: {cali_info.timezone}")

    print("\n=== Weather API Coordinates ===")
    for region in ["CALI", "ERCO", "MISO"]:
        lat, lon = get_region_coords(region)
        print(f"{region}: lat={lat}, lon={lon}")

    print("\n=== Fuel Types ===")
    for code, name in FUEL_TYPES.items():
        print(f"{code}: {name}")

    print("\n=== Validation ===")
    print(f"validate_region('CALI'): {validate_region('CALI')}")
    print(f"validate_region('INVALID'): {validate_region('INVALID')}")
    print(f"validate_fuel_type('WND'): {validate_fuel_type('WND')}")


Overwriting src/renewable/regions.py


### Example: Using Region Definitions

---

# Module 2: EIA Data Fetcher

**File:** `src/renewable/eia_renewable.py`

This module fetches **hourly wind and solar generation** from the EIA API.

## Critical Concepts

### StatsForecast Format
StatsForecast expects data in a specific format:
```
unique_id | ds                  | y
----------|---------------------|--------
CALI_WND  | 2024-01-01 00:00:00 | 1234.5
CALI_WND  | 2024-01-01 01:00:00 | 1456.7
ERCO_WND  | 2024-01-01 00:00:00 | 2345.6
```

- `unique_id`: Identifies the time series (e.g., "CALI_WND" = California Wind)
- `ds`: Datetime column (timezone-naive UTC)
- `y`: Target value (generation in MWh)

### API Rate Limiting
- EIA API has rate limits (~5 requests/second)
- We use controlled parallelism with delays

In [4]:
# %%writefile src/renewable/eia_renewable.py
# src/renewable/eia_renewable.py
from __future__ import annotations

import logging
import os
import time
from concurrent.futures import ThreadPoolExecutor, as_completed
from pathlib import Path
from typing import Optional
from urllib.parse import parse_qsl, urlencode, urlsplit, urlunsplit

import pandas as pd
import requests
from requests.adapters import HTTPAdapter
from urllib3.util.retry import Retry
from dotenv import find_dotenv, load_dotenv

from src.renewable.regions import (REGIONS, get_eia_respondent,
                                   validate_fuel_type, validate_region)

logger = logging.getLogger(__name__)


def _sanitize_url(url: str) -> str:
    parts = urlsplit(url)
    q = [(k, v) for k, v in parse_qsl(parts.query, keep_blank_values=True) if k.lower() != "api_key"]
    return urlunsplit((parts.scheme, parts.netloc, parts.path, urlencode(q), parts.fragment))


def _load_env_once(*, debug: bool = False) -> Optional[str]:
    """
    Load .env if present.
    - Primary: find_dotenv(usecwd=True) (walk up from CWD)
    - Fallback: repo_root/.env based on this file location
    Returns the path loaded (or None).
    """
    # 1) Try from current working directory upward
    dotenv_path = find_dotenv(usecwd=True)
    if dotenv_path:
        load_dotenv(dotenv_path, override=False)
        if debug:
            logger.info("Loaded .env via find_dotenv: %s", dotenv_path)
        return dotenv_path

    # 2) Fallback: assume src-layout -> repo root is ../../ from this file
    try:
        repo_root = Path(__file__).resolve().parents[2]
        fallback = repo_root / ".env"
        if fallback.exists():
            load_dotenv(fallback, override=False)
            if debug:
                logger.info("Loaded .env via fallback: %s", str(fallback))
            return str(fallback)
    except Exception:
        pass

    if debug:
        logger.info("No .env found to load.")
    return None


class EIARenewableFetcher:
    BASE_URL = "https://api.eia.gov/v2/electricity/rto/fuel-type-data/data/"
    MAX_RECORDS_PER_REQUEST = 5000
    RATE_LIMIT_DELAY = 0.2  # 5 requests/second max

    def __init__(self, api_key: Optional[str] = None, *, timeout: int = 60, debug_env: bool = False):
        """
        Initialize API key and configuration.

        Args:
            api_key: EIA API key (or reads from EIA_API_KEY env var)
            timeout: Request timeout in seconds (default: 60)
            debug_env: Enable debug logging for environment loading
        """
        loaded_env = _load_env_once(debug=debug_env)

        self.api_key = api_key or os.getenv("EIA_API_KEY")
        if not self.api_key:
            raise ValueError(
                "EIA API key required but not found.\n"
                "- Ensure .env contains EIA_API_KEY=...\n"
                "- Ensure your process CWD is under the repo (so find_dotenv can locate it), OR\n"
                "- Pass api_key=... explicitly.\n"
                f"Loaded .env path: {loaded_env}"
            )

        self.timeout = timeout
        self.session = self._create_session()  # Add retry-enabled session

        # Debug without leaking the key
        if debug_env:
            masked = self.api_key[:4] + "..." + self.api_key[-4:] if len(self.api_key) >= 8 else "***"
            logger.info("EIA_API_KEY loaded (masked): %s", masked)
            logger.info("Request timeout: %d seconds", self.timeout)

    def _create_session(self) -> requests.Session:
        """Create requests Session with retry logic for transient errors."""
        session = requests.Session()
        retries = Retry(
            total=3,
            backoff_factor=1.0,  # 1s, 2s, 4s between retries
            status_forcelist=[429, 500, 502, 503, 504],  # Retry on server errors and rate limits
            allowed_methods=frozenset(["GET"]),
            connect=3,  # Retry on connection errors
            read=3,     # Retry on read timeouts
        )
        session.mount("https://", HTTPAdapter(max_retries=retries))
        return session

    @staticmethod
    def _extract_eia_response(payload: dict, *, request_url: Optional[str] = None) -> tuple[list[dict], dict]:
        if not isinstance(payload, dict):
            raise TypeError(f"EIA payload is not a dict. type={type(payload)} url={request_url}")

        if "error" in payload and payload.get("response") is None:
            raise ValueError(f"EIA returned error payload. url={request_url} error={payload.get('error')}")

        if "response" not in payload:
            raise ValueError(
                f"EIA payload missing 'response'. url={request_url} keys={list(payload.keys())[:25]}"
            )

        response = payload.get("response") or {}
        if not isinstance(response, dict):
            raise TypeError(f"EIA payload['response'] is not a dict. type={type(response)} url={request_url}")

        if "data" not in response:
            raise ValueError(
                f"EIA response missing 'data'. url={request_url} response_keys={list(response.keys())[:25]}"
            )

        records = response.get("data") or []
        if not isinstance(records, list):
            raise TypeError(f"EIA response['data'] is not a list. type={type(records)} url={request_url}")

        total = response.get("total", None)
        offset = response.get("offset", None)

        meta_obj = response.get("metadata") or {}
        if isinstance(meta_obj, dict):
            if total is None and "total" in meta_obj:
                total = meta_obj.get("total")
            if offset is None and "offset" in meta_obj:
                offset = meta_obj.get("offset")

        try:
            total = int(total) if total is not None else None
        except Exception:
            pass
        try:
            offset = int(offset) if offset is not None else None
        except Exception:
            pass

        return records, {"total": total, "offset": offset}

    def fetch_region(
        self,
        region: str,
        fuel_type: str,
        start_date: str,
        end_date: str,
        *,
        debug: bool = False,
        diag: Optional[dict] = None,
    ) -> pd.DataFrame:
        if not validate_region(region):
            raise ValueError(f"Invalid region: {region}")
        if not validate_fuel_type(fuel_type):
            raise ValueError(f"Invalid fuel type: {fuel_type}")

        respondent = get_eia_respondent(region)

        all_records: list[dict] = []
        offset = 0

        # ✅ FIX: initialize loop diagnostics counters
        page_count = 0
        total_hint: Optional[int] = None

        while True:
            params = {
                "api_key": self.api_key,
                "data[]": "value",
                "facets[respondent][]": respondent,
                "facets[fueltype][]": fuel_type,
                "frequency": "hourly",
                "start": f"{start_date}T00",
                "end": f"{end_date}T23",
                "length": self.MAX_RECORDS_PER_REQUEST,
                "offset": offset,
                "sort[0][column]": "period",
                "sort[0][direction]": "asc",
            }

            resp = self.session.get(self.BASE_URL, params=params, timeout=self.timeout)
            resp.raise_for_status()
            payload = resp.json()

            records, meta = self._extract_eia_response(payload, request_url=resp.url)

            page_count += 1
            if total_hint is None:
                total_hint = meta.get("total")

            returned = len(records)

            if debug:
                safe_url = _sanitize_url(resp.url)
                print(
                    f"[PAGE] region={region} fuel={fuel_type} returned={returned} "
                    f"offset={offset} total={meta.get('total')} url={safe_url}"
                )

            # Empty on first page: legitimate empty series for that window
            if returned == 0 and offset == 0:
                if diag is not None:
                    diag.update({
                        "region": region,
                        "fuel_type": fuel_type,
                        "start_date": start_date,
                        "end_date": end_date,
                        "total_records": total_hint,
                        "pages": page_count,
                        "rows_parsed": 0,
                        "empty": True,
                    })
                return pd.DataFrame(columns=["ds", "value", "region", "fuel_type"])

            if returned == 0:
                break

            all_records.extend(records)

            if returned < self.MAX_RECORDS_PER_REQUEST:
                break

            offset += self.MAX_RECORDS_PER_REQUEST
            time.sleep(self.RATE_LIMIT_DELAY)

        df = pd.DataFrame(all_records)

        missing_cols = [c for c in ["period", "value"] if c not in df.columns]
        if missing_cols:
            sample_keys = sorted(set().union(*(r.keys() for r in all_records[:5]))) if all_records else []
            raise ValueError(
                f"EIA records missing expected keys {missing_cols}. "
                f"columns={df.columns.tolist()} sample_record_keys={sample_keys}"
            )

        # EIA returns timestamps in UTC format WITHOUT timezone marker (e.g., "2026-01-21T00")
        # Simply parse and treat as UTC (no conversion needed)
        df["ds"] = pd.to_datetime(df["period"], utc=True, errors="coerce").dt.tz_localize(None)

        df["value"] = pd.to_numeric(df["value"], errors="coerce")

        df["region"] = region
        df["fuel_type"] = fuel_type

        df = df.dropna(subset=["ds", "value"]).sort_values("ds").reset_index(drop=True)

        # Log negative values for investigation (but don't clamp - let dataset builder handle)
        neg_mask = df["value"] < 0
        if neg_mask.any():
            neg_count = int(neg_mask.sum())
            neg_min = float(df.loc[neg_mask, "value"].min())
            neg_max = float(df.loc[neg_mask, "value"].max())
            neg_pct = 100 * neg_count / max(len(df), 1)
            logger.warning(
                "[fetch_region][NEGATIVE] region=%s fuel=%s count=%d (%.1f%%) range=[%.2f, %.2f]",
                region, fuel_type, neg_count, neg_pct, neg_min, neg_max,
            )
            # Log sample for debugging
            neg_sample = df.loc[neg_mask, ["ds", "value"]].head(5)
            for _, row in neg_sample.iterrows():
                logger.debug("  ds=%s value=%.2f", row["ds"], row["value"])

            # NOTE: Keeping negative values in raw data for transparency
            # Dataset builder will handle negatives according to configured policy

        if diag is not None:
            diag.update({
                "region": region,
                "fuel_type": fuel_type,
                "start_date": start_date,
                "end_date": end_date,
                "total_records": total_hint,
                "pages": page_count,
                "rows_parsed": int(len(df)),
                "empty": bool(len(df) == 0),
            })

        return df[["ds", "value", "region", "fuel_type"]]

    def fetch_all_regions(
        self,
        fuel_type: str,
        start_date: str,
        end_date: str,
        regions: Optional[list[str]] = None,
        max_workers: int = 3,
        diagnostics: Optional[list[dict]] = None,
    ) -> pd.DataFrame:
        """Fetch generation data for all regions for a given fuel type.

        Args:
            fuel_type: Fuel type code (WND, SUN, etc.)
            start_date: Start date (YYYY-MM-DD)
            end_date: End date (YYYY-MM-DD)
            regions: List of region codes (defaults to all non-US48 regions)
            max_workers: Number of parallel workers
            diagnostics: Optional list to collect diagnostic info

        Returns:
            DataFrame with columns [unique_id, ds, y]

        Raises:
            RuntimeError: If no regions could be fetched (complete failure)
        """
        if regions is None:
            # Only include regions with configured EIA respondent (exclude US48 and None)
            regions = [
                code for code, info in REGIONS.items()
                if code != "US48" and info.eia_respondent is not None
            ]

        all_dfs: list[pd.DataFrame] = []
        failed_regions: list[tuple[str, str]] = []  # (region, error_msg)

        def _run_one(region: str) -> tuple[str, pd.DataFrame, dict]:
            d: dict = {}
            df = self.fetch_region(region, fuel_type, start_date, end_date, diag=d)
            return region, df, d

        with ThreadPoolExecutor(max_workers=max_workers) as executor:
            futures = {executor.submit(_run_one, region): region for region in regions}
            for future in as_completed(futures):
                region = futures[future]
                try:
                    _, df, d = future.result()
                    if diagnostics is not None:
                        diagnostics.append(d)

                    if len(df) > 0:
                        all_dfs.append(df)
                        print(f"[OK] {region}: {len(df)} rows")
                    else:
                        print(f"[EMPTY] {region}: 0 rows")
                        failed_regions.append((region, "Empty response (0 rows)"))
                except Exception as e:
                    failed_regions.append((region, str(e)))
                    if diagnostics is not None:
                        diagnostics.append({
                            "region": region,
                            "fuel_type": fuel_type,
                            "start_date": start_date,
                            "end_date": end_date,
                            "error": str(e),
                        })
                    print(f"[FAIL] {region}: {e}")

        # Explicit validation: require at least one successful region
        if not all_dfs:
            error_details = "; ".join([f"{r[0]}({r[1][:80]})" for r in failed_regions])
            raise RuntimeError(
                f"[EIA][FETCH] Failed to fetch {fuel_type} data for ALL regions. "
                f"Failures: {error_details}. "
                f"Check EIA API availability, API key validity, network connectivity, "
                f"and consider increasing timeout or reducing concurrency."
            )

        # Warn if partial failure (some regions succeeded, some failed)
        if failed_regions:
            failed_count = len(failed_regions)
            total_count = len(regions)
            print(f"[WARNING] Partial {fuel_type} fetch: {failed_count}/{total_count} regions failed")
            for region, error_msg in failed_regions:
                # Print first 100 chars of error
                print(f"  - {region}: {error_msg[:100]}")

        combined = pd.concat(all_dfs, ignore_index=True)
        combined["unique_id"] = combined["region"] + "_" + combined["fuel_type"]
        combined = combined.rename(columns={"value": "y"})

        result = combined[["unique_id", "ds", "y"]].sort_values(["unique_id", "ds"]).reset_index(drop=True)

        print(f"[SUMMARY] {fuel_type} data: {result['unique_id'].nunique()} series, {len(result)} total rows")

        return result

    def get_series_summary(self, df: pd.DataFrame) -> pd.DataFrame:
        return df.groupby("unique_id").agg(
            count=("y", "count"),
            min_value=("y", "min"),
            max_value=("y", "max"),
            mean_value=("y", "mean"),
            zero_count=("y", lambda x: (x == 0).sum()),
        ).reset_index()


if __name__ == "__main__":
    logging.basicConfig(level=logging.INFO)

    fetcher = EIARenewableFetcher(debug_env=True)

    print("=== Testing Single Region Fetch ===")
    df_single = fetcher.fetch_region("CALI", "WND", "2024-12-01", "2024-12-03", debug=True)
    print(f"Single region: {len(df_single)} rows")
    print(df_single.head())

    print("\n=== Testing Multi-Region Fetch ===")
    df_multi = fetcher.fetch_all_regions("WND", "2024-12-01", "2024-12-03", regions=["CALI", "ERCO", "MISO"])
    print(f"\nMulti-region: {len(df_multi)} rows")
    print(f"Series: {df_multi['unique_id'].unique().tolist()}")

    print("\n=== Series Summary ===")
    print(fetcher.get_series_summary(df_multi))

    # sun checks:
    f = EIARenewableFetcher()
    df = f.fetch_region("CALI", "SUN", "2024-12-01", "2024-12-03", debug=True)
    print(df.head(), len(df))


2026-01-22 16:50:02,132 - __main__ - INFO - Loaded .env via find_dotenv: c:\docker_projects\atsaf\.env
2026-01-22 16:50:02,133 - __main__ - INFO - EIA_API_KEY loaded (masked): 8pqH...8Xk7
2026-01-22 16:50:02,133 - __main__ - INFO - Request timeout: 60 seconds


=== Testing Single Region Fetch ===
[PAGE] region=CALI fuel=WND returned=72 offset=0 total=72 url=https://api.eia.gov/v2/electricity/rto/fuel-type-data/data/?data%5B%5D=value&facets%5Brespondent%5D%5B%5D=CISO&facets%5Bfueltype%5D%5B%5D=WND&frequency=hourly&start=2024-12-01T00&end=2024-12-03T23&length=5000&offset=0&sort%5B0%5D%5Bcolumn%5D=period&sort%5B0%5D%5Bdirection%5D=asc
Single region: 72 rows
                   ds  value region fuel_type
0 2024-12-01 00:00:00    359   CALI       WND
1 2024-12-01 01:00:00    277   CALI       WND
2 2024-12-01 02:00:00    498   CALI       WND
3 2024-12-01 03:00:00    692   CALI       WND
4 2024-12-01 04:00:00    879   CALI       WND

=== Testing Multi-Region Fetch ===
[OK] CALI: 72 rows
[OK] ERCO: 72 rows
[OK] MISO: 72 rows
[SUMMARY] WND data: 3 series, 216 total rows

Multi-region: 216 rows
Series: ['CALI_WND', 'ERCO_WND', 'MISO_WND']

=== Series Summary ===
  unique_id  count  min_value  max_value   mean_value  zero_count
0  CALI_WND     72        



[PAGE] region=CALI fuel=SUN returned=72 offset=0 total=72 url=https://api.eia.gov/v2/electricity/rto/fuel-type-data/data/?data%5B%5D=value&facets%5Brespondent%5D%5B%5D=CISO&facets%5Bfueltype%5D%5B%5D=SUN&frequency=hourly&start=2024-12-01T00&end=2024-12-03T23&length=5000&offset=0&sort%5B0%5D%5Bcolumn%5D=period&sort%5B0%5D%5Bdirection%5D=asc
                   ds  value region fuel_type
0 2024-12-01 00:00:00   7805   CALI       SUN
1 2024-12-01 01:00:00   3369   CALI       SUN
2 2024-12-01 02:00:00    186   CALI       SUN
3 2024-12-01 03:00:00    -47   CALI       SUN
4 2024-12-01 04:00:00    -41   CALI       SUN 72


---

# Module 3: Weather Integration

**File:** `src/renewable/open_meteo.py`

Weather is **critical** for renewable forecasting:
- **Wind generation** depends on wind speed (especially at hub height ~100m)
- **Solar generation** depends on radiation and cloud cover

## Key Concept: Preventing Leakage

**Data leakage** occurs when training uses information that wouldn't be available at prediction time.

```
❌ WRONG: Using historical weather to predict future generation
   - At prediction time, we don't have future actual weather!
   
✅ CORRECT: Use forecasted weather for predictions
   - Training: historical weather aligned with historical generation
   - Prediction: weather forecast for the prediction horizon
```

## Open-Meteo API

Open-Meteo is **free** and requires no API key:
- Historical API: Past weather data
- Forecast API: Up to 16 days ahead

In [5]:
# %%writefile src/renewable/open_meteo.py
# src/renewable/open_meteo.py
from __future__ import annotations

from dataclasses import dataclass
from datetime import datetime, timedelta
from typing import Optional

import pandas as pd
import requests
from requests.adapters import HTTPAdapter
from urllib3.util.retry import Retry

from src.renewable.regions import get_region_coords, validate_region


@dataclass(frozen=True)
class OpenMeteoEndpoints:
    historical_url: str = "https://archive-api.open-meteo.com/v1/archive"
    forecast_url: str = "https://api.open-meteo.com/v1/forecast"


class OpenMeteoRenewable:
    """
    Fetch weather features for renewable energy forecasting.

    Strict-by-default:
    - If Open-Meteo doesn't return a requested variable, we raise.
    - We do NOT fabricate values or silently "fill" missing columns.
    """

    WEATHER_VARS = [
        "temperature_2m",
        "wind_speed_10m",
        "wind_speed_100m",
        "wind_direction_10m",
        "direct_radiation",
        "diffuse_radiation",
        "cloud_cover",
    ]

    def __init__(self, timeout: int = 60, *, strict: bool = True):
        self.timeout = timeout
        self.strict = strict
        self.endpoints = OpenMeteoEndpoints()
        self.session = self._create_session()

    def _create_session(self) -> requests.Session:
        session = requests.Session()
        retries = Retry(
            total=3,
            backoff_factor=1.0,  # 1s, 2s, 4s between retries
            status_forcelist=[429, 500, 502, 503, 504],
            allowed_methods=frozenset(["GET"]),
            connect=3,  # Retry on connection errors
            read=3,     # Retry on read timeouts
        )
        session.mount("https://", HTTPAdapter(max_retries=retries))
        return session

    def fetch_historical(
        self,
        lat: float,
        lon: float,
        start_date: str,
        end_date: str,
        variables: Optional[list[str]] = None,
        *,
        debug: bool = False,
    ) -> pd.DataFrame:
        if variables is None:
            variables = self.WEATHER_VARS

        params = {
            "latitude": lat,
            "longitude": lon,
            "start_date": start_date,
            "end_date": end_date,
            "hourly": ",".join(variables),
            "timezone": "UTC",
        }

        resp = self.session.get(self.endpoints.historical_url, params=params, timeout=self.timeout)
        if debug:
            print(f"[OPENMETEO][HIST] status={resp.status_code} url={resp.url}")
        resp.raise_for_status()

        try:
            data = resp.json()
        except requests.exceptions.JSONDecodeError as e:
            # Log actual response content for debugging
            content_preview = resp.text[:500] if resp.text else "(empty)"
            raise ValueError(
                f"[OPENMETEO][HIST] Invalid JSON response. "
                f"status={resp.status_code} content_preview={content_preview}"
            ) from e

        return self._parse_response(data, variables, debug=debug, request_url=resp.url)

    def fetch_forecast(
        self,
        lat: float,
        lon: float,
        horizon_hours: int = 48,
        variables: Optional[list[str]] = None,
        *,
        debug: bool = False,
    ) -> pd.DataFrame:
        if variables is None:
            variables = self.WEATHER_VARS

        forecast_days = min((horizon_hours // 24) + 1, 16)
        params = {
            "latitude": lat,
            "longitude": lon,
            "hourly": ",".join(variables),
            "timezone": "UTC",
            "forecast_days": forecast_days,
        }

        resp = self.session.get(self.endpoints.forecast_url, params=params, timeout=self.timeout)
        if debug:
            print(f"[OPENMETEO][FCST] status={resp.status_code} url={resp.url}")
        resp.raise_for_status()

        try:
            data = resp.json()
        except requests.exceptions.JSONDecodeError as e:
            content_preview = resp.text[:500] if resp.text else "(empty)"
            raise ValueError(
                f"[OPENMETEO][FCST] Invalid JSON response. "
                f"status={resp.status_code} content_preview={content_preview}"
            ) from e

        df = self._parse_response(data, variables, debug=debug, request_url=resp.url)

        # Trim to requested horizon (ds is naive UTC)
        if len(df) > 0:
            cutoff = datetime.utcnow() + timedelta(hours=horizon_hours)
            df = df[df["ds"] <= cutoff].reset_index(drop=True)

        return df

    def fetch_for_region(
        self,
        region_code: str,
        start_date: str,
        end_date: str,
        *,
        debug: bool = False,
    ) -> pd.DataFrame:
        if not validate_region(region_code):
            raise ValueError(f"Invalid region_code: {region_code}")

        lat, lon = get_region_coords(region_code)
        df = self.fetch_historical(lat, lon, start_date, end_date, debug=debug)
        df["region"] = region_code
        return df

    def fetch_all_regions_historical(
        self,
        regions: list[str],
        start_date: str,
        end_date: str,
        *,
        debug: bool = False,
    ) -> pd.DataFrame:
        all_dfs: list[pd.DataFrame] = []
        for region in regions:
            try:
                df = self.fetch_for_region(region, start_date, end_date, debug=debug)
                all_dfs.append(df)
                print(f"[OK] Weather for {region}: {len(df)} rows")
            except requests.exceptions.Timeout as e:
                print(f"[FAIL] Weather for {region}: TIMEOUT after {self.timeout}s - {type(e).__name__}: {e}")
            except requests.exceptions.ConnectionError as e:
                print(f"[FAIL] Weather for {region}: CONNECTION_ERROR - {type(e).__name__}: {e}")
            except requests.exceptions.JSONDecodeError as e:
                print(f"[FAIL] Weather for {region}: JSON_PARSE_ERROR - {type(e).__name__}: {e}")
            except Exception as e:
                print(f"[FAIL] Weather for {region}: {type(e).__name__}: {e}")

        if not all_dfs:
            return pd.DataFrame()

        return (
            pd.concat(all_dfs, ignore_index=True)
            .sort_values(["region", "ds"])
            .reset_index(drop=True)
        )

    def _parse_response(
        self,
        data: dict,
        variables: list[str],
        *,
        debug: bool,
        request_url: str,
    ) -> pd.DataFrame:
        hourly = data.get("hourly")
        if not isinstance(hourly, dict):
            raise ValueError(f"Open-Meteo response missing/invalid 'hourly'. url={request_url}")

        times = hourly.get("time")
        if not isinstance(times, list) or len(times) == 0:
            raise ValueError(f"Open-Meteo response has no hourly time grid. url={request_url}")

        # Build ds (naive UTC)
        ds = pd.to_datetime(times, errors="coerce", utc=True).tz_localize(None)
        if ds.isna().any():
            bad = int(ds.isna().sum())
            raise ValueError(f"Open-Meteo returned unparsable times. bad={bad} url={request_url}")

        df_data = {"ds": ds}

        # Strict variable presence: raise if missing (no silent None padding)
        missing_vars = [v for v in variables if v not in hourly]
        if missing_vars and self.strict:
            raise ValueError(f"Open-Meteo missing requested vars={missing_vars}. url={request_url}")

        for var in variables:
            values = hourly.get(var)
            if values is None:
                # If not strict, keep as all-NA but be explicit (not hidden)
                df_data[var] = [None] * len(ds)
                continue

            if not isinstance(values, list):
                raise ValueError(f"Open-Meteo var '{var}' not a list. type={type(values)} url={request_url}")

            if len(values) != len(ds):
                raise ValueError(
                    f"Open-Meteo length mismatch for '{var}': "
                    f"len(values)={len(values)} len(time)={len(ds)} url={request_url}"
                )

            df_data[var] = pd.to_numeric(values, errors="coerce")

        df = pd.DataFrame(df_data).sort_values("ds").reset_index(drop=True)

        if debug:
            dup = int(df["ds"].duplicated().sum())
            na_counts = {v: int(df[v].isna().sum()) for v in variables if v in df.columns}
            print(f"[OPENMETEO][PARSE] rows={len(df)} dup_ds={dup} na_counts(sample)={dict(list(na_counts.items())[:3])}")

        return df

    def fetch_for_region_forecast(
        self,
        region_code: str,
        horizon_hours: int = 48,
        variables: Optional[list[str]] = None,
        *,
        debug: bool = False,
    ) -> pd.DataFrame:
        if not validate_region(region_code):
            raise ValueError(f"Invalid region_code: {region_code}")

        lat, lon = get_region_coords(region_code)
        df = self.fetch_forecast(lat, lon, horizon_hours=horizon_hours, variables=variables, debug=debug)
        df["region"] = region_code
        return df


    def fetch_all_regions_forecast(
        self,
        regions: list[str],
        horizon_hours: int = 48,
        variables: Optional[list[str]] = None,
        *,
        debug: bool = False,
    ) -> pd.DataFrame:
        all_dfs: list[pd.DataFrame] = []
        for region in regions:
            try:
                df = self.fetch_for_region_forecast(
                    region, horizon_hours=horizon_hours, variables=variables, debug=debug
                )
                all_dfs.append(df)
                print(f"[OK] Forecast weather for {region}: {len(df)} rows")
            except requests.exceptions.Timeout as e:
                print(f"[FAIL] Forecast weather for {region}: TIMEOUT after {self.timeout}s - {type(e).__name__}: {e}")
            except requests.exceptions.ConnectionError as e:
                print(f"[FAIL] Forecast weather for {region}: CONNECTION_ERROR - {type(e).__name__}: {e}")
            except requests.exceptions.JSONDecodeError as e:
                print(f"[FAIL] Forecast weather for {region}: JSON_PARSE_ERROR - {type(e).__name__}: {e}")
            except Exception as e:
                print(f"[FAIL] Forecast weather for {region}: {type(e).__name__}: {e}")

        if not all_dfs:
            return pd.DataFrame()

        return (
            pd.concat(all_dfs, ignore_index=True)
            .sort_values(["region", "ds"])
            .reset_index(drop=True)
        )



if __name__ == "__main__":
    # Real API smoke test (no key needed)
    weather = OpenMeteoRenewable(strict=True)

    print("=== Testing Historical Weather (REAL API) ===")
    hist_df = weather.fetch_for_region("CALI", "2024-12-01", "2024-12-03", debug=True)
    print(f"Historical rows: {len(hist_df)}")
    print(hist_df.head())


=== Testing Historical Weather (REAL API) ===
[OPENMETEO][HIST] status=200 url=https://archive-api.open-meteo.com/v1/archive?latitude=36.7&longitude=-119.4&start_date=2024-12-01&end_date=2024-12-03&hourly=temperature_2m%2Cwind_speed_10m%2Cwind_speed_100m%2Cwind_direction_10m%2Cdirect_radiation%2Cdiffuse_radiation%2Ccloud_cover&timezone=UTC
[OPENMETEO][PARSE] rows=72 dup_ds=0 na_counts(sample)={'temperature_2m': 0, 'wind_speed_10m': 0, 'wind_speed_100m': 0}
Historical rows: 72
                   ds  temperature_2m  wind_speed_10m  wind_speed_100m  \
0 2024-12-01 00:00:00            12.2             5.9              7.0   
1 2024-12-01 01:00:00             9.2             1.2              5.3   
2 2024-12-01 02:00:00             8.4             2.4              3.6   
3 2024-12-01 03:00:00             8.7             3.8              5.4   
4 2024-12-01 04:00:00            10.3             0.9              2.2   

   wind_direction_10m  direct_radiation  diffuse_radiation  cloud_cover re

# EDA

In [6]:
# %%writefile src/renewable/eda.py
# file: src/renewable/eda.py
"""
Enhanced Exploratory Data Analysis for Renewable Energy Forecasting

This module provides decision-driven EDA with emphasis on:
1. Understanding WHY negative values exist (not just detecting them)
2. Providing actionable recommendations based on findings
3. Validating physical constraints for renewable energy data

Key Principle: Renewable energy generation CANNOT be negative.
- Solar panels produce 0-X power, never negative
- Wind turbines produce 0-X power, never negative
- Negative values in data are ALWAYS data quality issues that need investigation
"""

from __future__ import annotations

import json
import warnings
from dataclasses import dataclass, field, asdict
from datetime import datetime
from pathlib import Path
from typing import Any, Dict, Optional, Tuple

import matplotlib.pyplot as plt
import numpy as np
import pandas as pd

warnings.filterwarnings('ignore', category=UserWarning, module='matplotlib')

plt.rcParams['figure.figsize'] = (12, 6)
plt.rcParams['figure.dpi'] = 100


@dataclass
class PreprocessingRecommendation:
    """Structured recommendation from EDA for preprocessing."""

    # Negative value handling
    negative_policy: str  # 'clamp_to_zero' | 'investigate' | 'pass_through'
    negative_reason: str
    negative_confidence: str  # 'HIGH' | 'MEDIUM' | 'LOW'

    # Grid enforcement
    max_missing_ratio: float = 0.02

    # Series-specific overrides (for different fuel types)
    series_overrides: Dict[str, Dict[str, Any]] = field(default_factory=dict)

    # Metadata
    generated_at: str = field(default_factory=lambda: datetime.now().isoformat())
    data_summary: Dict[str, Any] = field(default_factory=dict)


@dataclass
class EDARecommendations:
    """Complete EDA output with preprocessing recommendations."""

    preprocessing: PreprocessingRecommendation
    modeling: Dict[str, Any]
    evaluation: Dict[str, Any]

    # Full investigation results for audit trail
    negative_investigation: Dict[str, Any]
    seasonality: Dict[str, Any]
    zero_inflation: Dict[str, Any]
    weather_alignment: Dict[str, Any]
    timestamp: str = field(default_factory=lambda: datetime.now().strftime("%Y%m%d_%H%M%S"))

    def save(self, output_path: Path) -> None:
        """Save recommendations to JSON."""
        with open(output_path, 'w') as f:
            json.dump(asdict(self), f, indent=2, default=str)
            f.write('\n')

    @classmethod
    def load(cls, input_path: Path) -> 'EDARecommendations':
        """Load recommendations from JSON."""
        with open(input_path, 'r') as f:
            data = json.load(f)
        return cls(**data)


class NegativeValueInvestigation:
    """
    Deep investigation into why negative values exist in renewable generation data.

    EIA Data Context:
    - EIA reports "net generation" which is gross generation minus station use
    - Station auxiliary loads (cooling, controls, etc.) can exceed generation during:
      * Low-wind periods for wind farms
      * Night/cloudy periods for solar (if inverters consume standby power)
      * Startup/shutdown events

    This is VALID data but represents net metering, not physical impossibility.
    However, for FORECASTING purposes, we typically want to predict gross generation
    or at minimum clamp to zero since negative "production" isn't meaningful for
    grid planning.
    """

    def __init__(self, df: pd.DataFrame):
        """
        Args:
            df: DataFrame with columns [unique_id, ds, y] where y is generation
        """
        self.df = df.copy()
        self.df['ds'] = pd.to_datetime(self.df['ds'])
        self.df['hour'] = self.df['ds'].dt.hour
        self.df['dow'] = self.df['ds'].dt.dayofweek
        self.df['date'] = self.df['ds'].dt.date

    def get_negative_summary(self) -> Dict[str, Any]:
        """Get high-level summary of negative values."""
        neg_mask = self.df['y'] < 0

        summary = {
            'total_rows': len(self.df),
            'negative_count': int(neg_mask.sum()),
            'negative_ratio': float(neg_mask.sum() / len(self.df)) if len(self.df) > 0 else 0,
            'affected_series': self.df.loc[neg_mask, 'unique_id'].unique().tolist(),
            'min_value': float(self.df['y'].min()),
            'max_negative': float(self.df.loc[neg_mask, 'y'].max()) if neg_mask.any() else None,
        }
        return summary

    def analyze_negative_patterns(self) -> Dict[str, Any]:
        """
        Investigate WHEN and WHERE negatives occur to understand root cause.

        Key questions:
        1. Are negatives concentrated in specific hours? (auxiliary load pattern)
        2. Are negatives concentrated in specific series? (regional data issue)
        3. What's the magnitude? (small negatives = metering noise, large = real issue)
        """
        neg_df = self.df[self.df['y'] < 0].copy()

        if len(neg_df) == 0:
            return {'status': 'no_negatives_found'}

        analysis = {
            'by_series': {},
            'by_hour': {},
            'by_dow': {},
            'magnitude_analysis': {},
            'temporal_clustering': {},
        }

        # 1. Analyze by series
        for uid in neg_df['unique_id'].unique():
            series_neg = neg_df[neg_df['unique_id'] == uid]
            series_total = self.df[self.df['unique_id'] == uid]

            fuel_type = uid.split('_')[1] if '_' in uid else 'UNKNOWN'

            analysis['by_series'][uid] = {
                'count': int(len(series_neg)),
                'ratio': float(len(series_neg) / len(series_total)),
                'fuel_type': fuel_type,
                'min_value': float(series_neg['y'].min()),
                'max_value': float(series_neg['y'].max()),
                'mean_value': float(series_neg['y'].mean()),
                'std_value': float(series_neg['y'].std()),
            }

        # 2. Analyze by hour (Are negatives at night for solar? Low-wind hours for wind?)
        hour_counts = neg_df.groupby('hour').size()
        total_by_hour = self.df.groupby('hour').size()
        neg_ratio_by_hour = (hour_counts / total_by_hour).fillna(0)

        analysis['by_hour'] = {
            'counts': hour_counts.to_dict(),
            'ratios': neg_ratio_by_hour.to_dict(),
            'peak_negative_hour': int(neg_ratio_by_hour.idxmax()) if len(neg_ratio_by_hour) > 0 else None,
        }

        # 3. Magnitude analysis - categorize severity
        neg_values = neg_df['y'].values
        analysis['magnitude_analysis'] = {
            'tiny_negatives_count': int((neg_values > -10).sum()),  # Likely metering noise
            'small_negatives_count': int(((neg_values <= -10) & (neg_values > -100)).sum()),
            'medium_negatives_count': int(((neg_values <= -100) & (neg_values > -1000)).sum()),
            'large_negatives_count': int((neg_values <= -1000).sum()),  # Significant issue
            'percentiles': {
                'p5': float(np.percentile(neg_values, 5)),
                'p25': float(np.percentile(neg_values, 25)),
                'p50': float(np.percentile(neg_values, 50)),
                'p75': float(np.percentile(neg_values, 75)),
                'p95': float(np.percentile(neg_values, 95)),
            }
        }

        # 4. Check for temporal clustering (consecutive hours of negatives)
        for uid in neg_df['unique_id'].unique():
            series_df = self.df[self.df['unique_id'] == uid].sort_values('ds')
            series_df['is_negative'] = series_df['y'] < 0

            # Find consecutive negative runs
            series_df['neg_block'] = (series_df['is_negative'] != series_df['is_negative'].shift()).cumsum()
            neg_blocks = series_df[series_df['is_negative']].groupby('neg_block').agg(
                start=('ds', 'min'),
                end=('ds', 'max'),
                duration_hours=('ds', 'count'),
                min_value=('y', 'min'),
            ).reset_index(drop=True)

            if len(neg_blocks) > 0:
                analysis['temporal_clustering'][uid] = {
                    'num_blocks': len(neg_blocks),
                    'avg_block_duration_hours': float(neg_blocks['duration_hours'].mean()),
                    'max_block_duration_hours': int(neg_blocks['duration_hours'].max()),
                    'longest_block_start': str(neg_blocks.loc[neg_blocks['duration_hours'].idxmax(), 'start']),
                }

        return analysis

    def determine_root_cause(self) -> Dict[str, Any]:
        """
        Based on patterns, determine the likely root cause and recommend action.

        Possible causes:
        1. NET GENERATION DATA: EIA reports net = gross - auxiliary. This is valid.
        2. METERING NOISE: Tiny negatives (-1 to -10 MWh) are measurement error.
        3. DATA REPORTING ERROR: Large sporadic negatives are likely errors.
        4. SYSTEMATIC ISSUE: Negatives always at same time = station use pattern.
        """
        patterns = self.analyze_negative_patterns()

        if patterns.get('status') == 'no_negatives_found':
            return {
                'root_cause': 'NONE',
                'confidence': 'HIGH',
                'recommendation': 'No action needed - data is clean',
                'preprocessing_policy': 'pass_through',
            }

        magnitude = patterns.get('magnitude_analysis', {})
        total_neg = sum([
            magnitude.get('tiny_negatives_count', 0),
            magnitude.get('small_negatives_count', 0),
            magnitude.get('medium_negatives_count', 0),
            magnitude.get('large_negatives_count', 0),
        ])

        # Determine root cause based on patterns
        root_cause_analysis = {
            'factors': [],
            'root_cause': 'UNKNOWN',
            'confidence': 'LOW',
            'recommendation': '',
            'preprocessing_policy': 'clamp_to_zero',
        }

        # Check if mostly tiny negatives (metering noise)
        if magnitude.get('tiny_negatives_count', 0) / max(total_neg, 1) > 0.9:
            root_cause_analysis['factors'].append('90%+ negatives are tiny (<10 MWh)')
            root_cause_analysis['root_cause'] = 'METERING_NOISE'
            root_cause_analysis['confidence'] = 'HIGH'
            root_cause_analysis['recommendation'] = (
                'Tiny negatives are measurement noise. Safe to clamp to 0.'
            )
            root_cause_analysis['preprocessing_policy'] = 'clamp_to_zero'

        # Check if negatives are systematic (same hours)
        elif patterns.get('by_hour', {}).get('peak_negative_hour') is not None:
            hour_ratios = patterns.get('by_hour', {}).get('ratios', {})
            max_ratio = max(hour_ratios.values()) if hour_ratios else 0

            if max_ratio > 0.3:  # >30% of negatives in one hour
                root_cause_analysis['factors'].append(f'Negatives concentrated at specific hours')
                root_cause_analysis['root_cause'] = 'NET_GENERATION_AUXILIARY_LOAD'
                root_cause_analysis['confidence'] = 'MEDIUM'
                root_cause_analysis['recommendation'] = (
                    'Negatives likely represent station auxiliary loads exceeding generation. '
                    'This is valid net generation data. For forecasting, clamp to 0 since '
                    'we want to predict usable power output.'
                )
                root_cause_analysis['preprocessing_policy'] = 'clamp_to_zero'

        # Check for large sporadic negatives (data errors)
        if magnitude.get('large_negatives_count', 0) > 0:
            root_cause_analysis['factors'].append(f"{magnitude.get('large_negatives_count')} large negatives (<-1000 MWh)")

            # If ONLY large negatives and they're sporadic, likely errors
            if magnitude.get('tiny_negatives_count', 0) == 0 and magnitude.get('small_negatives_count', 0) == 0:
                root_cause_analysis['root_cause'] = 'DATA_REPORTING_ERROR'
                root_cause_analysis['confidence'] = 'MEDIUM'
                root_cause_analysis['recommendation'] = (
                    'Large sporadic negatives are likely data reporting errors. '
                    'Recommend clamping to 0 or investigating with EIA.'
                )

        return root_cause_analysis

    def generate_report(self, output_dir: Path) -> Dict[str, Any]:
        """Generate comprehensive negative value investigation report."""
        output_dir.mkdir(parents=True, exist_ok=True)

        summary = self.get_negative_summary()
        patterns = self.analyze_negative_patterns()
        root_cause = self.determine_root_cause()

        report = {
            'summary': summary,
            'patterns': patterns,
            'root_cause_analysis': root_cause,
            'generated_at': datetime.now().isoformat(),
        }

        # Save JSON report
        report_file = output_dir / 'negative_investigation.json'
        with open(report_file, 'w') as f:
            json.dump(report, f, indent=2, default=str)
            f.write('\n')  # POSIX standard: files should end with newline

        # Generate visualizations if negatives exist
        if summary['negative_count'] > 0:
            self._plot_negative_analysis(patterns, output_dir)

        return report

    def _plot_negative_analysis(self, patterns: Dict, output_dir: Path) -> None:
        """Create diagnostic plots for negative value analysis."""
        neg_df = self.df[self.df['y'] < 0]

        if len(neg_df) == 0:
            return

        fig, axes = plt.subplots(2, 2, figsize=(14, 10))

        # 1. Distribution of negative values
        ax = axes[0, 0]
        neg_values = neg_df['y'].values
        ax.hist(neg_values, bins=50, color='red', alpha=0.7, edgecolor='black')
        ax.axvline(x=np.median(neg_values), color='blue', linestyle='--',
                   label=f'Median: {np.median(neg_values):.1f}')
        ax.set_xlabel('Generation (MWh)')
        ax.set_ylabel('Frequency')
        ax.set_title('Distribution of Negative Values')
        ax.legend()

        # 2. Negative ratio by hour
        ax = axes[0, 1]
        hour_ratios = patterns.get('by_hour', {}).get('ratios', {})
        if hour_ratios:
            hours = sorted(hour_ratios.keys())
            ratios = [hour_ratios[h] for h in hours]
            ax.bar(hours, ratios, color='orange', alpha=0.7)
            ax.set_xlabel('Hour of Day')
            ax.set_ylabel('Negative Ratio')
            ax.set_title('When Do Negatives Occur? (by Hour)')
            ax.set_xticks(range(0, 24, 2))

        # 3. Negative values by series
        ax = axes[1, 0]
        series_data = patterns.get('by_series', {})
        if series_data:
            series_names = list(series_data.keys())
            series_counts = [series_data[s]['count'] for s in series_names]
            colors = ['red' if 'SUN' in s else 'blue' for s in series_names]
            ax.barh(series_names, series_counts, color=colors, alpha=0.7)
            ax.set_xlabel('Negative Count')
            ax.set_title('Negatives by Series (Blue=Wind, Red=Solar)')

        # 4. Time series with negatives highlighted
        ax = axes[1, 1]
        # Plot first affected series
        affected = patterns.get('by_series', {})
        if affected:
            first_series = list(affected.keys())[0]
            series_df = self.df[self.df['unique_id'] == first_series].sort_values('ds')
            ax.plot(series_df['ds'], series_df['y'], 'b-', alpha=0.5, label='Generation')
            neg_mask = series_df['y'] < 0
            ax.scatter(series_df.loc[neg_mask, 'ds'], series_df.loc[neg_mask, 'y'],
                      c='red', s=20, label='Negative values', zorder=5)
            ax.axhline(y=0, color='black', linestyle='--', linewidth=0.5)
            ax.set_xlabel('Date')
            ax.set_ylabel('Generation (MWh)')
            ax.set_title(f'Time Series: {first_series}')
            ax.legend()
            plt.setp(ax.xaxis.get_majorticklabels(), rotation=45, ha='right')

        plt.tight_layout()
        plt.savefig(output_dir / 'negative_investigation.png', dpi=150, bbox_inches='tight')
        plt.close()


def run_full_eda(
    generation_df: pd.DataFrame,
    weather_df: pd.DataFrame,
    output_dir: Path,
) -> EDARecommendations:
    """
    Run comprehensive EDA with emphasis on understanding data quality issues.

    This function produces actionable insights, not just statistics.

    Args:
        generation_df: DataFrame with columns [unique_id, ds, y]
        weather_df: DataFrame with columns [ds, region, weather_vars...]
        output_dir: Directory to save all outputs

    Returns:
        EDARecommendations object with structured preprocessing recommendations
    """
    timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
    report_dir = output_dir / timestamp
    report_dir.mkdir(parents=True, exist_ok=True)

    print("=" * 80)
    print("RENEWABLE ENERGY EDA - Comprehensive Analysis")
    print("=" * 80)

    results = {
        'timestamp': timestamp,
        'output_dir': str(report_dir),
        'data_summary': {},
        'negative_investigation': {},
        'seasonality': {},
        'zero_inflation': {},
        'weather_alignment': {},
        'recommendations': {},
    }

    # 1. Data Summary
    print("\n[1/5] Data Summary...")
    results['data_summary'] = {
        'generation_rows': len(generation_df),
        'generation_series': generation_df['unique_id'].nunique(),
        'series_list': generation_df['unique_id'].unique().tolist(),
        'date_range': {
            'start': str(generation_df['ds'].min()),
            'end': str(generation_df['ds'].max()),
        },
        'weather_rows': len(weather_df),
        'weather_regions': weather_df['region'].nunique() if 'region' in weather_df.columns else 0,
    }
    print(f"   Generation: {results['data_summary']['generation_rows']:,} rows, "
          f"{results['data_summary']['generation_series']} series")

    # 2. CRITICAL: Negative Value Investigation
    print("\n[2/5] Negative Value Investigation (CRITICAL)...")
    neg_investigator = NegativeValueInvestigation(generation_df)
    results['negative_investigation'] = neg_investigator.generate_report(
        report_dir / 'negative_values'
    )

    neg_summary = results['negative_investigation']['summary']
    root_cause = results['negative_investigation']['root_cause_analysis']

    if neg_summary['negative_count'] > 0:
        print(f"   [WARNING] Found {neg_summary['negative_count']} negative values "
              f"({neg_summary['negative_ratio']:.2%})")
        print(f"   [WARNING] Affected series: {neg_summary['affected_series']}")
        print(f"   [ANALYSIS] Root cause: {root_cause['root_cause']} "
              f"(confidence: {root_cause['confidence']})")
        print(f"   [RECOMMENDATION] {root_cause['recommendation']}")
    else:
        print("   [OK] No negative values found")

    # 3. Seasonality Analysis
    print("\n[3/5] Seasonality Analysis...")
    seasonality_dir = report_dir / 'seasonality'
    seasonality_dir.mkdir(parents=True, exist_ok=True)

    seasonality_results = _analyze_seasonality(generation_df, seasonality_dir)
    results['seasonality'] = seasonality_results
    print(f"   [OK] Analyzed {len(seasonality_results.get('series_analyzed', []))} series")

    # 4. Zero-Inflation Analysis
    print("\n[4/5] Zero-Inflation Analysis...")
    zero_dir = report_dir / 'zero_inflation'
    zero_dir.mkdir(parents=True, exist_ok=True)

    zero_results = _analyze_zero_inflation(generation_df, zero_dir)
    results['zero_inflation'] = zero_results

    solar_series = [uid for uid in generation_df['unique_id'].unique() if 'SUN' in uid]
    if solar_series:
        avg_zero = sum(
            zero_results['series_zero_ratios'].get(uid, {}).get('zero_ratio', 0)
            for uid in solar_series
        ) / len(solar_series)
        print(f"   [OK] Solar zero ratio: {avg_zero:.1%} (zeros at night expected)")

    # 5. Weather Alignment
    print("\n[5/5] Weather Alignment...")
    weather_dir = report_dir / 'weather_alignment'
    weather_dir.mkdir(parents=True, exist_ok=True)

    weather_results = _analyze_weather_alignment(generation_df, weather_df, weather_dir)
    results['weather_alignment'] = weather_results
    print(f"   [OK] Merge success rate: {weather_results['merge_success_ratio']:.1%}")

    # Generate Final Recommendations
    print("\n" + "=" * 80)
    print("RECOMMENDATIONS")
    print("=" * 80)

    # Build structured preprocessing recommendations
    preprocessing_rec = PreprocessingRecommendation(
        negative_policy=root_cause.get('preprocessing_policy', 'clamp_to_zero'),
        negative_reason=root_cause['recommendation'],
        negative_confidence=root_cause['confidence'],
        max_missing_ratio=0.02,
        series_overrides={},
        data_summary=results['data_summary'],
    )

    # Add series-specific patterns if negatives found
    patterns = results['negative_investigation'].get('patterns', {})
    if neg_summary['negative_count'] > 0:
        for series_id in neg_summary['affected_series']:
            series_pattern = patterns.get('by_series', {}).get(series_id, {})
            preprocessing_rec.series_overrides[series_id] = {
                'negative_ratio': series_pattern.get('ratio', 0),
                'suggested_policy': root_cause.get('preprocessing_policy'),
            }

    print(f"\n[PREPROCESSING] Negative Handling: {preprocessing_rec.negative_policy}")
    print(f"   Reason: {preprocessing_rec.negative_reason}")
    print(f"   Confidence: {preprocessing_rec.negative_confidence}")

    # Modeling recommendations
    modeling_recs = {
        'seasonality': 'Use MSTL with season_length=[24, 168] (daily + weekly)',
        'forecast_constraints': 'ALWAYS clip forecasts and intervals to [0, ∞)',
        'reason': 'Physical constraint: renewable generation cannot be negative',
    }
    print(f"\n[MODELING] Forecast Constraints: Clip to [0, ∞)")
    print(f"   Reason: Physical constraint - renewable generation cannot be negative")

    # Evaluation recommendations
    avg_zero = 0.0
    if solar_series:
        avg_zero = sum(
            zero_results['series_zero_ratios'].get(uid, {}).get('zero_ratio', 0)
            for uid in solar_series
        ) / len(solar_series)

    evaluation_recs = {
        'metrics': ['RMSE', 'MAE'],
        'avoid': 'MAPE (undefined when y=0)',
        'reason': f"Solar has {avg_zero:.1%} zeros (nighttime)" if solar_series else "Standard metrics",
    }
    print(f"\n[EVALUATION] Use RMSE/MAE, avoid MAPE")

    # Build final EDARecommendations object
    recommendations = EDARecommendations(
        preprocessing=preprocessing_rec,
        modeling=modeling_recs,
        evaluation=evaluation_recs,
        negative_investigation=results['negative_investigation'],
        seasonality=results['seasonality'],
        zero_inflation=results['zero_inflation'],
        weather_alignment=results['weather_alignment'],
    )

    # Save structured recommendations
    recommendations.save(report_dir / 'recommendations.json')

    # Also save full report (backward compatibility)
    results['recommendations'] = asdict(recommendations)
    report_file = report_dir / 'eda_report.json'
    with open(report_file, 'w') as f:
        json.dump(results, f, indent=2, default=str)
        f.write('\n')  # POSIX standard: files should end with newline

    print("\n" + "=" * 80)
    print(f"[SUCCESS] EDA complete. Report saved to: {report_dir}")
    print("=" * 80)

    return recommendations


def _analyze_seasonality(df: pd.DataFrame, output_dir: Path) -> Dict[str, Any]:
    """Analyze seasonal patterns in generation data."""
    df = df.copy()
    df['ds'] = pd.to_datetime(df['ds'])

    results = {
        'series_analyzed': [],
        'hourly_patterns': {},
    }

    for uid in df['unique_id'].unique()[:3]:  # Analyze first 3 series
        series_data = df[df['unique_id'] == uid].copy()
        series_data['hour'] = series_data['ds'].dt.hour

        hourly_profile = series_data.groupby('hour')['y'].agg(['mean', 'std']).reset_index()
        results['series_analyzed'].append(uid)
        results['hourly_patterns'][uid] = hourly_profile.to_dict(orient='records')

    # Create visualization
    fig, axes = plt.subplots(1, len(results['series_analyzed']),
                            figsize=(5 * len(results['series_analyzed']), 4))
    if len(results['series_analyzed']) == 1:
        axes = [axes]

    for idx, uid in enumerate(results['series_analyzed']):
        series_data = df[df['unique_id'] == uid].copy()
        series_data['hour'] = series_data['ds'].dt.hour
        hourly_mean = series_data.groupby('hour')['y'].mean()
        hourly_std = series_data.groupby('hour')['y'].std()

        axes[idx].plot(hourly_mean.index, hourly_mean.values, marker='o')
        axes[idx].fill_between(hourly_mean.index,
                               hourly_mean - hourly_std,
                               hourly_mean + hourly_std, alpha=0.3)
        axes[idx].set_xlabel('Hour of Day')
        axes[idx].set_ylabel('Generation (MWh)')
        axes[idx].set_title(f'{uid} - Hourly Profile')
        axes[idx].grid(True, alpha=0.3)

    plt.tight_layout()
    plt.savefig(output_dir / 'hourly_profiles.png', dpi=150, bbox_inches='tight')
    plt.close()

    return results


def _analyze_zero_inflation(df: pd.DataFrame, output_dir: Path) -> Dict[str, Any]:
    """Analyze zero values in generation data."""
    df = df.copy()
    df['ds'] = pd.to_datetime(df['ds'])
    df['hour'] = df['ds'].dt.hour

    results = {
        'series_zero_ratios': {},
    }

    for uid in df['unique_id'].unique():
        series_data = df[df['unique_id'] == uid]
        zero_count = (series_data['y'] == 0).sum()
        total_count = len(series_data)

        results['series_zero_ratios'][uid] = {
            'zero_count': int(zero_count),
            'total_count': int(total_count),
            'zero_ratio': float(zero_count / total_count) if total_count > 0 else 0,
        }

    # Visualization
    solar_series = [uid for uid in df['unique_id'].unique() if 'SUN' in uid]
    wind_series = [uid for uid in df['unique_id'].unique() if 'WND' in uid]

    fig, axes = plt.subplots(1, 2, figsize=(14, 5))

    # Solar zeros by hour
    if solar_series:
        solar_df = df[df['unique_id'].isin(solar_series)]
        # Use vectorized approach: faster and no FutureWarning
        solar_zero_by_hour = (
            solar_df.assign(is_zero=solar_df['y'].eq(0))
            .groupby('hour')['is_zero']
            .mean()
        )
        axes[0].bar(solar_zero_by_hour.index, solar_zero_by_hour.values,
                   color='orange', alpha=0.7)
        axes[0].set_xlabel('Hour of Day')
        axes[0].set_ylabel('Zero Ratio')
        axes[0].set_title('Solar: Zero Ratio by Hour (Night = Expected)')
        axes[0].axhline(y=0.5, color='red', linestyle='--', alpha=0.5)

    # Wind zeros by hour
    if wind_series:
        wind_df = df[df['unique_id'].isin(wind_series)]
        # Use vectorized approach: faster and no FutureWarning
        wind_zero_by_hour = (
            wind_df.assign(is_zero=wind_df['y'].eq(0))
            .groupby('hour')['is_zero']
            .mean()
        )
        axes[1].bar(wind_zero_by_hour.index, wind_zero_by_hour.values,
                   color='blue', alpha=0.7)
        axes[1].set_xlabel('Hour of Day')
        axes[1].set_ylabel('Zero Ratio')
        axes[1].set_title('Wind: Zero Ratio by Hour')

    plt.tight_layout()
    plt.savefig(output_dir / 'zero_inflation.png', dpi=150, bbox_inches='tight')
    plt.close()

    return results


def _analyze_weather_alignment(
    generation_df: pd.DataFrame,
    weather_df: pd.DataFrame,
    output_dir: Path
) -> Dict[str, Any]:
    """Analyze weather-generation correlation."""
    generation_df = generation_df.copy()
    weather_df = weather_df.copy()

    generation_df['ds'] = pd.to_datetime(generation_df['ds'])
    weather_df['ds'] = pd.to_datetime(weather_df['ds'])
    generation_df['region'] = generation_df['unique_id'].str.split('_').str[0]

    merged = generation_df.merge(weather_df, on=['ds', 'region'], how='left')

    weather_vars = [c for c in weather_df.columns
                   if c not in ['ds', 'region'] and c in merged.columns]

    results = {
        'merge_success_ratio': float(
            merged[weather_vars[0]].notna().mean() if weather_vars else 0
        ),
        'correlation_by_fuel': {},
    }

    # Calculate correlations
    wind_series = merged[merged['unique_id'].str.contains('WND')]
    solar_series = merged[merged['unique_id'].str.contains('SUN')]

    if len(wind_series) > 0:
        wind_corr = {}
        for var in weather_vars:
            if var in wind_series.columns:
                corr = wind_series[['y', var]].corr().iloc[0, 1]
                wind_corr[var] = float(corr) if not pd.isna(corr) else 0.0
        results['correlation_by_fuel']['WND'] = wind_corr

    if len(solar_series) > 0:
        solar_corr = {}
        for var in weather_vars:
            if var in solar_series.columns:
                corr = solar_series[['y', var]].corr().iloc[0, 1]
                solar_corr[var] = float(corr) if not pd.isna(corr) else 0.0
        results['correlation_by_fuel']['SUN'] = solar_corr

    # Save results
    with open(output_dir / 'weather_analysis.json', 'w') as f:
        json.dump(results, f, indent=2)
        f.write('\n')  # POSIX standard: files should end with newline

    return results


if __name__ == "__main__":
    """Run EDA on renewable energy data."""
    import sys

    generation_path = Path("data/renewable/generation.parquet")
    weather_path = Path("data/renewable/weather.parquet")

    if not generation_path.exists() or not weather_path.exists():
        print("Data files not found. Run pipeline first.")
        

    generation_df = pd.read_parquet(generation_path)
    weather_df = pd.read_parquet(weather_path)

    output_dir = Path("reports/renewable/eda")

    results = run_full_eda(generation_df, weather_df, output_dir)


RENEWABLE ENERGY EDA - Comprehensive Analysis

[1/5] Data Summary...
   Generation: 4,358 rows, 6 series

[2/5] Negative Value Investigation (CRITICAL)...
   [ANALYSIS] Root cause: UNKNOWN (confidence: LOW)
   [RECOMMENDATION] 

[3/5] Seasonality Analysis...
   [OK] Analyzed 3 series

[4/5] Zero-Inflation Analysis...
   [OK] Solar zero ratio: 15.4% (zeros at night expected)

[5/5] Weather Alignment...
   [OK] Merge success rate: 100.0%

RECOMMENDATIONS

[PREPROCESSING] Negative Handling: clamp_to_zero
   Reason: 
   Confidence: LOW

[MODELING] Forecast Constraints: Clip to [0, ∞)
   Reason: Physical constraint - renewable generation cannot be negative

[EVALUATION] Use RMSE/MAE, avoid MAPE

[SUCCESS] EDA complete. Report saved to: reports\renewable\eda\20260122_165040


# Dataset Builder based on EDA from both datasets

In [7]:
# %%writefile src/renewable/dataset_builder.py
# file: src/renewable/dataset_builder.py
"""
Dataset Builder for Renewable Energy Forecasting

This module transforms raw EIA/weather data into modeling-ready datasets with:
1. Transparent preprocessing based on EDA findings
2. Physical constraint enforcement (non-negativity)
3. Comprehensive diagnostics

KEY PRINCIPLE:
Renewable energy generation CANNOT be negative. This is a physical law.
- Solar panels: 0 to max capacity
- Wind turbines: 0 to max capacity

Any negative values in raw data are data quality issues (metering, net generation
accounting, etc.) and should be handled transparently.

Preprocessing Policies:
- clamp_to_zero: Set negative values to 0 (recommended for most cases)
- investigate: Fail with detailed diagnostics (for initial data exploration)
- pass_through: No modification (only if you understand why negatives exist)
"""

from __future__ import annotations

import json
import logging
from dataclasses import dataclass, asdict, field
from datetime import datetime
from pathlib import Path
from typing import Any, Dict, List, Optional, Tuple

import numpy as np
import pandas as pd

logger = logging.getLogger(__name__)


# Weather variables from Open-Meteo
WEATHER_VARS = [
    "temperature_2m",
    "wind_speed_10m",
    "wind_speed_100m",
    "wind_direction_10m",
    "direct_radiation",
    "diffuse_radiation",
    "cloud_cover",
]

# Time features for modeling
TIME_FEATURES = ["hour_sin", "hour_cos", "dow_sin", "dow_cos"]


@dataclass
class NegativeValueReport:
    """Report on negative values found and handled."""
    total_negative_count: int
    total_rows: int
    negative_ratio: float
    by_series: Dict[str, Dict[str, Any]]
    action_taken: str
    samples: List[Dict[str, Any]]


@dataclass
class PreprocessingReport:
    """Complete report of all preprocessing steps."""
    timestamp: str

    # Input stats
    input_rows: int
    input_series: int
    input_date_range: Dict[str, str]

    # Negative handling
    negative_report: NegativeValueReport

    # Missing data
    missing_hours_dropped: int
    series_dropped_incomplete: List[str]

    # Weather alignment
    weather_coverage: float
    weather_vars_used: List[str]

    # Output stats
    output_rows: int
    output_series: int
    output_features: List[str]

    # Configuration used
    config: Dict[str, Any]


def _add_time_features(df: pd.DataFrame) -> pd.DataFrame:
    """Add cyclical time features (hour, day of week)."""
    out = df.copy()
    out["hour"] = out["ds"].dt.hour
    out["dow"] = out["ds"].dt.dayofweek

    # Cyclical encoding (sin/cos transform)
    out["hour_sin"] = np.sin(2 * np.pi * out["hour"] / 24)
    out["hour_cos"] = np.cos(2 * np.pi * out["hour"] / 24)
    out["dow_sin"] = np.sin(2 * np.pi * out["dow"] / 7)
    out["dow_cos"] = np.cos(2 * np.pi * out["dow"] / 7)

    return out.drop(columns=["hour", "dow"])


def _apply_negative_policy(
    df: pd.DataFrame,
    policy: str,
) -> Tuple[pd.DataFrame, NegativeValueReport]:
    """
    Apply negative value policy (NO INVESTIGATION - that's EDA's job).

    This function ONLY applies the policy. Investigation should be done
    in eda.py before calling dataset builder.

    Physical Reality:
    - Renewable energy generation CANNOT be negative
    - Negative values are ALWAYS data quality issues

    Policies:
    - clamp_to_zero: Set negatives to 0 (RECOMMENDED from EDA)
    - investigate: Fail with message to run EDA first
    - pass_through: No modification (if EDA recommends)

    Args:
        df: DataFrame with [unique_id, ds, y]
        policy: Policy to apply (from EDA recommendations)

    Returns:
        (processed_df, report)
    """
    neg_mask = df['y'] < 0
    neg_count = int(neg_mask.sum())
    total_rows = len(df)

    # DEBUG: Log what we're working with
    logger.debug(f"[_apply_negative_policy] Processing {total_rows} rows, found {neg_count} negatives")

    # Build report
    by_series = {}
    samples = []

    if neg_count > 0:
        for uid in df.loc[neg_mask, 'unique_id'].unique():
            series_mask = (df['unique_id'] == uid) & neg_mask
            series_neg = df.loc[series_mask]
            series_total = len(df[df['unique_id'] == uid])

            by_series[uid] = {
                'count': int(series_mask.sum()),
                'min_value': float(series_neg['y'].min()),
                'max_value': float(series_neg['y'].max()),
            }

            # Just a few samples for audit
            for _, row in series_neg.head(3).iterrows():
                samples.append({
                    'unique_id': row['unique_id'],
                    'ds': str(row['ds']),
                    'y': float(row['y']),
                })

    # Calculate negative ratio
    negative_ratio = float(neg_count / total_rows) if total_rows > 0 else 0.0

    # DEBUG: Log report creation
    logger.debug(f"[_apply_negative_policy] Creating report: {neg_count}/{total_rows} = {negative_ratio:.2%} negatives")

    report = NegativeValueReport(
        total_negative_count=neg_count,
        total_rows=total_rows,
        negative_ratio=negative_ratio,
        by_series=by_series,
        action_taken=policy,
        samples=samples,
    )

    # Apply policy
    if neg_count == 0:
        report.action_taken = 'none_needed'
        return df.copy(), report

    if policy == 'investigate':
        raise ValueError(
            f"NEGATIVE VALUES DETECTED: {neg_count} negatives found.\n"
            f"Run EDA first to investigate root cause:\n"
            f"  from src.renewable.eda import run_full_eda\n"
            f"  recommendations = run_full_eda(generation_df, weather_df, output_dir)\n"
            f"Then use recommended policy from EDA."
        )

    elif policy == 'clamp_to_zero':
        out = df.copy()
        out['y'] = out['y'].clip(lower=0)
        logger.info(
            f"[PREPROCESSING] Clamped {neg_count} negative values to 0 "
            f"({100*neg_count/total_rows:.2f}% of data)"
        )

        # Log per-series
        for uid, info in by_series.items():
            logger.info(
                f"  {uid}: {info['count']} negatives clamped "
                f"(range: [{info['min_value']:.1f}, {info['max_value']:.1f}])"
            )

        report.action_taken = 'clamped_to_zero'
        return out, report

    elif policy == 'pass_through':
        logger.warning(f"[PREPROCESSING] Passing through {neg_count} negative values")
        report.action_taken = 'passed_through'
        return df.copy(), report

    else:
        raise ValueError(f"Unknown negative_policy: {policy}")


def _enforce_hourly_grid(
    df: pd.DataFrame,
    max_missing_ratio: float = 0.02,
) -> Tuple[pd.DataFrame, List[str], int]:
    """
    Enforce complete hourly grid (no gaps).

    For time series forecasting, we need continuous hourly data.
    Series with too many gaps are dropped (no imputation - we don't fabricate data).

    Args:
        df: DataFrame with [unique_id, ds, y]
        max_missing_ratio: Maximum allowed ratio of missing hours

    Returns:
        (filtered_df, dropped_series, total_missing_hours)
    """
    dropped_series = []
    total_missing = 0

    keep_rows = []

    for uid, group in df.groupby('unique_id'):
        group = group.sort_values('ds')
        start = group['ds'].min()
        end = group['ds'].max()

        expected_hours = pd.date_range(start, end, freq='h')
        actual_hours = len(group)
        expected_count = len(expected_hours)

        missing_count = expected_count - actual_hours
        missing_ratio = missing_count / expected_count if expected_count > 0 else 0

        total_missing += missing_count

        if missing_ratio > max_missing_ratio:
            dropped_series.append(uid)
            logger.warning(
                f"[GRID] Dropping {uid}: missing {missing_count} hours "
                f"({missing_ratio:.1%} > {max_missing_ratio:.1%} threshold)"
            )
        else:
            keep_rows.append(group)

    if not keep_rows:
        raise RuntimeError(
            f"All series dropped due to missing hours. "
            f"Dropped: {dropped_series}. Consider increasing max_missing_ratio."
        )

    filtered = pd.concat(keep_rows, ignore_index=True)
    return filtered, dropped_series, total_missing


def _align_weather(
    df: pd.DataFrame,
    weather_df: pd.DataFrame,
) -> Tuple[pd.DataFrame, float, List[str]]:
    """
    Align weather data to generation timestamps.

    Args:
        df: Generation DataFrame (must have 'region' column or unique_id with region prefix)
        weather_df: Weather DataFrame with [ds, region, weather_vars...]

    Returns:
        (merged_df, coverage_ratio, weather_vars_used)
    """
    work = df.copy()

    # Extract region from unique_id if not present
    if 'region' not in work.columns:
        work['region'] = work['unique_id'].str.split('_').str[0]

    # Find available weather variables
    available_vars = [c for c in WEATHER_VARS if c in weather_df.columns]
    if not available_vars:
        raise ValueError(
            f"No weather variables found in weather_df. "
            f"Expected: {WEATHER_VARS}, Got: {weather_df.columns.tolist()}"
        )

    # Merge
    merged = work.merge(
        weather_df[['ds', 'region'] + available_vars],
        on=['ds', 'region'],
        how='left',
        validate='many_to_one',
    )

    # Check coverage
    missing_weather = merged[available_vars].isna().any(axis=1)
    coverage = 1 - (missing_weather.sum() / len(merged))

    if missing_weather.any():
        missing_count = int(missing_weather.sum())
        logger.warning(
            f"[WEATHER] {missing_count} rows ({1-coverage:.1%}) missing weather data"
        )

        # Drop rows with missing weather (no fabrication)
        merged = merged[~missing_weather].reset_index(drop=True)
        logger.warning(f"[WEATHER] Dropped {missing_count} rows with missing weather")

    # Drop region column (not needed for modeling)
    merged = merged.drop(columns=['region'])

    return merged, coverage, available_vars


def build_modeling_dataset(
    generation_df: pd.DataFrame,
    weather_df: pd.DataFrame,
    *,
    negative_policy: str = 'clamp_to_zero',
    max_missing_ratio: float = 0.02,
    output_dir: Optional[Path] = None,
    eda_recommendations: Optional['PreprocessingRecommendation'] = None,
) -> Tuple[pd.DataFrame, PreprocessingReport]:
    """
    Build modeling-ready dataset from raw data.

    RECOMMENDED: Run EDA first and pass recommendations via eda_recommendations.

    Pipeline:
    1. Validate inputs
    2. Handle negative values (based on policy)
    3. Enforce hourly grid (drop incomplete series)
    4. Add time features
    5. Align weather data

    Args:
        generation_df: Raw generation data [unique_id, ds, y]
        weather_df: Raw weather data [ds, region, weather_vars...]
        negative_policy: How to handle negative values (override if no EDA)
            - 'clamp_to_zero': Set to 0 (RECOMMENDED)
            - 'investigate': Fail with diagnostics
            - 'pass_through': No modification
        max_missing_ratio: Max ratio of missing hours before dropping series
        output_dir: Optional directory for detailed reports
        eda_recommendations: Recommendations from run_full_eda() (PREFERRED)

    Returns:
        (modeling_df, preprocessing_report)
    """
    logger.info("=" * 60)
    logger.info("DATASET BUILDER - Building Modeling Dataset")
    logger.info("=" * 60)

    # Use EDA recommendations if provided
    if eda_recommendations is not None:
        negative_policy = eda_recommendations.negative_policy
        max_missing_ratio = eda_recommendations.max_missing_ratio
        logger.info("[DATASET_BUILDER] Using EDA recommendations")
        logger.info(f"  Policy: {negative_policy} (confidence: {eda_recommendations.negative_confidence})")
        logger.info(f"  Reason: {eda_recommendations.negative_reason}")
    else:
        logger.warning("[DATASET_BUILDER] No EDA recommendations - using defaults")
        logger.warning("  RECOMMENDED: Run EDA first for data-driven decisions")

    # Validate inputs
    required_gen = {'unique_id', 'ds', 'y'}
    if not required_gen.issubset(generation_df.columns):
        missing = required_gen - set(generation_df.columns)
        raise ValueError(f"generation_df missing columns: {missing}")

    if generation_df.empty:
        raise ValueError("generation_df is empty")

    required_weather = {'ds', 'region'}
    if not required_weather.issubset(weather_df.columns):
        missing = required_weather - set(weather_df.columns)
        raise ValueError(f"weather_df missing columns: {missing}")

    # Ensure datetime
    work = generation_df.copy()
    work['ds'] = pd.to_datetime(work['ds'])
    weather_df = weather_df.copy()
    weather_df['ds'] = pd.to_datetime(weather_df['ds'])

    input_rows = len(work)
    input_series = work['unique_id'].nunique()
    input_date_range = {
        'start': str(work['ds'].min()),
        'end': str(work['ds'].max()),
    }

    logger.info(f"Input: {input_rows:,} rows, {input_series} series")
    logger.info(f"Date range: {input_date_range['start']} to {input_date_range['end']}")

    # Step 1: Apply negative value policy
    logger.info(f"\n[1/4] Applying negative value policy (policy={negative_policy})...")
    work, neg_report = _apply_negative_policy(work, policy=negative_policy)

    # Step 2: Enforce hourly grid
    logger.info(f"\n[2/4] Enforcing hourly grid (max_missing={max_missing_ratio:.1%})...")
    work, dropped_series, missing_hours = _enforce_hourly_grid(
        work, max_missing_ratio=max_missing_ratio
    )

    # Step 3: Add time features
    logger.info("\n[3/4] Adding time features...")
    work = _add_time_features(work)
    logger.info(f"   Added: {TIME_FEATURES}")

    # Step 4: Align weather
    logger.info("\n[4/4] Aligning weather data...")
    work, weather_coverage, weather_vars = _align_weather(work, weather_df)
    logger.info(f"   Coverage: {weather_coverage:.1%}")
    logger.info(f"   Variables: {weather_vars}")

    # Sort and finalize
    work = work.sort_values(['unique_id', 'ds']).reset_index(drop=True)

    output_rows = len(work)
    output_series = work['unique_id'].nunique()
    output_features = ['unique_id', 'ds', 'y'] + TIME_FEATURES + weather_vars

    # Build report
    report = PreprocessingReport(
        timestamp=datetime.now().isoformat(),
        input_rows=input_rows,
        input_series=input_series,
        input_date_range=input_date_range,
        negative_report=neg_report,
        missing_hours_dropped=missing_hours,
        series_dropped_incomplete=dropped_series,
        weather_coverage=weather_coverage,
        weather_vars_used=weather_vars,
        output_rows=output_rows,
        output_series=output_series,
        output_features=output_features,
        config={
            'negative_policy': negative_policy,
            'max_missing_ratio': max_missing_ratio,
        },
    )

    # Save report if output_dir provided
    if output_dir:
        output_dir = Path(output_dir)
        output_dir.mkdir(parents=True, exist_ok=True)

        report_dict = asdict(report)
        report_dict['negative_report'] = asdict(report.negative_report)

        report_file = output_dir / 'preprocessing_report.json'
        with open(report_file, 'w') as f:
            json.dump(report_dict, f, indent=2, default=str)
            f.write('\n')  # POSIX standard: files should end with newline
        logger.info(f"\n[REPORT] Saved to: {report_file}")

    logger.info("\n" + "=" * 60)
    logger.info("DATASET BUILDER - Complete")
    logger.info(f"Output: {output_rows:,} rows, {output_series} series")
    logger.info(f"Dropped: {input_rows - output_rows:,} rows")
    logger.info("=" * 60)

    return work, report


# ============================================================================
# Dataset-Specific Builders
# ============================================================================

class RenewableDatasetBuilder:
    """Base class for fuel-type-specific dataset builders."""

    def __init__(
        self,
        fuel_type: str,
        eda_recommendations: Optional['PreprocessingRecommendation'] = None,
    ):
        self.fuel_type = fuel_type
        self.eda_recommendations = eda_recommendations
        self.config = self._get_default_config()

        if eda_recommendations:
            self.config['negative_policy'] = eda_recommendations.negative_policy
            self.config['max_missing_ratio'] = eda_recommendations.max_missing_ratio

    def _get_default_config(self) -> Dict[str, Any]:
        """Get default config for this fuel type."""
        return {
            'negative_policy': 'clamp_to_zero',
            'max_missing_ratio': 0.02,
        }

    def build(
        self,
        generation_df: pd.DataFrame,
        weather_df: pd.DataFrame,
        output_dir: Optional[Path] = None,
    ) -> Tuple[pd.DataFrame, PreprocessingReport]:
        """Build dataset using fuel-type-specific logic."""

        # Filter to this fuel type only
        fuel_series = [uid for uid in generation_df['unique_id'].unique() if self.fuel_type in uid]

        if not fuel_series:
            raise ValueError(f"No series found for fuel type: {self.fuel_type}")

        filtered_df = generation_df[generation_df['unique_id'].isin(fuel_series)].copy()

        logger.info(f"[{self.fuel_type}_BUILDER] Building dataset for {len(fuel_series)} series")

        return build_modeling_dataset(
            filtered_df,
            weather_df,
            negative_policy=self.config['negative_policy'],
            max_missing_ratio=self.config['max_missing_ratio'],
            output_dir=output_dir,
            eda_recommendations=self.eda_recommendations,
        )


class SolarDatasetBuilder(RenewableDatasetBuilder):
    """Solar-specific dataset builder."""

    def __init__(self, eda_recommendations: Optional['PreprocessingRecommendation'] = None):
        super().__init__('SUN', eda_recommendations)

    def _get_default_config(self) -> Dict[str, Any]:
        config = super()._get_default_config()
        # Solar: might tolerate slightly more missing data at night
        config['max_missing_ratio'] = 0.03
        return config


class WindDatasetBuilder(RenewableDatasetBuilder):
    """Wind-specific dataset builder."""

    def __init__(self, eda_recommendations: Optional['PreprocessingRecommendation'] = None):
        super().__init__('WND', eda_recommendations)


def build_dataset_by_fuel_type(
    generation_df: pd.DataFrame,
    weather_df: pd.DataFrame,
    fuel_type: str,
    output_dir: Optional[Path] = None,
    eda_recommendations: Optional['PreprocessingRecommendation'] = None,
) -> Tuple[pd.DataFrame, PreprocessingReport]:
    """
    Factory function to build dataset using appropriate fuel-specific builder.

    Args:
        generation_df: Raw generation data
        weather_df: Raw weather data
        fuel_type: 'SUN' or 'WND'
        output_dir: Optional output directory
        eda_recommendations: Optional EDA recommendations

    Returns:
        (modeling_df, report)
    """
    builders = {
        'SUN': SolarDatasetBuilder,
        'WND': WindDatasetBuilder,
    }

    if fuel_type not in builders:
        raise ValueError(f"Unknown fuel type: {fuel_type}")

    builder = builders[fuel_type](eda_recommendations)
    return builder.build(generation_df, weather_df, output_dir)


if __name__ == "__main__":
    """Test dataset builder with real data."""
    import sys

    logging.basicConfig(
        level=logging.INFO,
        format="%(asctime)s - %(levelname)s - %(message)s"
    )

    generation_path = Path("data/renewable/generation.parquet")
    weather_path = Path("data/renewable/weather.parquet")

    if not generation_path.exists() or not weather_path.exists():
        print("Data files not found. Run pipeline first.")
        

    generation_df = pd.read_parquet(generation_path)
    weather_df = pd.read_parquet(weather_path)

    # First investigate negatives
    print("\n[TEST 1] Investigating negatives...")
    try:
        _, _ = build_modeling_dataset(
            generation_df, weather_df,
            negative_policy='investigate',
            output_dir=Path("data/renewable/test_investigate")
        )
    except ValueError as e:
        print(str(e))

    # Then build with clamp
    print("\n[TEST 2] Building with clamp_to_zero...")
    modeling_df, report = build_modeling_dataset(
        generation_df, weather_df,
        negative_policy='clamp_to_zero',
        output_dir=Path("data/renewable/preprocessing")
    )

    print(f"\nFinal dataset shape: {modeling_df.shape}")
    print(f"Columns: {modeling_df.columns.tolist()}")
    print(f"\nSample:")
    print(modeling_df.head())


2026-01-22 16:50:41,213 - __main__ - INFO - DATASET BUILDER - Building Modeling Dataset
2026-01-22 16:50:41,218 - __main__ - INFO - Input: 4,358 rows, 6 series
2026-01-22 16:50:41,218 - __main__ - INFO - Date range: 2025-12-23 00:00:00 to 2026-01-22 07:00:00
2026-01-22 16:50:41,219 - __main__ - INFO - 
[1/4] Applying negative value policy (policy=investigate)...
2026-01-22 16:50:41,221 - __main__ - INFO - DATASET BUILDER - Building Modeling Dataset
2026-01-22 16:50:41,225 - __main__ - INFO - Input: 4,358 rows, 6 series
2026-01-22 16:50:41,225 - __main__ - INFO - Date range: 2025-12-23 00:00:00 to 2026-01-22 07:00:00
2026-01-22 16:50:41,226 - __main__ - INFO - 
[1/4] Applying negative value policy (policy=clamp_to_zero)...
2026-01-22 16:50:41,228 - __main__ - INFO - [PREPROCESSING] Clamped 403 negative values to 0 (9.25% of data)
2026-01-22 16:50:41,229 - __main__ - INFO -   CALI_SUN: 403 negatives clamped (range: [-60.0, -5.0])
2026-01-22 16:50:41,229 - __main__ - INFO - 
[2/4] Enforci


[TEST 1] Investigating negatives...
NEGATIVE VALUES DETECTED: 403 negatives found.
Run EDA first to investigate root cause:
  from src.renewable.eda import run_full_eda
  recommendations = run_full_eda(generation_df, weather_df, output_dir)
Then use recommended policy from EDA.

[TEST 2] Building with clamp_to_zero...

Final dataset shape: (4358, 14)
Columns: ['unique_id', 'ds', 'y', 'hour_sin', 'hour_cos', 'dow_sin', 'dow_cos', 'temperature_2m', 'wind_speed_10m', 'wind_speed_100m', 'wind_direction_10m', 'direct_radiation', 'diffuse_radiation', 'cloud_cover']

Sample:
  unique_id                  ds     y  hour_sin  hour_cos   dow_sin  dow_cos  \
0  CALI_SUN 2025-12-23 00:00:00  3701  0.000000  1.000000  0.781831  0.62349   
1  CALI_SUN 2025-12-23 01:00:00   393  0.258819  0.965926  0.781831  0.62349   
2  CALI_SUN 2025-12-23 02:00:00    31  0.500000  0.866025  0.781831  0.62349   
3  CALI_SUN 2025-12-23 03:00:00    37  0.707107  0.707107  0.781831  0.62349   
4  CALI_SUN 2025-12-23 0

---

# Module 4: Probabilistic Modeling

**File:** `src/renewable/modeling.py`

This is where the forecasting happens! We use **StatsForecast** for:

1. **Multi-series forecasting**: Handle multiple regions/fuel types in one model
2. **Probabilistic predictions**: Get prediction intervals, not just point forecasts
3. **Weather exogenous**: Include weather features as predictors

## Key Concepts

### Why Prediction Intervals?

Point forecasts are useful, but energy traders need **uncertainty quantification**:
- **80% interval**: "I'm 80% confident generation will be between X and Y"
- **95% interval**: Wider, for risk management

### Zero-Value Safety (CRITICAL)

**Solar panels generate ZERO at night!** This breaks MAPE:

```
MAPE = mean(|actual - predicted| / actual)

When actual = 0:
MAPE = |0 - pred| / 0 = undefined (division by zero!)
```

**Solution**: Always use RMSE and MAE for renewable forecasting.

In [8]:
%%writefile src/renewable/validation.py
# file: src/renewable/validation.py
"""Validation utilities for renewable generation data."""

from __future__ import annotations

from dataclasses import dataclass
from typing import Iterable, Optional

import pandas as pd


@dataclass(frozen=True)
class ValidationReport:
    ok: bool
    message: str
    details: dict


def validate_generation_df(
    df: pd.DataFrame,
    *,
    max_lag_hours: int = 3,
    max_missing_ratio: float = 0.02,
    expected_series: Optional[Iterable[str]] = None,
) -> ValidationReport:
    required = {"unique_id", "ds", "y"}
    missing_cols = required - set(df.columns)
    if missing_cols:
        return ValidationReport(
            False,
            "Missing required columns",
            {"missing_cols": sorted(missing_cols)},
        )

    if df.empty:
        return ValidationReport(False, "Generation data is empty", {})

    work = df.copy()

    work["ds"] = pd.to_datetime(work["ds"], errors="coerce", utc=True)
    if work["ds"].isna().any():
        return ValidationReport(
            False,
            "Unparseable ds values found",
            {"bad_ds": int(work["ds"].isna().sum())},
        )

    work["y"] = pd.to_numeric(work["y"], errors="coerce")
    if work["y"].isna().any():
        return ValidationReport(
            False,
            "Unparseable y values found",
            {"bad_y": int(work["y"].isna().sum())},
        )

    # Check for negative values and log warning (but allow to pass)
    # Dataset builder will handle negatives according to configured policy
    if (work["y"] < 0).any():
        import logging
        logger = logging.getLogger(__name__)

        neg_mask = work["y"] < 0
        neg_count = int(neg_mask.sum())
        by_series = (
            work[neg_mask]
            .groupby("unique_id")
            .agg(count=("y", "count"), min_y=("y", "min"), max_y=("y", "max"))
            .reset_index()
        )

        logger.warning(
            "[validation][NEGATIVE] Found %d negative values (%.1f%%) across %d series",
            neg_count,
            100 * neg_count / len(work),
            len(by_series)
        )

        for _, row in by_series.iterrows():
            logger.warning(
                "  Series %s: %d negative values, range=[%.1f, %.1f]",
                row["unique_id"], row["count"], row["min_y"], row["max_y"]
            )

        logger.info(
            "[validation][NEGATIVE] Negatives will be handled by dataset builder "
            "according to configured negative_policy"
        )

        # Continue validation instead of failing
        # (Dataset builder will handle negatives per policy)

    dup = work.duplicated(subset=["unique_id", "ds"]).sum()
    if dup:
        return ValidationReport(
            False,
            "Duplicate (unique_id, ds) rows found",
            {"duplicates": int(dup)},
        )

    if expected_series:
        expected = sorted(set(expected_series))
        present = sorted(set(work["unique_id"]))
        missing_series = sorted(set(expected) - set(present))
        if missing_series:
            return ValidationReport(
                False,
                "Missing expected series",
                {"missing_series": missing_series, "present_series": present},
            )

    now_utc = pd.Timestamp.now(tz="UTC").floor("h")
    max_ds = work["ds"].max()
    lag_hours = (now_utc - max_ds).total_seconds() / 3600.0
    if lag_hours > max_lag_hours:
        return ValidationReport(
            False,
            "Data not fresh enough",
            {
                "now_utc": now_utc.isoformat(),
                "max_ds": max_ds.isoformat(),
                "lag_hours": lag_hours,
            },
        )

    series_max = work.groupby("unique_id")["ds"].max()
    series_lag = (now_utc - series_max).dt.total_seconds() / 3600.0
    stale = series_lag[series_lag > max_lag_hours].sort_values(ascending=False)
    if not stale.empty:
        return ValidationReport(
            False,
            "Stale series found",
            {
                "stale_series": stale.head(10).to_dict(),
                "max_lag_hours": max_lag_hours,
            },
        )

    missing_ratios = {}
    for uid, group in work.groupby("unique_id"):
        group = group.sort_values("ds")
        start = group["ds"].iloc[0]
        end = group["ds"].iloc[-1]
        expected = int(((end - start) / pd.Timedelta(hours=1)) + 1)
        actual = len(group)
        missing = max(expected - actual, 0)
        missing_ratios[uid] = missing / max(expected, 1)

    worst_uid = max(missing_ratios, key=missing_ratios.get)
    worst_ratio = missing_ratios[worst_uid]
    if worst_ratio > max_missing_ratio:
        return ValidationReport(
            False,
            "Too many missing hourly points",
            {"worst_uid": worst_uid, "worst_missing_ratio": worst_ratio},
        )

    return ValidationReport(
        True,
        "OK",
        {
            "row_count": len(work),
            "series_count": int(work["unique_id"].nunique()),
            "max_ds": max_ds.isoformat(),
            "lag_hours": lag_hours,
            "worst_missing_ratio": worst_ratio,
        },
    )


Overwriting src/renewable/validation.py


In [9]:
# %%writefile src/renewable/modeling.py
# file: src/renewable/modeling.py
"""
Renewable Energy Forecasting Models

This module provides probabilistic forecasting with PHYSICAL CONSTRAINTS.

KEY PRINCIPLE:
Statistical models (ARIMA, ETS, etc.) can produce negative forecasts and
prediction intervals because they assume Gaussian errors. However:

  RENEWABLE ENERGY GENERATION CANNOT BE NEGATIVE.

This module enforces this physical constraint by clipping ALL forecasts
and prediction intervals to [0, ∞). This is NOT "defensive coding" - it's
applying domain knowledge about physical reality.

Model Architecture:
1. StatsForecast for multi-series probabilistic forecasting
2. Post-processing to enforce [0, ∞) constraint
3. Calibration check for prediction intervals
"""

from __future__ import annotations

import logging
import re
from dataclasses import dataclass
from typing import Any, Dict, List, Optional, Tuple

import numpy as np
import pandas as pd

logger = logging.getLogger(__name__)


WEATHER_VARS = [
    "temperature_2m",
    "wind_speed_10m",
    "wind_speed_100m",
    "wind_direction_10m",
    "direct_radiation",
    "diffuse_radiation",
    "cloud_cover",
]

TIME_FEATURES = ["hour_sin", "hour_cos", "dow_sin", "dow_cos"]


@dataclass
class ForecastConfig:
    """Configuration for forecasting."""
    horizon: int = 24
    confidence_levels: Tuple[int, int] = (80, 95)

    # Physical constraints
    enforce_non_negative: bool = True  # ALWAYS True for renewable energy

    # CV settings
    cv_windows: int = 3
    cv_step_size: int = 168  # 1 week


def enforce_physical_constraints(
    df: pd.DataFrame,
    min_value: float = 0.0,
) -> pd.DataFrame:
    """
    Enforce physical constraints on forecasts.

    For renewable energy:
    - Generation cannot be negative
    - All forecast columns (point and intervals) are clipped to [0, ∞)

    This is applying physical reality:
    - A solar panel cannot generate negative power
    - A wind turbine cannot generate negative power

    Args:
        df: Forecast DataFrame with columns like yhat, yhat_lo_80, yhat_hi_80, etc.
        min_value: Minimum physical value (0 for generation)

    Returns:
        DataFrame with all forecasts clipped to [min_value, ∞)
    """
    # Identify forecast columns (exclude unique_id, ds, etc.)
    exclude_cols = {'unique_id', 'ds', 'cutoff', 'y', 'region', 'fuel_type'}
    forecast_cols = [c for c in df.columns if c not in exclude_cols]

    # Count values that will be clipped
    clip_counts = {}
    for col in forecast_cols:
        if col in df.columns and pd.api.types.is_numeric_dtype(df[col]):
            below_min = (df[col] < min_value).sum()
            if below_min > 0:
                clip_counts[col] = int(below_min)

    if clip_counts:
        total_clipped = sum(clip_counts.values())
        total_values = len(df) * len(forecast_cols)
        logger.info(
            f"[PHYSICAL CONSTRAINT] Clipping {total_clipped} values to >= {min_value} "
            f"({total_clipped/total_values:.1%} of forecast values)"
        )
        for col, count in clip_counts.items():
            logger.debug(f"  {col}: {count} values clipped")

    # Apply constraint
    result = df.copy()
    for col in forecast_cols:
        if col in result.columns and pd.api.types.is_numeric_dtype(result[col]):
            result[col] = result[col].clip(lower=min_value)

    return result


def compute_metrics(
    y_true: np.ndarray,
    y_pred: np.ndarray,
) -> Dict[str, float]:
    """
    Compute forecast evaluation metrics.

    Uses RMSE and MAE (NOT MAPE because y=0 at night for solar).
    """
    valid_mask = np.isfinite(y_true) & np.isfinite(y_pred)
    y_true = y_true[valid_mask]
    y_pred = y_pred[valid_mask]

    if len(y_true) == 0:
        return {'rmse': np.nan, 'mae': np.nan, 'valid_rows': 0}

    errors = y_true - y_pred

    return {
        'rmse': float(np.sqrt(np.mean(errors ** 2))),
        'mae': float(np.mean(np.abs(errors))),
        'valid_rows': int(len(y_true)),
    }


def compute_coverage(
    y_true: np.ndarray,
    y_lo: np.ndarray,
    y_hi: np.ndarray,
) -> float:
    """Compute prediction interval coverage."""
    valid = np.isfinite(y_true) & np.isfinite(y_lo) & np.isfinite(y_hi)
    if valid.sum() == 0:
        return np.nan

    in_interval = (y_true[valid] >= y_lo[valid]) & (y_true[valid] <= y_hi[valid])
    return float(in_interval.mean())


class RenewableForecastModel:
    """
    Probabilistic forecasting model with physical constraints.

    Uses StatsForecast for efficient multi-series forecasting,
    then enforces non-negativity on all outputs.
    """

    def __init__(
        self,
        horizon: int = 24,
        confidence_levels: Tuple[int, int] = (80, 95),
    ):
        self.horizon = horizon
        self.confidence_levels = confidence_levels
        self.sf = None
        self._train_df = None
        self._exog_cols: List[str] = []
        self.fitted = False

    def _prepare_training_df(
        self,
        df: pd.DataFrame,
    ) -> pd.DataFrame:
        """
        Prepare training DataFrame.

        Expects preprocessed data from dataset_builder (already has time features
        and weather aligned).
        """
        required = {'unique_id', 'ds', 'y'}
        if not required.issubset(df.columns):
            missing = required - set(df.columns)
            raise ValueError(f"Missing required columns: {missing}")

        if df.empty:
            raise ValueError("Empty DataFrame")

        # Check for required features
        time_features = [c for c in TIME_FEATURES if c in df.columns]
        weather_features = [c for c in WEATHER_VARS if c in df.columns]

        if not time_features:
            raise ValueError(
                "No time features found. Data should be preprocessed by dataset_builder."
            )

        self._exog_cols = time_features + weather_features

        work = df.copy()
        work['ds'] = pd.to_datetime(work['ds'])
        work = work.sort_values(['unique_id', 'ds']).reset_index(drop=True)

        # Validate no negatives in training data
        neg_count = (work['y'] < 0).sum()
        if neg_count > 0:
            raise ValueError(
                f"Training data contains {neg_count} negative values. "
                f"Data should be preprocessed by dataset_builder with negative_policy='clamp_to_zero'."
            )

        logger.info(
            f"[TRAIN] Prepared: {len(work):,} rows, {work['unique_id'].nunique()} series, "
            f"{len(self._exog_cols)} exog features"
        )

        return work

    def fit(self, df: pd.DataFrame) -> None:
        """
        Fit models on training data.

        Args:
            df: Preprocessed DataFrame from dataset_builder
        """
        from statsforecast import StatsForecast
        from statsforecast.models import (MSTL, AutoARIMA, AutoETS,
                                          SeasonalNaive)

        train_df = self._prepare_training_df(df)

        models = [
            MSTL(season_length=[24, 168], trend_forecaster=AutoARIMA(), alias="MSTL_ARIMA"),
            AutoARIMA(season_length=24),
            AutoETS(season_length=24),
            SeasonalNaive(season_length=24),
        ]

        # Try to add expanded models
        try:
            from statsforecast.models import AutoTheta
            models.append(AutoTheta(season_length=24))
            logger.info("[FIT] Using expanded model set: +AutoTheta")
        except ImportError:
            pass

        self.sf = StatsForecast(models=models, freq='h', n_jobs=-1)
        self._train_df = train_df
        self.fitted = True

        logger.info(f"[FIT] Fitted {len(models)} models on {len(train_df):,} rows")

    def cross_validate(
        self,
        df: pd.DataFrame,
        n_windows: int = 3,
        step_size: int = 168,
    ) -> Tuple[pd.DataFrame, pd.DataFrame]:
        """
        Perform cross-validation.

        Returns:
            (cv_results, leaderboard)
        """
        from statsforecast import StatsForecast
        from statsforecast.models import (MSTL, AutoARIMA, AutoETS,
                                          SeasonalNaive)

        train_df = self._prepare_training_df(df)

        models = [
            MSTL(season_length=[24, 168], trend_forecaster=AutoARIMA(), alias="MSTL_ARIMA"),
            AutoARIMA(season_length=24),
            AutoETS(season_length=24),
            SeasonalNaive(season_length=24),
        ]

        try:
            from statsforecast.models import AutoTheta
            models.append(AutoTheta(season_length=24))
        except ImportError:
            pass

        sf = StatsForecast(models=models, freq='h', n_jobs=-1)

        logger.info(
            f"[CV] Running: {n_windows} windows, step={step_size}h, horizon={self.horizon}h"
        )

        cv = sf.cross_validation(
            df=train_df,
            h=self.horizon,
            step_size=step_size,
            n_windows=n_windows,
            level=list(self.confidence_levels),
        ).reset_index()

        # CRITICAL: Apply physical constraints to CV results
        cv = enforce_physical_constraints(cv, min_value=0.0)

        # Build leaderboard
        leaderboard = self._build_leaderboard(cv)

        return cv, leaderboard

    def _build_leaderboard(self, cv_df: pd.DataFrame) -> pd.DataFrame:
        """Build model comparison leaderboard from CV results."""
        # Find model columns (not id/ds/cutoff/y, not interval columns)
        exclude = {'unique_id', 'ds', 'cutoff', 'y'}
        interval_pattern = re.compile(r'-(lo|hi)-\d+$')

        model_cols = [
            c for c in cv_df.columns
            if c not in exclude and not interval_pattern.search(c)
        ]

        rows = []
        y_true = cv_df['y'].values

        for model in model_cols:
            y_pred = cv_df[model].values
            metrics = compute_metrics(y_true, y_pred)

            row = {
                'model': model,
                'rmse': metrics['rmse'],
                'mae': metrics['mae'],
                'valid_rows': metrics['valid_rows'],
            }

            # Add coverage for each confidence level
            for level in self.confidence_levels:
                lo_col = f"{model}-lo-{level}"
                hi_col = f"{model}-hi-{level}"
                if lo_col in cv_df.columns and hi_col in cv_df.columns:
                    coverage = compute_coverage(
                        y_true,
                        cv_df[lo_col].values,
                        cv_df[hi_col].values,
                    )
                    row[f'coverage_{level}'] = coverage

            rows.append(row)

        leaderboard = pd.DataFrame(rows)
        leaderboard = leaderboard.sort_values('rmse').reset_index(drop=True)

        return leaderboard

    def predict(
        self,
        future_exog: pd.DataFrame,
        best_model: Optional[str] = None,
    ) -> pd.DataFrame:
        """
        Generate forecasts.

        Args:
            future_exog: DataFrame with future exogenous features
                         Must have [unique_id, ds] + exog features
            best_model: If specified, only return this model's predictions

        Returns:
            Forecast DataFrame with physical constraints applied
        """
        if not self.fitted:
            raise RuntimeError("Call fit() first")

        # Build future X_df
        X_df = self._build_future_X(future_exog)

        # Generate forecasts
        fcst = self.sf.forecast(
            h=self.horizon,
            df=self._train_df,
            X_df=X_df,
            level=list(self.confidence_levels),
        ).reset_index()

        # CRITICAL: Apply physical constraints
        fcst = enforce_physical_constraints(fcst, min_value=0.0)

        # If best_model specified, filter
        if best_model is not None:
            if best_model not in fcst.columns:
                available = [c for c in fcst.columns if c not in ['unique_id', 'ds']]
                raise ValueError(
                    f"Model '{best_model}' not found. Available: {available}"
                )

            keep_cols = ['unique_id', 'ds', best_model]
            rename_map = {best_model: 'yhat'}

            for level in self.confidence_levels:
                lo = f"{best_model}-lo-{level}"
                hi = f"{best_model}-hi-{level}"
                if lo in fcst.columns:
                    keep_cols.append(lo)
                    rename_map[lo] = f'yhat_lo_{level}'
                if hi in fcst.columns:
                    keep_cols.append(hi)
                    rename_map[hi] = f'yhat_hi_{level}'

            fcst = fcst[keep_cols].rename(columns=rename_map)

        return fcst

    def _build_future_X(self, future_exog: pd.DataFrame) -> pd.DataFrame:
        """Build future exogenous feature DataFrame."""
        required = {'unique_id', 'ds'}
        if not required.issubset(future_exog.columns):
            missing = required - set(future_exog.columns)
            raise ValueError(f"future_exog missing: {missing}")

        # Check exog columns
        missing_exog = [c for c in self._exog_cols if c not in future_exog.columns]
        if missing_exog:
            raise ValueError(
                f"future_exog missing required features: {missing_exog}. "
                f"Expected: {self._exog_cols}"
            )

        X = future_exog[['unique_id', 'ds'] + self._exog_cols].copy()
        X = X.sort_values(['unique_id', 'ds']).reset_index(drop=True)

        return X


def compute_baseline_metrics(
    cv_df: pd.DataFrame,
    model_name: str,
    threshold_k: float = 2.0,
) -> Dict[str, Any]:
    """
    Compute baseline metrics for drift detection.

    Args:
        cv_df: Cross-validation results
        model_name: Model to compute baseline for
        threshold_k: k for threshold = mean + k*std

    Returns:
        Baseline metrics dictionary
    """
    if model_name not in cv_df.columns:
        raise ValueError(f"Model '{model_name}' not in CV results")

    # Compute per-window metrics
    def window_rmse(g):
        metrics = compute_metrics(g['y'].values, g[model_name].values)
        return metrics['rmse']

    per_window = cv_df.groupby(['unique_id', 'cutoff']).apply(
        window_rmse, include_groups=False
    )

    rmse_mean = float(per_window.mean())
    rmse_std = float(per_window.std())

    baseline = {
        'model': model_name,
        'rmse_mean': rmse_mean,
        'rmse_std': rmse_std,
        'drift_threshold_rmse': rmse_mean + threshold_k * rmse_std,
        'n_windows': int(per_window.notna().sum()),
    }

    return baseline


if __name__ == "__main__":
    """
    Smoke test for renewable forecasting models.

    This test validates the complete modeling pipeline using real data:
    1. Loads modeling dataset created by the pipeline
    2. Runs cross-validation with multiple models
    3. Validates physical constraints (no negative forecasts)
    4. Displays performance leaderboard
    """
    import sys
    from pathlib import Path

    from src.renewable.dataset_builder import build_modeling_dataset
    from src.renewable.eia_renewable import EIARenewableFetcher
    from src.renewable.open_meteo import OpenMeteoRenewable

    logging.basicConfig(level=logging.INFO)
    fetcher = EIARenewableFetcher(debug_env=True)

    print("=== Testing Single Region Fetch ===")
    df_single = fetcher.fetch_region("CALI", "WND", "2024-12-01", "2024-12-03", debug=True)
    print(f"Single region: {len(df_single)} rows")
    print(df_single.head())

    print("\n=== Testing Multi-Region Fetch ===")
    df_multi = fetcher.fetch_all_regions("WND", "2024-12-01", "2024-12-03", regions=["CALI", "ERCO", "MISO"])
    print(f"\nMulti-region: {len(df_multi)} rows")
    print(f"Series: {df_multi['unique_id'].unique().tolist()}")

    print("\n=== Series Summary ===")
    print(fetcher.get_series_summary(df_multi))

    # sun checks:
    f = EIARenewableFetcher()
    df = f.fetch_region("CALI", "SUN", "2024-12-01", "2024-12-03", debug=True)
    print(df.head(), len(df))
    
    
    # Real API smoke test (no key needed)
    weather = OpenMeteoRenewable(strict=True)

    print("=== Testing Historical Weather (REAL API) ===")
    hist_df = weather.fetch_for_region("CALI", "2024-12-01", "2024-12-03", debug=True)
    print(f"Historical rows: {len(hist_df)}")
    print(hist_df.head())


    generation_path = Path("data/renewable/generation.parquet")
    weather_path = Path("data/renewable/weather.parquet")

    if not generation_path.exists() or not weather_path.exists():
        print("Data files not found. Run pipeline first.")
        

    generation_df = pd.read_parquet(generation_path)
    weather_df = pd.read_parquet(weather_path)



    # First investigate negatives
    print("\n[TEST 1] Investigating negatives...")
    try:
        _, _ = build_modeling_dataset(
            generation_df, weather_df,
            negative_policy='investigate',
            output_dir=Path("data/renewable/test_investigate")
        )
    except ValueError as e:
        print(str(e))

    # Then build with clamp
    print("\n[TEST 2] Building with clamp_to_zero...")
    modeling_df, report = build_modeling_dataset(
        generation_df, weather_df,
        negative_policy='clamp_to_zero',
        output_dir=Path("data/renewable/preprocessing")
    )

    print(f"\nFinal dataset shape: {modeling_df.shape}")
    print(f"Columns: {modeling_df.columns.tolist()}")
    print(f"\nSample:")
    print(modeling_df.head())

    # CRITICAL: Save the modeling dataset for smoke test
    # This step was missing, causing FileNotFoundError below
    data_path = Path("data/renewable/modeling_dataset.parquet")
    print(f"\n[TEST 3] Saving modeling dataset to {data_path}...")
    modeling_df.to_parquet(data_path, index=False)
    print(f"✅ Saved {len(modeling_df):,} rows ({data_path.stat().st_size / 1024:.1f} KB)")

    # smoke test on these
    # Load preprocessed modeling dataset from pipeline (to verify save worked)
    print("\n" + "="*80)
    print("SMOKE TEST: Renewable Forecasting Models")
    print("="*80)
    print(f"Loading data from: {data_path}")

    df = pd.read_parquet(data_path)

    print(f"Dataset shape: {df.shape}")
    print(f"Columns: {df.columns.tolist()}")
    print(f"Series: {df['unique_id'].unique().tolist()}")
    print(f"Date range: {df['ds'].min()} to {df['ds'].max()}")
    print()

    # Run cross-validation
    print("Running cross-validation (3 windows, 168h step)...")
    model = RenewableForecastModel(horizon=24, confidence_levels=(80, 95))
    cv, leaderboard = model.cross_validate(df, n_windows=3, step_size=168)

    print("\n" + "="*80)
    print("LEADERBOARD (sorted by RMSE)")
    print("="*80)
    print(leaderboard.to_string(index=False))

    print("\n" + "="*80)
    print("PHYSICAL CONSTRAINT VALIDATION")
    print("="*80)
    print(f"Min forecast (MSTL_ARIMA): {cv['MSTL_ARIMA'].min():.2f} MWh")
    print(f"Max forecast (MSTL_ARIMA): {cv['MSTL_ARIMA'].max():.2f} MWh")
    print(f"Any negative forecasts: {(cv['MSTL_ARIMA'] < 0).any()}")

    if (cv['MSTL_ARIMA'] < 0).any():
        print("\n⚠️ WARNING: Negative forecasts detected!")
        print("This violates physical constraints (renewable generation cannot be negative)")
        neg_count = (cv['MSTL_ARIMA'] < 0).sum()
        print(f"Count: {neg_count} out of {len(cv)} ({100*neg_count/len(cv):.2f}%)")
    else:
        print("\n✅ SUCCESS: All forecasts are non-negative (physical constraints satisfied)")

    print("="*80)

    print("="*80)


2026-01-22 16:50:41,286 - src.renewable.eia_renewable - INFO - Loaded .env via find_dotenv: c:\docker_projects\atsaf\.env
2026-01-22 16:50:41,287 - src.renewable.eia_renewable - INFO - EIA_API_KEY loaded (masked): 8pqH...8Xk7
2026-01-22 16:50:41,287 - src.renewable.eia_renewable - INFO - Request timeout: 60 seconds


=== Testing Single Region Fetch ===
[PAGE] region=CALI fuel=WND returned=72 offset=0 total=72 url=https://api.eia.gov/v2/electricity/rto/fuel-type-data/data/?data%5B%5D=value&facets%5Brespondent%5D%5B%5D=CISO&facets%5Bfueltype%5D%5B%5D=WND&frequency=hourly&start=2024-12-01T00&end=2024-12-03T23&length=5000&offset=0&sort%5B0%5D%5Bcolumn%5D=period&sort%5B0%5D%5Bdirection%5D=asc
Single region: 72 rows
                   ds  value region fuel_type
0 2024-12-01 00:00:00    359   CALI       WND
1 2024-12-01 01:00:00    277   CALI       WND
2 2024-12-01 02:00:00    498   CALI       WND
3 2024-12-01 03:00:00    692   CALI       WND
4 2024-12-01 04:00:00    879   CALI       WND

=== Testing Multi-Region Fetch ===
[OK] CALI: 72 rows
[OK] MISO: 72 rows
[OK] ERCO: 72 rows
[SUMMARY] WND data: 3 series, 216 total rows

Multi-region: 216 rows
Series: ['CALI_WND', 'ERCO_WND', 'MISO_WND']

=== Series Summary ===
  unique_id  count  min_value  max_value   mean_value  zero_count
0  CALI_WND     72        



[PAGE] region=CALI fuel=SUN returned=72 offset=0 total=72 url=https://api.eia.gov/v2/electricity/rto/fuel-type-data/data/?data%5B%5D=value&facets%5Brespondent%5D%5B%5D=CISO&facets%5Bfueltype%5D%5B%5D=SUN&frequency=hourly&start=2024-12-01T00&end=2024-12-03T23&length=5000&offset=0&sort%5B0%5D%5Bcolumn%5D=period&sort%5B0%5D%5Bdirection%5D=asc
                   ds  value region fuel_type
0 2024-12-01 00:00:00   7805   CALI       SUN
1 2024-12-01 01:00:00   3369   CALI       SUN
2 2024-12-01 02:00:00    186   CALI       SUN
3 2024-12-01 03:00:00    -47   CALI       SUN
4 2024-12-01 04:00:00    -41   CALI       SUN 72
=== Testing Historical Weather (REAL API) ===


2026-01-22 16:51:14,536 - src.renewable.dataset_builder - INFO - DATASET BUILDER - Building Modeling Dataset
2026-01-22 16:51:14,541 - src.renewable.dataset_builder - INFO - Input: 4,358 rows, 6 series
2026-01-22 16:51:14,541 - src.renewable.dataset_builder - INFO - Date range: 2025-12-23 00:00:00 to 2026-01-22 07:00:00
2026-01-22 16:51:14,541 - src.renewable.dataset_builder - INFO - 
[1/4] Applying negative value policy (policy=investigate)...
2026-01-22 16:51:14,543 - src.renewable.dataset_builder - INFO - DATASET BUILDER - Building Modeling Dataset
2026-01-22 16:51:14,547 - src.renewable.dataset_builder - INFO - Input: 4,358 rows, 6 series
2026-01-22 16:51:14,548 - src.renewable.dataset_builder - INFO - Date range: 2025-12-23 00:00:00 to 2026-01-22 07:00:00
2026-01-22 16:51:14,548 - src.renewable.dataset_builder - INFO - 
[1/4] Applying negative value policy (policy=clamp_to_zero)...
2026-01-22 16:51:14,550 - src.renewable.dataset_builder - INFO - [PREPROCESSING] Clamped 403 negativ

[OPENMETEO][HIST] status=200 url=https://archive-api.open-meteo.com/v1/archive?latitude=36.7&longitude=-119.4&start_date=2024-12-01&end_date=2024-12-03&hourly=temperature_2m%2Cwind_speed_10m%2Cwind_speed_100m%2Cwind_direction_10m%2Cdirect_radiation%2Cdiffuse_radiation%2Ccloud_cover&timezone=UTC
[OPENMETEO][PARSE] rows=72 dup_ds=0 na_counts(sample)={'temperature_2m': 0, 'wind_speed_10m': 0, 'wind_speed_100m': 0}
Historical rows: 72
                   ds  temperature_2m  wind_speed_10m  wind_speed_100m  \
0 2024-12-01 00:00:00            12.2             5.9              7.0   
1 2024-12-01 01:00:00             9.2             1.2              5.3   
2 2024-12-01 02:00:00             8.4             2.4              3.6   
3 2024-12-01 03:00:00             8.7             3.8              5.4   
4 2024-12-01 04:00:00            10.3             0.9              2.2   

   wind_direction_10m  direct_radiation  diffuse_radiation  cloud_cover region  
0                  72              84.0

  from .autonotebook import tqdm as notebook_tqdm
2026-01-22 16:51:15,773 - __main__ - INFO - [TRAIN] Prepared: 4,358 rows, 6 series, 11 exog features
2026-01-22 16:51:15,774 - __main__ - INFO - [CV] Running: 3 windows, step=168h, horizon=24h
2026-01-22 16:59:02,090 - __main__ - INFO - [PHYSICAL CONSTRAINT] Clipping 2772 values to >= 0.0 (24.7% of forecast values)



LEADERBOARD (sorted by RMSE)
        model         rmse         mae  valid_rows  coverage_80  coverage_95
   MSTL_ARIMA  4492.481161 2804.936829         432     0.615741     0.780093
    AutoTheta  4750.601078 2568.060959         432     0.780093     0.851852
    AutoARIMA  4804.202861 2613.864402         432     0.759259     0.858796
SeasonalNaive  5653.479147 3193.574074         432     0.761574     0.905093
      AutoETS  7115.548627 3065.900143         432     0.715278     0.803241
        index 10685.852922 7173.983796         432          NaN          NaN

PHYSICAL CONSTRAINT VALIDATION
Min forecast (MSTL_ARIMA): 0.00 MWh
Max forecast (MSTL_ARIMA): 25864.32 MWh
Any negative forecasts: False

✅ SUCCESS: All forecasts are non-negative (physical constraints satisfied)



# Module: Pipeline Tasks

**File:** `src/renewable/tasks.py`

This module orchestrates the complete pipeline:

1. **Fetch generation data** from EIA
2. **Fetch weather data** from Open-Meteo
3. **Train models** with cross-validation
4. **Generate forecasts** with prediction intervals
5. **Compute drift metrics** vs baseline

## Key Feature: Adaptive CV

Cross-validation requires sufficient data:
```
Minimum rows = horizon + (n_windows × step_size)
```

For short series, we **adapt** the CV settings automatically.

In [10]:
%%writefile src/renewable/tasks.py
# file: src/renewable/tasks.py
"""Renewable energy forecasting pipeline tasks.

Idempotent tasks for:
- Fetching EIA renewable generation data
- Fetching weather data from Open-Meteo
- Training probabilistic models
- Generating forecasts with intervals
- Computing drift metrics
"""

import argparse
import logging
from dataclasses import dataclass, field
from datetime import datetime, timedelta, timezone
from pathlib import Path
from typing import Optional

import numpy as np
import pandas as pd

# Optional imports for interpretability (LightGBM + skforecast)
try:
    from lightgbm import LGBMRegressor
    from skforecast.recursive import ForecasterRecursive
    INTERPRETABILITY_AVAILABLE = True
except ImportError:
    INTERPRETABILITY_AVAILABLE = False
    logger = logging.getLogger(__name__)
    logger.warning("lightgbm and/or skforecast not installed - interpretability features unavailable")

from src.renewable.eia_renewable import EIARenewableFetcher
from src.renewable.modeling import (
    RenewableForecastModel,
    compute_baseline_metrics,
    enforce_physical_constraints,
    WEATHER_VARS,
)
from src.renewable.model_interpretability import (
    InterpretabilityReport,
    generate_full_interpretability_report,
)
from src.renewable.open_meteo import OpenMeteoRenewable
from src.renewable.dataset_builder import _add_time_features, build_modeling_dataset

logger = logging.getLogger(__name__)


@dataclass
class RenewablePipelineConfig:
    """Configuration for renewable forecasting pipeline."""

    # Data parameters
    regions: list[str] = field(default_factory=lambda: ["CALI", "ERCO", "MISO", "PJM", "SWPP"])
    fuel_types: list[str] = field(default_factory=lambda: ["WND", "SUN"])
    start_date: str = ""  # Set dynamically
    end_date: str = ""  # Set dynamically
    lookback_days: int = 30

    # Forecast parameters
    horizon: int = 24
    confidence_levels: tuple[int, int] = (80, 95)
    horizon_preset: Optional[str] = None  # "24h" | "48h" | "72h"

    # CV parameters
    cv_windows: int = 5
    cv_step_size: int = 168  # 1 week

    # Model parameters
    enable_interpretability: bool = True  # LightGBM SHAP analysis (on by default)

    # Preprocessing parameters
    negative_policy: str = "clamp"  # "clamp" | "fail_loud" | "hybrid"
    hourly_grid_policy: str = "drop_incomplete_series"  # "drop_incomplete_series" | "fail_loud"

    # Output paths
    data_dir: str = "data/renewable"
    overwrite: bool = False

    # Horizon preset definitions (class-level constant)
    _PRESETS = {
        "24h": {"horizon": 24, "cv_windows": 2, "lookback_days": 15},
        "48h": {"horizon": 48, "cv_windows": 3, "lookback_days": 21},
        "72h": {"horizon": 72, "cv_windows": 3, "lookback_days": 28},
    }

    def __post_init__(self):
        # Apply horizon preset if specified
        if self.horizon_preset and self.horizon_preset in self._PRESETS:
            preset = self._PRESETS[self.horizon_preset]
            # Use object.__setattr__ since this is a dataclass
            object.__setattr__(self, "horizon", preset["horizon"])
            object.__setattr__(self, "cv_windows", preset["cv_windows"])
            object.__setattr__(self, "lookback_days", preset["lookback_days"])
            logger.info(f"[config] Applied preset '{self.horizon_preset}': horizon={preset['horizon']}h")

        # Set default dates if not provided
        if not self.end_date:
            self.end_date = datetime.now(timezone.utc).strftime("%Y-%m-%d")
        if not self.start_date:
            end = datetime.strptime(self.end_date, "%Y-%m-%d")
            start = end - timedelta(days=self.lookback_days)
            self.start_date = start.strftime("%Y-%m-%d")

        # Validate configuration
        warnings = self._validate()
        for warning in warnings:
            logger.warning(f"[config] {warning}")

    def _validate(self) -> list[str]:
        """Validate configuration and return warnings."""
        warnings = []

        # Check minimum data requirement
        available_hours = self.lookback_days * 24
        required_hours = self.horizon + (self.cv_windows * self.cv_step_size)
        if available_hours < required_hours:
            warnings.append(
                f"Insufficient data: need {required_hours}h, have {available_hours}h. "
                f"Increase lookback_days to {(required_hours // 24) + 1} or reduce cv_windows."
            )

        # Warn about accuracy degradation
        if self.horizon > 72:
            warnings.append(
                f"Horizon {self.horizon}h exceeds recommended max (72h). "
                f"Weather forecast accuracy degrades significantly beyond 3 days."
            )

        return warnings

    def generation_path(self) -> Path:
        return Path(self.data_dir) / "generation.parquet"

    def weather_path(self) -> Path:
        return Path(self.data_dir) / "weather.parquet"

    def forecasts_path(self) -> Path:
        return Path(self.data_dir) / "forecasts.parquet"

    def baseline_path(self) -> Path:
        return Path(self.data_dir) / "baseline.json"

    def interpretability_dir(self) -> Path:
        return Path(self.data_dir) / "interpretability"

    def preprocessing_dir(self) -> Path:
        return Path(self.data_dir) / "preprocessing"


def fetch_renewable_data(
    config: RenewablePipelineConfig,
    fetch_diagnostics: Optional[list[dict]] = None,
) -> pd.DataFrame:
    """Task 1: Fetch EIA generation data for all regions and fuel types.

    Args:
        config: Pipeline configuration
        fetch_diagnostics: Optional list to capture per-region fetch metadata

    Returns:
        DataFrame with columns [unique_id, ds, y]
    """
    output_path = config.generation_path()
    output_path.parent.mkdir(parents=True, exist_ok=True)

    def _log_generation_summary(df: pd.DataFrame, source: str) -> None:


        expected_series = {
            f"{region}_{fuel}" for region in config.regions for fuel in config.fuel_types
        }
        present_series = set(df["unique_id"]) if "unique_id" in df.columns else set()
        missing_series = sorted(expected_series - present_series)
        if missing_series:
            logger.warning(
                "[fetch_generation] Missing expected series (%s): %s",
                source,
                missing_series,
            )

        if df.empty:
            logger.warning("[fetch_generation] No generation data rows (%s).", source)
            return

        coverage = (
            df.groupby("unique_id")["ds"]
            .agg(min_ds="min", max_ds="max", rows="count")
            .reset_index()
            .sort_values("unique_id")
        )
        max_series_log = 25
        if len(coverage) > max_series_log:
            logger.info(
                "[fetch_generation] Coverage (%s, first %s series):\n%s",
                source,
                max_series_log,
                coverage.head(max_series_log).to_string(index=False),
            )
        else:
            logger.info("[fetch_generation] Coverage (%s):\n%s", source, coverage.to_string(index=False))

    if output_path.exists() and not config.overwrite:
        logger.info(f"[fetch_generation] exists, loading: {output_path}")
        cached = pd.read_parquet(output_path)
        # Log cached coverage to surface missing series without refetching.
        _log_generation_summary(cached, source="cache")
        return cached

    logger.info(f"[fetch_generation] Fetching {config.fuel_types} for {config.regions}")

    # Use longer timeout (90s) to handle slow EIA API responses
    fetcher = EIARenewableFetcher(timeout=90)
    all_dfs = []

    for fuel_type in config.fuel_types:
        df = fetcher.fetch_all_regions(
            fuel_type=fuel_type,
            start_date=config.start_date,
            end_date=config.end_date,
            regions=config.regions,
            diagnostics=fetch_diagnostics,
        )
        all_dfs.append(df)

    combined = pd.concat(all_dfs, ignore_index=True)
    combined = combined.sort_values(["unique_id", "ds"]).reset_index(drop=True)

    # Log fresh coverage to highlight gaps or unexpected negatives.
    _log_generation_summary(combined, source="fresh")

    if fetch_diagnostics:
        empty_series = [
            entry
            for entry in fetch_diagnostics
            if entry.get("empty")
        ]
        for entry in empty_series:
            logger.warning(
                "[fetch_generation] Empty series detail: region=%s fuel=%s total=%s pages=%s",
                entry.get("region"),
                entry.get("fuel_type"),
                entry.get("total_records"),
                entry.get("pages"),
            )

    combined.to_parquet(output_path, index=False)
    logger.info(f"[fetch_generation] Saved: {output_path} ({len(combined)} rows)")

    return combined


def fetch_renewable_weather(
    config: RenewablePipelineConfig,
    include_forecast: bool = True,
) -> pd.DataFrame:
    """Task 2: Fetch weather data for all regions.

    Args:
        config: Pipeline configuration
        include_forecast: Include forecast weather for predictions

    Returns:
        DataFrame with columns [ds, region, weather_vars...]
    """
    output_path = config.weather_path()
    output_path.parent.mkdir(parents=True, exist_ok=True)

    def _log_weather_summary(df: pd.DataFrame, source: str) -> None:
        if df.empty:
            logger.warning("[fetch_weather] No weather data rows (%s).", source)
            return

        coverage = (
            df.groupby("region")["ds"]
            .agg(min_ds="min", max_ds="max", rows="count")
            .reset_index()
            .sort_values("region")
        )
        max_region_log = 25
        if len(coverage) > max_region_log:
            logger.info(
                "[fetch_weather] Coverage (%s, first %s regions):\n%s",
                source,
                max_region_log,
                coverage.head(max_region_log).to_string(index=False),
            )
        else:
            logger.info("[fetch_weather] Coverage (%s):\n%s", source, coverage.to_string(index=False))

        missing_cols = [
            col for col in OpenMeteoRenewable.WEATHER_VARS if col not in df.columns
        ]
        if missing_cols:
            logger.warning(
                "[fetch_weather] Missing expected weather columns (%s): %s",
                source,
                missing_cols,
            )

        missing_values = {
            col: int(df[col].isna().sum())
            for col in OpenMeteoRenewable.WEATHER_VARS
            if col in df.columns and df[col].isna().any()
        }
        if missing_values:
            logger.warning(
                "[fetch_weather] Missing weather values (%s): %s",
                source,
                missing_values,
            )

    if output_path.exists() and not config.overwrite:
        logger.info(f"[fetch_weather] exists, loading: {output_path}")
        cached = pd.read_parquet(output_path)
        # Log cached weather coverage to surface missing regions/columns.
        _log_weather_summary(cached, source="cache")
        return cached

    logger.info(f"[fetch_weather] Fetching weather for {config.regions}")

    weather = OpenMeteoRenewable()

    # Historical weather
    hist_df = weather.fetch_all_regions_historical(
        regions=config.regions,
        start_date=config.start_date,
        end_date=config.end_date,
    )

    # Validate historical weather result
    if hist_df.empty:
        raise RuntimeError(
            "[fetch_weather] Historical weather returned empty DataFrame. "
            "fetch_all_regions_historical should raise an error on failure, "
            "but received empty result. Check fetch logic."
        )

    if not {"ds", "region"}.issubset(hist_df.columns):
        missing_cols = {"ds", "region"} - set(hist_df.columns)
        raise ValueError(
            f"[fetch_weather] Weather DataFrame missing required columns: {missing_cols}"
        )

    hist_regions = hist_df['region'].nunique()
    hist_rows = len(hist_df)
    logger.info(
        f"[fetch_weather] Historical: {hist_regions} regions, {hist_rows} rows"
    )

    # Forecast weather (for prediction, prevents leakage)
    if include_forecast:
        fcst_df = weather.fetch_all_regions_forecast(
            regions=config.regions,
            horizon_hours=config.horizon + 24,  # Buffer
        )

        # Validate forecast weather result
        if fcst_df.empty:
            logger.warning(
                "[fetch_weather] Forecast weather returned empty DataFrame. "
                "Using historical data only for model training and predictions."
            )
            combined = hist_df
        else:
            fcst_rows = len(fcst_df)
            logger.info(f"[fetch_weather] Forecast: {fcst_rows} rows")

            # Combine, preferring forecast for overlapping times
            combined = pd.concat([hist_df, fcst_df], ignore_index=True)
            combined = combined.drop_duplicates(subset=["ds", "region"], keep="last")
    else:
        combined = hist_df

    combined = combined.sort_values(["region", "ds"]).reset_index(drop=True)

    # Log fresh weather coverage and missing values before saving.
    _log_weather_summary(combined, source="fresh")

    combined.to_parquet(output_path, index=False)
    logger.info(f"[fetch_weather] Saved: {output_path} ({len(combined)} rows)")

    return combined


def train_renewable_models(
    config: RenewablePipelineConfig,
    modeling_df: Optional[pd.DataFrame] = None,
) -> tuple[pd.DataFrame, pd.DataFrame, dict]:
    """Task 3: Train models and compute baseline metrics via cross-validation.

    Args:
        config: Pipeline configuration
        modeling_df: Preprocessed modeling dataset from build_modeling_dataset()
                     (loads and builds from scratch if None)

    Returns:
        Tuple of (cv_results, leaderboard, baseline_metrics)
    """
    # Load and preprocess data if not provided
    if modeling_df is None:
        generation_df = pd.read_parquet(config.generation_path())
        weather_df = pd.read_parquet(config.weather_path())

        logger.info("[train_models] Building modeling dataset...")
        modeling_df, _ = build_modeling_dataset(
            generation_df,
            weather_df,
            negative_policy='clamp_to_zero',
            output_dir=config.preprocessing_dir()
        )

    logger.info(f"[train_models] Training on {len(modeling_df)} rows")

    model = RenewableForecastModel(
        horizon=config.horizon,
        confidence_levels=config.confidence_levels,
    )

    # Compute adaptive CV settings based on shortest series
    min_series_len = modeling_df.groupby("unique_id").size().min()

    # CV needs: horizon + (n_windows * step_size) rows minimum
    # Solve for n_windows: n_windows = (min_series_len - horizon) / step_size
    available_for_cv = min_series_len - config.horizon

    # Adjust step_size and n_windows to fit data
    step_size = min(config.cv_step_size, max(24, available_for_cv // 3))
    n_windows = min(config.cv_windows, max(2, available_for_cv // step_size))

    logger.info(
        f"[train_models] Adaptive CV: {n_windows} windows, "
        f"step={step_size}h (min_series={min_series_len} rows)"
    )

    # Cross-validation (modeling_df already has weather merged and time features added)
    cv_results, leaderboard = model.cross_validate(
        df=modeling_df,
        n_windows=n_windows,
        step_size=step_size,
    )

    best_model = leaderboard.iloc[0]["model"]
    baseline = compute_baseline_metrics(cv_results, model_name=best_model)

    logger.info(f"[train_models] Best model: {best_model}, RMSE: {baseline['rmse_mean']:.1f}")

    return cv_results, leaderboard, baseline


def train_interpretability_models(
    config: RenewablePipelineConfig,
    generation_df: Optional[pd.DataFrame] = None,
    weather_df: Optional[pd.DataFrame] = None,
) -> dict[str, InterpretabilityReport]:
    """Train LightGBM models and generate interpretability reports per series.

    This trains a separate LightGBM model for each series (region × fuel type)
    and generates SHAP, partial dependence, and feature importance artifacts.

    Note: LightGBM is used for interpretability only. The primary forecasts
    come from statistical models (MSTL/ARIMA) which provide better uncertainty
    quantification.

    Args:
        config: Pipeline configuration
        generation_df: Generation data (loads from file if None)
        weather_df: Weather data (loads from file if None)

    Returns:
        Dict mapping series_id -> InterpretabilityReport
    """
    # Load data if not provided
    if generation_df is None:
        generation_df = pd.read_parquet(config.generation_path())
    if weather_df is None:
        weather_df = pd.read_parquet(config.weather_path())

    logger.info(f"[train_interpretability] Training LightGBM for {generation_df['unique_id'].nunique()} series")

    # Ensure datetime types
    generation_df = generation_df.copy()
    generation_df["ds"] = pd.to_datetime(generation_df["ds"], errors="raise")
    weather_df = weather_df.copy()
    weather_df["ds"] = pd.to_datetime(weather_df["ds"], errors="raise")

    reports: dict[str, InterpretabilityReport] = {}
    output_dir = config.interpretability_dir()
    output_dir.mkdir(parents=True, exist_ok=True)

    for uid in sorted(generation_df["unique_id"].unique()):
        logger.info(f"[train_interpretability] Processing {uid}...")

        # Extract series data
        series_data = generation_df[generation_df["unique_id"] == uid].copy()
        series_data = series_data.sort_values("ds")

        # Prepare target series with proper frequency
        y = series_data.set_index("ds")["y"]
        y.index = pd.DatetimeIndex(y.index, freq="h")  # Set hourly frequency

        # Prepare exogenous features
        region = uid.split("_")[0]
        series_weather = weather_df[weather_df["region"] == region].copy()

        if series_weather.empty:
            logger.warning(f"[train_interpretability] No weather data for region {region}, skipping {uid}")
            continue

        # Merge weather to series timestamps
        series_data = series_data.merge(
            series_weather[["ds"] + [c for c in WEATHER_VARS if c in series_weather.columns]],
            on="ds",
            how="left",
        )

        # Add time features
        series_data = _add_time_features(series_data)

        # Build exog DataFrame aligned with y
        exog_cols = ["hour_sin", "hour_cos", "dow_sin", "dow_cos"]
        exog_cols += [c for c in WEATHER_VARS if c in series_data.columns]
        exog = series_data.set_index("ds")[exog_cols]

        # Check for missing weather
        missing_weather = exog.isna().any(axis=1).sum()
        if missing_weather > 0:
            logger.warning(f"[train_interpretability] {uid}: {missing_weather} rows with missing weather, filling with ffill/bfill")
            exog = exog.ffill().bfill()

        # Fit LightGBM forecaster
        try:
            if not INTERPRETABILITY_AVAILABLE:
                logger.warning(f"[train_interpretability] {uid}: lightgbm/skforecast not available, skipping")
                continue

            # Create skforecast ForecasterRecursive with LightGBM estimator
            forecaster = ForecasterRecursive(
                estimator=LGBMRegressor(
                    random_state=42,
                    verbose=-1,
                    n_estimators=100,
                    learning_rate=0.05,
                    max_depth=6,
                ),
                lags=168,  # 7 days of lags
            )
            forecaster.fit(y=y, exog=exog)

            # Create training matrices for SHAP analysis
            X_train, y_train = forecaster.create_train_X_y(y=y, exog=exog)

            # Generate interpretability report
            series_output_dir = output_dir / uid
            report = generate_full_interpretability_report(
                forecaster=forecaster,
                X_train=X_train,
                series_id=uid,
                output_dir=series_output_dir,
                top_n_features=5,
                shap_sample_frac=0.5,
                shap_max_samples=1000,
            )
            reports[uid] = report

            logger.info(
                f"[train_interpretability] {uid}: top_features={report.top_features[:3]}"
            )

        except Exception as e:
            logger.error(f"[train_interpretability] {uid}: Failed to train - {e}")
            continue

    logger.info(f"[train_interpretability] Generated {len(reports)} interpretability reports")
    return reports


def generate_renewable_forecasts(
    config: RenewablePipelineConfig,
    modeling_df: Optional[pd.DataFrame] = None,
    weather_df: Optional[pd.DataFrame] = None,
    best_model: str = "MSTL_ARIMA",
) -> pd.DataFrame:
    """Task 4: Generate forecasts with prediction intervals.

    Args:
        config: Pipeline configuration
        modeling_df: Preprocessed modeling dataset (if None, loads and builds)
        weather_df: Raw weather data with forecast (if None, loads from file)
        best_model: Model to use for forecasting

    Returns:
        Forecast DataFrame with physical constraints applied
    """
    output_path = config.forecasts_path()
    output_path.parent.mkdir(parents=True, exist_ok=True)

    # Load and preprocess data if not provided
    if modeling_df is None:
        generation_df = pd.read_parquet(config.generation_path())
        if weather_df is None:
            weather_df = pd.read_parquet(config.weather_path())

        logger.info("[generate_forecasts] Building modeling dataset...")
        modeling_df, _ = build_modeling_dataset(
            generation_df,
            weather_df,
            negative_policy='clamp_to_zero',
            output_dir=config.preprocessing_dir()
        )

    if weather_df is None:
        weather_df = pd.read_parquet(config.weather_path())

    logger.info(
        f"[generate_forecasts] Generating {config.horizon}h forecasts "
        f"using model={best_model}"
    )

    # Ensure datetime types
    modeling_df = modeling_df.copy()
    modeling_df["ds"] = pd.to_datetime(modeling_df["ds"], errors="raise")
    weather_df = weather_df.copy()
    weather_df["ds"] = pd.to_datetime(weather_df["ds"], errors="raise")

    model = RenewableForecastModel(
        horizon=config.horizon,
        confidence_levels=config.confidence_levels,
    )

    # Fit on preprocessed modeling data
    model.fit(modeling_df)

    # Prepare future exogenous features for forecasting
    # We need weather + time features for the forecast horizon
    per_series_max = modeling_df.groupby("unique_id")["ds"].max()
    logger.info(
        f"[generate_forecasts] Per-series max timestamps:\n"
        f"{per_series_max.to_dict()}"
    )

    min_of_max = per_series_max.min()
    global_max = modeling_df["ds"].max()

    logger.info(
        f"[generate_forecasts] Min of series maxes: {min_of_max}, "
        f"Global max: {global_max}, "
        f"Delta: {(global_max - min_of_max).total_seconds() / 3600:.1f}h"
    )

    # Get future weather beyond the last training timestamp
    future_weather = weather_df[weather_df["ds"] > min_of_max].copy()

    if future_weather.empty:
        raise RuntimeError(
            "[generate_forecasts] No future weather found after last "
            f"training timestamp. min_of_max={min_of_max}"
        )

    # Build future_exog by preparing timestamps and merging weather
    unique_ids = modeling_df["unique_id"].unique()
    future_timestamps = pd.date_range(
        start=min_of_max + pd.Timedelta(hours=1),
        periods=config.horizon,
        freq="h"
    )

    # Create future_exog with all series x timestamps combinations
    future_exog = pd.DataFrame([
        {"unique_id": uid, "ds": ts}
        for uid in unique_ids
        for ts in future_timestamps
    ])

    # Add region for weather merge
    future_exog["region"] = future_exog["unique_id"].str.split("_").str[0]

    # Merge weather
    available_weather_vars = [
        c for c in WEATHER_VARS if c in future_weather.columns
    ]
    future_exog = future_exog.merge(
        future_weather[["ds", "region"] + available_weather_vars],
        on=["ds", "region"],
        how="left"
    )

    # Check for missing weather
    missing_weather = future_exog[available_weather_vars].isna().any(axis=1)
    if missing_weather.any():
        missing_count = missing_weather.sum()
        logger.warning(
            f"[generate_forecasts] {missing_count} future rows missing "
            f"weather, dropping them"
        )
        future_exog = future_exog[~missing_weather].reset_index(drop=True)

    # Add time features (same as dataset_builder)
    future_exog["hour"] = future_exog["ds"].dt.hour
    future_exog["dow"] = future_exog["ds"].dt.dayofweek
    future_exog["hour_sin"] = np.sin(2 * np.pi * future_exog["hour"] / 24)
    future_exog["hour_cos"] = np.cos(2 * np.pi * future_exog["hour"] / 24)
    future_exog["dow_sin"] = np.sin(2 * np.pi * future_exog["dow"] / 7)
    future_exog["dow_cos"] = np.cos(2 * np.pi * future_exog["dow"] / 7)
    future_exog = future_exog.drop(columns=["hour", "dow", "region"])

    # Generate forecasts
    logger.info(
        f"[generate_forecasts] Generating predictions using "
        f"model: {best_model}"
    )
    forecasts = model.predict(future_exog=future_exog, best_model=best_model)

    # CRITICAL: Apply physical constraints (renewable generation cannot be negative)
    # This matches the constraint enforcement during cross-validation (modeling.py:301)
    forecasts = enforce_physical_constraints(forecasts, min_value=0.0)
    logger.info("[generate_forecasts] Applied physical constraints (clipped to [0, ∞))")

    logger.info(
        f"[generate_forecasts] Generated {len(forecasts)} forecast rows "
        f"for {forecasts['unique_id'].nunique()} series"
    )

    forecasts.to_parquet(output_path, index=False)
    logger.info(
        f"[generate_forecasts] Saved: {output_path} ({len(forecasts)} rows)"
    )

    return forecasts


def compute_renewable_drift(
    predictions: pd.DataFrame,
    actuals: pd.DataFrame,
    baseline_metrics: dict,
) -> dict:
    """Task 5: Detect drift by comparing current metrics to baseline.

    Drift is flagged when current RMSE > baseline_mean + 2*baseline_std

    Args:
        predictions: Forecast DataFrame with [unique_id, ds, yhat]
        actuals: Actual values DataFrame with [unique_id, ds, y]
        baseline_metrics: Baseline from cross-validation

    Returns:
        Dictionary with drift status and details
    """
    from src.chapter2.evaluation import ForecastMetrics

    # Merge predictions with actuals
    merged = predictions.merge(
        actuals[["unique_id", "ds", "y"]],
        on=["unique_id", "ds"],
        how="inner",
    )

    if len(merged) == 0:
        return {
            "status": "no_data",
            "message": "No overlapping data between predictions and actuals",
        }

    # Compute current metrics
    y_true = merged["y"].values
    y_pred = merged["yhat"].values

    current_rmse = ForecastMetrics.rmse(y_true, y_pred)
    current_mae = ForecastMetrics.mae(y_true, y_pred)

    # Check against threshold
    threshold = baseline_metrics.get("drift_threshold_rmse", float("inf"))
    is_drifting = current_rmse > threshold

    result = {
        "status": "drift_detected" if is_drifting else "stable",
        "current_rmse": float(current_rmse),
        "current_mae": float(current_mae),
        "baseline_rmse": float(baseline_metrics.get("rmse_mean", 0)),
        "drift_threshold": float(threshold),
        "threshold_exceeded_by": float(max(0, current_rmse - threshold)),
        "n_predictions": len(merged),
        "timestamp": datetime.utcnow().isoformat(),
    }

    if is_drifting:
        logger.warning(
            f"[drift] DRIFT DETECTED: RMSE={current_rmse:.1f} > threshold={threshold:.1f}"
        )
    else:
        logger.info(f"[drift] Stable: RMSE={current_rmse:.1f} <= threshold={threshold:.1f}")

    return result


def run_full_pipeline(
    config: RenewablePipelineConfig,
    fetch_diagnostics: Optional[list[dict]] = None,
    skip_eda: bool = False,
) -> dict:
    """Run the complete renewable forecasting pipeline.

    Steps:
    1. Fetch generation data
    2. Fetch weather data
    3. Run EDA and get recommendations (NEW)
    4. Build datasets per fuel type using recommendations
    5. Train models (CV)
    6. Generate forecasts

    Args:
        config: Pipeline configuration
        fetch_diagnostics: Optional list to capture per-region fetch metadata

    Returns:
        Dictionary with pipeline results
    """
    logger.info(f"[pipeline] Starting: {config.start_date} to {config.end_date}")
    logger.info(f"[pipeline] Regions: {config.regions}")
    logger.info(f"[pipeline] Fuel types: {config.fuel_types}")

    results = {}

    # Step 1: Fetch generation
    generation_df = fetch_renewable_data(config, fetch_diagnostics=fetch_diagnostics)
    results["generation_rows"] = len(generation_df)
    results["series_count"] = generation_df["unique_id"].nunique()

    from src.renewable.validation import validate_generation_df

    expected_series = [f"{r}_{f}" for r in config.regions for f in config.fuel_types]
    rep = validate_generation_df(
        generation_df,
        expected_series=expected_series,
        max_missing_ratio=0.02,
        max_lag_hours=48,  # choose a value consistent with EIA publishing lag
    )
    if not rep.ok:
        raise RuntimeError(f"[pipeline][generation_validation] {rep.message} details={rep.details}")

    # Step 2: Fetch weather
    weather_df = fetch_renewable_weather(config)
    results["weather_rows"] = len(weather_df)

    # Step 3: Run EDA (NEW)
    eda_recommendations = None

    if not skip_eda:
        logger.info("[pipeline] Running EDA to generate preprocessing recommendations")
        from src.renewable.eda import run_full_eda

        eda_output_dir = Path(config.data_dir) / "eda"
        eda_recommendations = run_full_eda(
            generation_df,
            weather_df,
            eda_output_dir,
        )

        results["eda"] = {
            "output_dir": str(eda_output_dir),
            "negative_policy": eda_recommendations.preprocessing.negative_policy,
            "confidence": eda_recommendations.preprocessing.negative_confidence,
        }
        logger.info(f"[pipeline] EDA complete. Recommended policy: {eda_recommendations.preprocessing.negative_policy}")
    else:
        logger.warning("[pipeline] Skipping EDA - using default preprocessing policies")

    # Step 4: Build datasets per fuel type (MODIFIED)
    logger.info("[pipeline] Building modeling datasets (fuel-type specific)")

    from src.renewable.dataset_builder import build_dataset_by_fuel_type

    fuel_datasets = {}
    last_prep_report = None

    for fuel_type in config.fuel_types:
        logger.info(f"[pipeline] Building {fuel_type} dataset...")

        fuel_output_dir = config.preprocessing_dir() / fuel_type.lower()

        modeling_df_fuel, prep_report = build_dataset_by_fuel_type(
            generation_df,
            weather_df,
            fuel_type=fuel_type,
            output_dir=fuel_output_dir,
            eda_recommendations=eda_recommendations.preprocessing if eda_recommendations else None,
        )

        fuel_datasets[fuel_type] = modeling_df_fuel
        last_prep_report = prep_report  # Keep last report for backward compatibility

        logger.info(f"[pipeline] {fuel_type}: {prep_report.input_rows:,} → {prep_report.output_rows:,} rows")

    # Combine all fuel datasets
    modeling_df = pd.concat(fuel_datasets.values(), ignore_index=True)

    # DEBUG: Log dataset combination
    logger.info(f"[pipeline] Combined {len(fuel_datasets)} fuel-type datasets into {len(modeling_df):,} rows")

    # Save combined modeling dataset for analysis and testing
    modeling_dataset_path = Path(config.data_dir) / "modeling_dataset.parquet"
    modeling_df.to_parquet(modeling_dataset_path, index=False)
    logger.info(f"[pipeline] Saved modeling dataset: {modeling_dataset_path} ({len(modeling_df):,} rows)")

    # Initialize preprocessing results
    results["preprocessing"] = {
        "rows_input": len(generation_df),
        "rows_output": len(modeling_df),
    }

    # Use last report for backward compatibility
    prep_report = last_prep_report

    # Extract time and weather features from output_features (from last fuel type)
    if prep_report:
        logger.debug("[pipeline] Extracting features from last fuel type's preprocessing report")

        time_features = [
            f for f in prep_report.output_features
            if f in ['hour_sin', 'hour_cos', 'dow_sin', 'dow_cos']
        ]
        weather_features = [
            f for f in prep_report.output_features
            if f in prep_report.weather_vars_used
        ]

        logger.debug(f"[pipeline] Found {len(time_features)} time features, {len(weather_features)} weather features")

        results["preprocessing"].update({
            "series_dropped": len(prep_report.series_dropped_incomplete),
            "negative_action": prep_report.negative_report.action_taken,
            "time_features": time_features,
            "weather_features": weather_features,
        })
        logger.info(
            f"[pipeline] Preprocessing: {len(generation_df):,} → "
            f"{len(modeling_df):,} rows"
        )
    else:
        logger.warning("[pipeline] No preprocessing report available - skipping feature extraction")

    # Step 5: Train and validate (on preprocessed data)
    cv_results, leaderboard, baseline = train_renewable_models(
        config, modeling_df
    )
    best_model = leaderboard.iloc[0]["model"]
    results["best_model"] = best_model
    results["best_rmse"] = float(leaderboard.iloc[0]["rmse"])
    results["baseline"] = baseline
    # Save full leaderboard for dashboard display
    results["leaderboard"] = leaderboard.to_dict(orient="records")

    # Step 6: Generate forecasts (use the best model from CV)
    # Pass weather_df for future weather (forecast horizon)
    forecasts = generate_renewable_forecasts(
        config, modeling_df, weather_df, best_model=best_model
    )
    results["forecast_rows"] = len(forecasts)

    # Step 5: Train LightGBM models and generate interpretability reports (optional)
    # (LightGBM is for interpretability only - MSTL/ARIMA provide primary forecasts)
    if config.enable_interpretability:
        logger.info("[pipeline] Training interpretability models (LightGBM + SHAP)")
        try:
            interpretability_reports = train_interpretability_models(
                config, generation_df, weather_df
            )
            results["interpretability"] = {
                "series_count": len(interpretability_reports),
                "series": list(interpretability_reports.keys()),
                "output_dir": str(config.interpretability_dir()),
            }

            # Add top features summary per series
            for uid, report in interpretability_reports.items():
                results["interpretability"][f"{uid}_top_features"] = report.top_features[:3]

        except Exception as e:
            logger.warning(f"[pipeline] Interpretability training failed (non-fatal): {e}")
            results["interpretability"] = {"error": str(e)}
    else:
        logger.info("[pipeline] Interpretability disabled (enable_interpretability=False)")
        results["interpretability"] = {"enabled": False}

    if fetch_diagnostics is not None:
        results["fetch_diagnostics"] = fetch_diagnostics

    logger.info(f"[pipeline] Complete. Best model: {results['best_model']}")

    return results


def main():
    """CLI entry point for renewable pipeline."""
    parser = argparse.ArgumentParser(
        description="Renewable Energy Forecasting Pipeline",
        formatter_class=argparse.RawDescriptionHelpFormatter,
        epilog="""
Preset Examples:
  # Fast development (24h forecast, 2 CV windows, 15 days lookback)
  python -m src.renewable.tasks --preset 24h

  # Standard forecasting (48h forecast, 3 CV windows, 21 days lookback)
  python -m src.renewable.tasks --preset 48h

  # Extended planning (72h forecast, 3 CV windows, 28 days lookback)
  python -m src.renewable.tasks --preset 72h

Custom Examples:
  # 24h preset but only CALI region, skip interpretability
  python -m src.renewable.tasks --preset 24h --regions CALI --no-interpretability

  # Custom: 36h forecast with 4 CV windows
  python -m src.renewable.tasks --horizon 36 --cv-windows 4 --lookback-days 30
        """
    )

    # Preset system (NEW)
    parser.add_argument(
        "--preset",
        type=str,
        choices=["24h", "48h", "72h"],
        help="Quick preset: 24h (fast dev), 48h (standard), 72h (extended planning)",
    )

    # Flags (NEW)
    parser.add_argument(
        "--no-interpretability",
        action="store_true",
        help="Disable LightGBM interpretability analysis (speeds up pipeline)",
    )

    # Data parameters (existing)
    parser.add_argument(
        "--regions",
        type=str,
        help="Override regions (comma-separated, e.g., CALI,ERCO,MISO)",
    )
    parser.add_argument(
        "--fuel",
        type=str,
        help="Override fuel types (comma-separated, e.g., WND,SUN)",
    )

    # Forecast parameters (existing + new)
    parser.add_argument(
        "--horizon",
        type=int,
        help="Override forecast horizon in hours",
    )
    parser.add_argument(
        "--lookback-days",
        type=int,
        help="Override lookback days",
    )
    parser.add_argument(
        "--cv-windows",
        type=int,
        help="Override CV windows count",
    )

    # Output parameters
    parser.add_argument(
        "--overwrite",
        action="store_true",
        help="Overwrite existing data files",
    )
    parser.add_argument(
        "--data-dir",
        type=str,
        default="data/renewable",
        help="Output directory (default: data/renewable)",
    )

    args = parser.parse_args()

    # Configure logging
    logging.basicConfig(
        level=logging.INFO,
        format="%(asctime)s - %(name)s - %(levelname)s - %(message)s",
    )

    # Build config with preset support
    if args.preset:
        # Apply preset defaults
        logger.info(f"[CLI] Applying preset: {args.preset}")
        config = RenewablePipelineConfig(
            horizon_preset=args.preset,  # This triggers __post_init__ to apply preset
            regions=args.regions.split(",") if args.regions else ["CALI", "ERCO", "MISO"],
            fuel_types=args.fuel.split(",") if args.fuel else ["WND", "SUN"],
            enable_interpretability=not args.no_interpretability,
            overwrite=args.overwrite,
            data_dir=args.data_dir,
        )

        # Allow CLI overrides of preset values
        if args.horizon is not None:
            object.__setattr__(config, "horizon", args.horizon)
            logger.info(f"[CLI] Override: horizon={args.horizon}h")
        if args.lookback_days is not None:
            object.__setattr__(config, "lookback_days", args.lookback_days)
            logger.info(f"[CLI] Override: lookback_days={args.lookback_days}")
        if args.cv_windows is not None:
            object.__setattr__(config, "cv_windows", args.cv_windows)
            logger.info(f"[CLI] Override: cv_windows={args.cv_windows}")

    else:
        # No preset: use explicit values or defaults
        config = RenewablePipelineConfig(
            regions=args.regions.split(",") if args.regions else ["CALI", "ERCO", "MISO"],
            fuel_types=args.fuel.split(",") if args.fuel else ["WND", "SUN"],
            lookback_days=args.lookback_days if args.lookback_days else 30,
            horizon=args.horizon if args.horizon else 24,
            cv_windows=args.cv_windows if args.cv_windows else 5,
            enable_interpretability=not args.no_interpretability,
            overwrite=args.overwrite,
            data_dir=args.data_dir,
        )

    # Run pipeline
    results = run_full_pipeline(config)

    print("\n" + "=" * 60)
    print("PIPELINE RESULTS")
    print("=" * 60)
    print(f"  Series count: {results['series_count']}")
    print(f"  Generation rows: {results['generation_rows']}")
    print(f"  Weather rows: {results['weather_rows']}")
    print(f"  Forecast rows: {results['forecast_rows']}")
    print(f"  Best model: {results['best_model']}")
    print(f"  Best RMSE: {results['best_rmse']:.1f}")
    print("=" * 60)


if __name__ == "__main__":
    main()


Overwriting src/renewable/tasks.py


In [11]:
%%writefile src/renewable/data_freshness.py
# src/renewable/data_freshness.py
"""
Lightweight EIA data freshness checking.

This module provides functions to check if new data is available from the EIA API
before running the full pipeline. It compares the current max timestamps with
the previous run's max timestamps to determine if a full pipeline run is needed.
"""

from __future__ import annotations

import json
import logging
from dataclasses import dataclass, field
from datetime import datetime, timezone
from pathlib import Path
from typing import Optional

import pandas as pd
import requests

from src.renewable.regions import get_eia_respondent

logger = logging.getLogger(__name__)


@dataclass(frozen=True)
class FreshnessCheckResult:
    """Result of a data freshness check."""

    has_new_data: bool
    checked_at_utc: str
    series_status: dict[str, dict] = field(default_factory=dict)
    summary: str = ""


def load_previous_max_ds(run_log_path: Path) -> dict[str, str]:
    """
    Load per-series max_ds from previous run_log.json.

    Args:
        run_log_path: Path to run_log.json

    Returns:
        Dict mapping unique_id -> max_ds ISO string.
        Empty dict if file doesn't exist or is malformed.
    """
    if not run_log_path.exists():
        logger.info("[freshness] No previous run_log.json found - first run")
        return {}

    try:
        data = json.loads(run_log_path.read_text(encoding="utf-8"))

        # Navigate to diagnostics.generation_coverage.coverage
        coverage = (
            data.get("diagnostics", {})
            .get("generation_coverage", {})
            .get("coverage", [])
        )

        if not coverage:
            logger.warning("[freshness] run_log.json has no coverage data")
            return {}

        result = {}
        for item in coverage:
            uid = item.get("unique_id")
            max_ds = item.get("max_ds")
            if uid and max_ds:
                result[uid] = max_ds

        logger.info(f"[freshness] Loaded {len(result)} series from previous run_log")
        return result

    except (json.JSONDecodeError, KeyError, TypeError) as e:
        logger.warning(f"[freshness] Failed to parse run_log.json: {e}")
        return {}


def probe_eia_latest(
    api_key: str,
    region: str,
    fuel_type: str,
    *,
    timeout: int = 15,
) -> Optional[str]:
    """
    Fetch only the single most recent record from EIA API.

    This is a lightweight probe that uses:
    - length=1 (only fetch 1 record)
    - sort by period DESC (most recent first)

    Args:
        api_key: EIA API key
        region: Region code (CALI, ERCO, MISO, etc.)
        fuel_type: Fuel type (WND, SUN)
        timeout: Request timeout in seconds

    Returns:
        ISO timestamp string of latest record, or None on error.
    """
    try:
        respondent = get_eia_respondent(region)

        params = {
            "api_key": api_key,
            "data[]": "value",
            "facets[respondent][]": respondent,
            "facets[fueltype][]": fuel_type,
            "frequency": "hourly",
            "length": 1,
            "sort[0][column]": "period",
            "sort[0][direction]": "desc",
        }

        base_url = "https://api.eia.gov/v2/electricity/rto/fuel-type-data/data/"
        resp = requests.get(base_url, params=params, timeout=timeout)
        resp.raise_for_status()

        payload = resp.json()
        response = payload.get("response", {})
        records = response.get("data", [])

        if not records:
            logger.warning(f"[probe] {region}_{fuel_type}: No records returned")
            return None

        period = records[0].get("period")
        if not period:
            logger.warning(f"[probe] {region}_{fuel_type}: Record missing 'period'")
            return None

        # Parse to consistent ISO format
        ts = pd.to_datetime(period, utc=True)
        return ts.isoformat()

    except requests.RequestException as e:
        logger.warning(f"[probe] {region}_{fuel_type}: API error: {e}")
        return None
    except Exception as e:
        logger.warning(f"[probe] {region}_{fuel_type}: Unexpected error: {e}")
        return None


def _compare_timestamps(prev: Optional[str], current: Optional[str]) -> bool:
    """
    Return True if current is strictly newer than prev.

    Handles None values conservatively (assume new data if unknown).
    """
    if not prev or not current:
        return True  # Unknown = assume new data (conservative)

    try:
        prev_dt = pd.to_datetime(prev, utc=True)
        curr_dt = pd.to_datetime(current, utc=True)
        return curr_dt > prev_dt
    except Exception:
        return True  # Parse error = assume new data


def check_all_series_freshness(
    regions: list[str],
    fuel_types: list[str],
    run_log_path: Path,
    api_key: str,
) -> FreshnessCheckResult:
    """
    Check all series for new data availability.

    Args:
        regions: List of region codes (e.g., ["CALI", "ERCO", "MISO"])
        fuel_types: List of fuel types (e.g., ["WND", "SUN"])
        run_log_path: Path to previous run_log.json
        api_key: EIA API key

    Returns:
        FreshnessCheckResult with has_new_data flag and detailed status per series.
    """
    checked_at = datetime.now(timezone.utc).isoformat()

    # 1. Load previous max_ds values
    prev_max_ds = load_previous_max_ds(run_log_path)

    # 2. If no previous run_log, always run full pipeline (first run)
    if not prev_max_ds:
        return FreshnessCheckResult(
            has_new_data=True,
            checked_at_utc=checked_at,
            series_status={},
            summary="No previous run_log.json found - running full pipeline (first run)",
        )

    # 3. Probe each series
    series_status: dict[str, dict] = {}
    has_any_new = False
    new_series: list[str] = []
    error_series: list[str] = []

    for region in regions:
        for fuel_type in fuel_types:
            series_id = f"{region}_{fuel_type}"
            prev = prev_max_ds.get(series_id)
            current = probe_eia_latest(api_key, region, fuel_type)

            # Determine if this series has new data
            if current is None:
                # API error - be conservative, assume new data
                is_new = True
                error_series.append(series_id)
                logger.warning(
                    f"[freshness] {series_id}: probe failed, assuming new data"
                )
            else:
                is_new = _compare_timestamps(prev, current)

            series_status[series_id] = {
                "prev_max_ds": prev,
                "current_max_ds": current,
                "is_new": is_new,
            }

            if is_new:
                has_any_new = True
                if current is not None:
                    new_series.append(series_id)

            # Log each series check
            status_str = "NEW" if is_new else "unchanged"
            logger.info(
                f"[freshness] {series_id}: prev={prev} current={current} ({status_str})"
            )

    # 4. Build summary
    if error_series:
        summary = f"Probe errors for {error_series}, assuming new data available"
    elif new_series:
        summary = f"New data found for: {', '.join(new_series)}"
    else:
        summary = "No new data found for any series"

    return FreshnessCheckResult(
        has_new_data=has_any_new,
        checked_at_utc=checked_at,
        series_status=series_status,
        summary=summary,
    )


if __name__ == "__main__":
    # Quick test
    import os
    from dotenv import load_dotenv

    load_dotenv()
    logging.basicConfig(level=logging.INFO)

    api_key = os.getenv("EIA_API_KEY")
    if not api_key:
        print("EIA_API_KEY not set")
        exit(1)

    run_log_path = Path("data/renewable/run_log.json")

    result = check_all_series_freshness(
        regions=["CALI", "ERCO", "MISO"],
        fuel_types=["WND", "SUN"],
        run_log_path=run_log_path,
        api_key=api_key,
    )

    print(f"\nFreshness Check Result:")
    print(f"  has_new_data: {result.has_new_data}")
    print(f"  checked_at: {result.checked_at_utc}")
    print(f"  summary: {result.summary}")
    print(f"\nPer-series status:")
    for series_id, status in result.series_status.items():
        print(f"  {series_id}: {status}")


Overwriting src/renewable/data_freshness.py


# SQLite persistence layer

Extends the monitoring database with:
- Prediction intervals (80%, 95%)
- Weather features table
- Renewable-specific columns (fuel_type, region)

In [12]:
%%writefile src/renewable/db.py
# file: src/renewable/db.py
"""Database schema and operations for renewable forecasting.

Extends the Chapter 4 monitoring database with:
- Prediction intervals (80%, 95%)
- Weather features table
- Renewable-specific columns (fuel_type, region)
"""

import json
import sqlite3
from datetime import datetime
from pathlib import Path
from typing import Optional

import pandas as pd


def connect(db_path: str) -> sqlite3.Connection:
    """Connect to SQLite database with optimized settings."""
    Path(db_path).parent.mkdir(parents=True, exist_ok=True)
    con = sqlite3.connect(db_path)
    con.execute("PRAGMA journal_mode=WAL;")
    con.execute("PRAGMA synchronous=NORMAL;")
    return con


def init_renewable_db(db_path: str) -> None:
    """Initialize renewable forecasting database schema.

    Creates tables:
    - renewable_forecasts: Forecasts with dual intervals
    - renewable_scores: Evaluation metrics with coverage
    - weather_features: Weather data by region
    - drift_alerts: Drift detection history
    - baseline_metrics: Backtest baselines for drift thresholds
    """
    con = connect(db_path)
    cur = con.cursor()

    # Forecasts with dual prediction intervals
    cur.execute("""
    CREATE TABLE IF NOT EXISTS renewable_forecasts (
        id INTEGER PRIMARY KEY AUTOINCREMENT,
        run_id TEXT NOT NULL,
        created_at TEXT NOT NULL,
        unique_id TEXT NOT NULL,
        region TEXT NOT NULL,
        fuel_type TEXT NOT NULL,
        ds TEXT NOT NULL,
        model TEXT NOT NULL,
        yhat REAL,
        yhat_lo_80 REAL,
        yhat_hi_80 REAL,
        yhat_lo_95 REAL,
        yhat_hi_95 REAL,
        UNIQUE (run_id, model, unique_id, ds)
    );
    """)

    # Index for efficient queries
    cur.execute("""
    CREATE INDEX IF NOT EXISTS idx_forecasts_region_ds
    ON renewable_forecasts (region, ds);
    """)

    cur.execute("""
    CREATE INDEX IF NOT EXISTS idx_forecasts_fuel_ds
    ON renewable_forecasts (fuel_type, ds);
    """)

    # Evaluation scores with dual coverage
    cur.execute("""
    CREATE TABLE IF NOT EXISTS renewable_scores (
        id INTEGER PRIMARY KEY AUTOINCREMENT,
        scored_at TEXT NOT NULL,
        run_id TEXT NOT NULL,
        unique_id TEXT NOT NULL,
        region TEXT NOT NULL,
        fuel_type TEXT NOT NULL,
        model TEXT NOT NULL,
        horizon_hours INTEGER NOT NULL,
        rmse REAL,
        mae REAL,
        coverage_80 REAL,
        coverage_95 REAL,
        valid_rows INTEGER,
        UNIQUE (run_id, model, unique_id, horizon_hours)
    );
    """)

    # Weather features by region
    cur.execute("""
    CREATE TABLE IF NOT EXISTS weather_features (
        id INTEGER PRIMARY KEY AUTOINCREMENT,
        region TEXT NOT NULL,
        ds TEXT NOT NULL,
        temperature_2m REAL,
        wind_speed_10m REAL,
        wind_speed_100m REAL,
        wind_direction_10m REAL,
        direct_radiation REAL,
        diffuse_radiation REAL,
        cloud_cover REAL,
        is_forecast INTEGER DEFAULT 0,
        created_at TEXT DEFAULT CURRENT_TIMESTAMP,
        UNIQUE (region, ds, is_forecast)
    );
    """)

    cur.execute("""
    CREATE INDEX IF NOT EXISTS idx_weather_region_ds
    ON weather_features (region, ds);
    """)

    # Drift detection alerts
    cur.execute("""
    CREATE TABLE IF NOT EXISTS drift_alerts (
        id INTEGER PRIMARY KEY AUTOINCREMENT,
        alert_at TEXT NOT NULL,
        run_id TEXT,
        unique_id TEXT,
        region TEXT,
        fuel_type TEXT,
        alert_type TEXT NOT NULL,
        severity TEXT NOT NULL,
        current_rmse REAL,
        threshold_rmse REAL,
        message TEXT,
        metadata_json TEXT
    );
    """)

    cur.execute("""
    CREATE INDEX IF NOT EXISTS idx_drift_alerts_time
    ON drift_alerts (alert_at);
    """)

    # Baseline metrics for drift detection
    cur.execute("""
    CREATE TABLE IF NOT EXISTS baseline_metrics (
        id INTEGER PRIMARY KEY AUTOINCREMENT,
        created_at TEXT NOT NULL,
        unique_id TEXT NOT NULL,
        model TEXT NOT NULL,
        rmse_mean REAL NOT NULL,
        rmse_std REAL NOT NULL,
        mae_mean REAL,
        mae_std REAL,
        drift_threshold_rmse REAL NOT NULL,
        drift_threshold_mae REAL,
        n_windows INTEGER,
        metadata_json TEXT,
        UNIQUE (unique_id, model)
    );
    """)

    con.commit()
    con.close()


def save_forecasts(
    db_path: str,
    forecasts_df: pd.DataFrame,
    run_id: str,
    model: str = "MSTL_ARIMA",
) -> int:
    """Save forecasts to database.

    Args:
        db_path: Path to SQLite database
        forecasts_df: DataFrame with [unique_id, ds, yhat, yhat_lo_80, ...]
        run_id: Pipeline run identifier
        model: Model name

    Returns:
        Number of rows inserted
    """
    con = connect(db_path)
    created_at = datetime.utcnow().isoformat()

    rows = []
    for _, row in forecasts_df.iterrows():
        unique_id = row["unique_id"]
        parts = unique_id.split("_")
        region = parts[0] if len(parts) > 0 else ""
        fuel_type = parts[1] if len(parts) > 1 else ""

        rows.append((
            run_id,
            created_at,
            unique_id,
            region,
            fuel_type,
            str(row["ds"]),
            model,
            row.get("yhat"),
            row.get("yhat_lo_80"),
            row.get("yhat_hi_80"),
            row.get("yhat_lo_95"),
            row.get("yhat_hi_95"),
        ))

    cur = con.cursor()
    cur.executemany("""
        INSERT OR REPLACE INTO renewable_forecasts
        (run_id, created_at, unique_id, region, fuel_type, ds, model,
         yhat, yhat_lo_80, yhat_hi_80, yhat_lo_95, yhat_hi_95)
        VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)
    """, rows)

    con.commit()
    con.close()

    return len(rows)


def save_weather(
    db_path: str,
    weather_df: pd.DataFrame,
    is_forecast: bool = False,
) -> int:
    """Save weather features to database.

    Args:
        db_path: Path to SQLite database
        weather_df: DataFrame with [ds, region, weather_vars...]
        is_forecast: True if this is forecast weather data

    Returns:
        Number of rows inserted
    """
    con = connect(db_path)

    weather_cols = [
        "temperature_2m", "wind_speed_10m", "wind_speed_100m",
        "wind_direction_10m", "direct_radiation", "diffuse_radiation", "cloud_cover"
    ]

    rows = []
    for _, row in weather_df.iterrows():
        values = [row.get(col) for col in weather_cols]
        rows.append((
            row["region"],
            str(row["ds"]),
            *values,
            1 if is_forecast else 0,
        ))

    cur = con.cursor()
    cur.executemany(f"""
        INSERT OR REPLACE INTO weather_features
        (region, ds, {', '.join(weather_cols)}, is_forecast)
        VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?)
    """, rows)

    con.commit()
    con.close()

    return len(rows)


def save_drift_alert(
    db_path: str,
    run_id: str,
    unique_id: str,
    current_rmse: float,
    threshold_rmse: float,
    severity: str = "warning",
    metadata: Optional[dict] = None,
) -> None:
    """Save drift detection alert.

    Args:
        db_path: Path to SQLite database
        run_id: Pipeline run identifier
        unique_id: Series identifier
        current_rmse: Current RMSE value
        threshold_rmse: Drift threshold
        severity: Alert severity (info, warning, critical)
        metadata: Additional metadata
    """
    con = connect(db_path)

    parts = unique_id.split("_")
    region = parts[0] if len(parts) > 0 else ""
    fuel_type = parts[1] if len(parts) > 1 else ""

    alert_type = "drift_detected" if current_rmse > threshold_rmse else "drift_check"
    message = (
        f"RMSE {current_rmse:.1f} {'>' if current_rmse > threshold_rmse else '<='} "
        f"threshold {threshold_rmse:.1f}"
    )

    cur = con.cursor()
    cur.execute("""
        INSERT INTO drift_alerts
        (alert_at, run_id, unique_id, region, fuel_type, alert_type, severity,
         current_rmse, threshold_rmse, message, metadata_json)
        VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)
    """, (
        datetime.utcnow().isoformat(),
        run_id,
        unique_id,
        region,
        fuel_type,
        alert_type,
        severity,
        current_rmse,
        threshold_rmse,
        message,
        json.dumps(metadata) if metadata else None,
    ))

    con.commit()
    con.close()


def save_baseline(
    db_path: str,
    unique_id: str,
    model: str,
    baseline: dict,
) -> None:
    """Save baseline metrics for drift detection.

    Args:
        db_path: Path to SQLite database
        unique_id: Series identifier
        model: Model name
        baseline: Baseline metrics dictionary
    """
    con = connect(db_path)
    cur = con.cursor()

    cur.execute("""
        INSERT OR REPLACE INTO baseline_metrics
        (created_at, unique_id, model, rmse_mean, rmse_std, mae_mean, mae_std,
         drift_threshold_rmse, drift_threshold_mae, n_windows, metadata_json)
        VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)
    """, (
        datetime.utcnow().isoformat(),
        unique_id,
        model,
        baseline.get("rmse_mean"),
        baseline.get("rmse_std"),
        baseline.get("mae_mean"),
        baseline.get("mae_std"),
        baseline.get("drift_threshold_rmse"),
        baseline.get("drift_threshold_mae"),
        baseline.get("n_windows"),
        json.dumps(baseline),
    ))

    con.commit()
    con.close()


def get_recent_forecasts(
    db_path: str,
    region: Optional[str] = None,
    fuel_type: Optional[str] = None,
    hours: int = 48,
) -> pd.DataFrame:
    """Get recent forecasts from database.

    Args:
        db_path: Path to SQLite database
        region: Filter by region (optional)
        fuel_type: Filter by fuel type (optional)
        hours: Hours of history to retrieve

    Returns:
        DataFrame with forecasts
    """
    con = connect(db_path)

    query = """
        SELECT *
        FROM renewable_forecasts
        WHERE datetime(created_at) > datetime('now', ?)
    """
    params = [f"-{hours} hours"]

    if region:
        query += " AND region = ?"
        params.append(region)

    if fuel_type:
        query += " AND fuel_type = ?"
        params.append(fuel_type)

    query += " ORDER BY ds DESC"

    df = pd.read_sql_query(query, con, params=params)
    con.close()

    return df


def get_drift_alerts(
    db_path: str,
    hours: int = 24,
    severity: Optional[str] = None,
) -> pd.DataFrame:
    """Get recent drift alerts.

    Args:
        db_path: Path to SQLite database
        hours: Hours of history
        severity: Filter by severity (optional)

    Returns:
        DataFrame with alerts
    """
    con = connect(db_path)

    query = """
        SELECT *
        FROM drift_alerts
        WHERE datetime(alert_at) > datetime('now', ?)
    """
    params = [f"-{hours} hours"]

    if severity:
        query += " AND severity = ?"
        params.append(severity)

    query += " ORDER BY alert_at DESC"

    df = pd.read_sql_query(query, con, params=params)
    con.close()

    return df


if __name__ == "__main__":
    # Test database initialization
    import tempfile

    with tempfile.TemporaryDirectory() as tmpdir:
        db_path = f"{tmpdir}/test_renewable.db"

        print("Initializing database...")
        init_renewable_db(db_path)

        print("Database initialized successfully!")

        # Test connection
        con = connect(db_path)
        cur = con.cursor()
        cur.execute("SELECT name FROM sqlite_master WHERE type='table'")
        tables = cur.fetchall()
        print(f"Tables created: {[t[0] for t in tables]}")
        con.close()


Overwriting src/renewable/db.py


---

# Module 8: Dashboard

**File:** `src/renewable/dashboard.py`

The Streamlit dashboard provides:
- **Forecast visualization** with prediction intervals
- **Drift monitoring** and alerts
- **Coverage analysis** (nominal vs empirical)
- **Weather features** by region

## Running the Dashboard

```bash
streamlit run src/renewable/dashboard.py
```

The dashboard will:
1. Load forecasts from `data/renewable/forecasts.parquet`
2. Display interactive charts with Plotly
3. Show drift alerts from the database

In [14]:
%%writefile src/renewable/dashboard.py
# file: src/renewable/dashboard.py
"""Streamlit dashboard for renewable energy forecasting.

Provides:
- Forecast visualization with prediction intervals
- Drift monitoring and alerts
- Coverage analysis (nominal vs empirical)
- Weather features by region

Run with:
    streamlit run src/renewable/dashboard.py
"""

import os
import sys
from datetime import datetime, timedelta, timezone
from pathlib import Path

import numpy as np
import pandas as pd
import plotly.express as px
import plotly.graph_objects as go
import streamlit as st
import streamlit_mermaid as stmd

# Add project root to path
sys.path.insert(0, str(Path(__file__).parent.parent.parent))

from src.renewable.db import (connect, get_drift_alerts, get_recent_forecasts,
                              init_renewable_db)
from src.renewable.regions import FUEL_TYPES, REGIONS

# Page config
st.set_page_config(
    page_title="Renewable Forecast Dashboard",
    page_icon="⚡",
    layout="wide",
)


def main():
    """Main dashboard application."""
    st.title("⚡ Renewable Energy Forecast Dashboard")
    st.markdown("Next-24h wind/solar generation forecasts with drift monitoring")

    # Sidebar configuration
    with st.sidebar:
        st.header("Configuration")

        db_path = st.text_input(
            "Database Path",
            value="data/renewable/renewable.db",
        )

        # Initialize database if it doesn't exist
        if not Path(db_path).exists():
            init_renewable_db(db_path)
            st.info("Database initialized")

        st.divider()

        # Region filter
        all_regions = list(REGIONS.keys())
        selected_regions = st.multiselect(
            "Regions",
            options=all_regions,
            default=["CALI", "ERCO", "MISO"],
        )

        # Fuel type filter
        fuel_type = st.selectbox(
            "Fuel Type",
            options=["WND", "SUN", "Both"],
            index=0,
        )

        st.divider()

        # Actions
        show_debug = st.checkbox("Show Debug", value=False)
        if st.button("🔄 Refresh Data", width="stretch"):
            st.rerun()

        if st.button("📊 Run Pipeline", width="stretch"):
            run_pipeline_from_dashboard(db_path, selected_regions, fuel_type)

    # Main content tabs  (Data & Insights first)
    tab_insights, tab_forecasts, tab_drift, tab_coverage, tab_weather, tab_interp = st.tabs([
        "📚 Data & Insights",
        "📈 Forecasts",
        "⚠️ Drift Monitor",
        "📊 Coverage",
        "🌤️ Weather",
        "🔍 Interpretability",
    ])

    with tab_insights:
        render_insights_tab(db_path)

    with tab_forecasts:
        render_forecasts_tab(db_path, selected_regions, fuel_type, show_debug=show_debug)

    with tab_drift:
        render_drift_tab(db_path)

    with tab_coverage:
        render_coverage_tab(db_path)

    with tab_weather:
        render_weather_tab(db_path, selected_regions)

    with tab_interp:
        render_interpretability_tab(selected_regions, fuel_type)



def render_forecasts_tab(db_path: str, regions: list, fuel_type: str, *, show_debug: bool = False):
    """Render forecast visualization with prediction intervals."""
    st.subheader("Generation Forecasts")

    forecasts_df = pd.DataFrame()
    data_source = "none"
    derived_columns: list[str] = []

    # Try to load from parquet file first (pipeline output)
    parquet_path = Path("data/renewable/forecasts.parquet")
    if parquet_path.exists():
        try:
            forecasts_df = pd.read_parquet(parquet_path)
            data_source = f"parquet:{parquet_path}"
            # Add region/fuel_type columns if missing
            if "unique_id" in forecasts_df.columns:
                parts = forecasts_df["unique_id"].astype(str).str.split("_", n=1, expand=True)
                if "region" not in forecasts_df.columns:
                    forecasts_df["region"] = parts[0]
                    derived_columns.append("region")
                if "fuel_type" not in forecasts_df.columns:
                    forecasts_df["fuel_type"] = parts[1] if parts.shape[1] > 1 else pd.NA
                    derived_columns.append("fuel_type")
            st.success(f"Loaded {len(forecasts_df)} forecasts from pipeline")

            # Calculate and display data freshness
            if not forecasts_df.empty and "ds" in forecasts_df.columns:
                earliest_forecast_ts = forecasts_df["ds"].min()
                now_utc = pd.Timestamp.now(tz="UTC").floor("h")

                # Forecasts start from last_data + 1h, so last_data = earliest_forecast - 1h
                last_data_ts = earliest_forecast_ts - pd.Timedelta(hours=1)

                # Ensure both timestamps are timezone-aware for comparison
                if not hasattr(last_data_ts, 'tz') or last_data_ts.tz is None:
                    last_data_ts = pd.Timestamp(last_data_ts, tz="UTC")

                data_age_hours = (now_utc - last_data_ts).total_seconds() / 3600

                # Show warning if data is > 6 hours old
                if data_age_hours > 6:
                    st.warning(
                        f"⚠️ Forecasts are based on **{data_age_hours:.1f} hour old** data "
                        f"(last EIA data: {last_data_ts.strftime('%b %d %H:%M')} UTC). "
                        f"Click 'Refresh Forecasts' button in sidebar to update."
                    )
                else:
                    st.info(
                        f"✅ Forecasts from {last_data_ts.strftime('%b %d %H:%M')} UTC data "
                        f"({data_age_hours:.1f}h old)"
                    )

        except Exception as e:
            st.warning(f"Could not load parquet: {e}")

    # Fall back to database
    if forecasts_df.empty:
        try:
            forecasts_df = get_recent_forecasts(db_path, hours=72)
            data_source = f"db:{db_path}"
        except Exception as e:
            st.warning(f"Could not load from database: {e}")

    if forecasts_df.empty:
        # Show demo data
        st.info("No forecasts found. Showing demo data.")
        forecasts_df = generate_demo_forecasts(regions, fuel_type)
        data_source = "demo"

    if show_debug:
        with st.expander("Debug: Forecast Data", expanded=False):
            st.markdown("**Source**")
            st.code(data_source)
            st.markdown("**Columns**")
            st.code(", ".join(forecasts_df.columns.tolist()))

            st.markdown("**Counts (pre-filter)**")
            st.write({"rows": int(len(forecasts_df))})

            if derived_columns:
                st.markdown("**Derived Columns**")
                st.write(derived_columns)

            if "unique_id" in forecasts_df.columns:
                st.markdown("**unique_id sample**")
                st.write(forecasts_df["unique_id"].dropna().astype(str).head(10).tolist())

            if "fuel_type" in forecasts_df.columns:
                st.markdown("**fuel_type counts**")
                st.dataframe(forecasts_df["fuel_type"].value_counts(dropna=False).to_frame())

                unknown = sorted(
                    {str(v) for v in forecasts_df["fuel_type"].dropna().unique()}
                    - set(FUEL_TYPES.keys())
                )
                if unknown:
                    st.warning(f"Unknown fuel_type values: {unknown}")

            if "region" in forecasts_df.columns:
                st.markdown("**region counts**")
                st.dataframe(forecasts_df["region"].value_counts(dropna=False).to_frame())

    # Filter by selections
    if fuel_type != "Both":
        forecasts_df = forecasts_df[forecasts_df["fuel_type"] == fuel_type]

    if regions:
        forecasts_df = forecasts_df[forecasts_df["region"].isin(regions)]

    if show_debug:
        with st.expander("Debug: Filter Result", expanded=False):
            st.markdown("**Applied Filters**")
            st.write({"fuel_type": fuel_type, "regions": regions})
            st.markdown("**Counts (post-filter)**")
            st.write({"rows": int(len(forecasts_df))})
            if "unique_id" in forecasts_df.columns:
                st.markdown("**unique_id after filter**")
                st.write(sorted(forecasts_df["unique_id"].dropna().astype(str).unique().tolist()))

    if forecasts_df.empty:
        st.warning("No data matching filters")
        return

    # Series selector
    series_options = forecasts_df["unique_id"].unique().tolist()
    selected_series = st.selectbox(
        "Select Series",
        options=series_options,
        index=0 if series_options else None,
        key="forecast_series_select",
    )

    if selected_series:
        series_data = forecasts_df[forecasts_df["unique_id"] == selected_series].copy()
        series_data = series_data.sort_values("ds")

        # Convert to local timezone for display
        region_code = series_data["unique_id"].iloc[0].split("_")[0]
        region_info = REGIONS.get(region_code)
        timezone_name = region_info.timezone if region_info else "UTC"

        # Create forecast plot with intervals
        fig = create_forecast_plot(series_data, selected_series, timezone_name)
        st.plotly_chart(fig, width="stretch")

        # Show data table
        with st.expander("View Data"):
            st.dataframe(
                series_data[["ds", "yhat", "yhat_lo_80", "yhat_hi_80", "yhat_lo_95", "yhat_hi_95"]],
                width="stretch",
            )


def create_forecast_plot(df: pd.DataFrame, title: str, timezone_name: str = "UTC") -> go.Figure:
    """Create Plotly figure with forecast and prediction intervals.

    Args:
        df: Forecast dataframe with ds (timestamp), yhat, and interval columns
        title: Series name for chart title
        timezone_name: IANA timezone name for display (e.g., "America/Chicago")
    """
    fig = go.Figure()

    # Convert timestamps to local timezone for display
    df = df.copy()
    df["ds"] = pd.to_datetime(df["ds"])

    # Convert UTC to local timezone
    if timezone_name != "UTC":
        df["ds"] = df["ds"].dt.tz_localize("UTC").dt.tz_convert(timezone_name)

    # Get timezone abbreviation for display (e.g., "CST", "PST")
    if timezone_name != "UTC" and len(df) > 0:
        tz_abbr = df["ds"].iloc[0].strftime("%Z")
    else:
        tz_abbr = "UTC"

    # 95% interval (outer, lighter)
    if "yhat_lo_95" in df.columns and "yhat_hi_95" in df.columns:
        fig.add_trace(go.Scatter(
            x=pd.concat([df["ds"], df["ds"][::-1]]),
            y=pd.concat([df["yhat_hi_95"], df["yhat_lo_95"][::-1]]),
            fill="toself",
            fillcolor="rgba(68, 138, 255, 0.2)",
            line=dict(color="rgba(255,255,255,0)"),
            name="95% Interval",
            hoverinfo="skip",
        ))

    # 80% interval (inner, darker)
    if "yhat_lo_80" in df.columns and "yhat_hi_80" in df.columns:
        fig.add_trace(go.Scatter(
            x=pd.concat([df["ds"], df["ds"][::-1]]),
            y=pd.concat([df["yhat_hi_80"], df["yhat_lo_80"][::-1]]),
            fill="toself",
            fillcolor="rgba(68, 138, 255, 0.4)",
            line=dict(color="rgba(255,255,255,0)"),
            name="80% Interval",
            hoverinfo="skip",
        ))

    # Point forecast
    fig.add_trace(go.Scatter(
        x=df["ds"],
        y=df["yhat"],
        mode="lines",
        name="Forecast",
        line=dict(color="#1f77b4", width=2),
    ))

    # Actuals if available
    if "y" in df.columns:
        fig.add_trace(go.Scatter(
            x=df["ds"],
            y=df["y"],
            mode="markers",
            name="Actual",
            marker=dict(color="#2ca02c", size=6),
        ))

    fig.update_layout(
        title=f"Forecast: {title}",
        xaxis_title=f"Time ({tz_abbr})",
        yaxis_title="Generation (MWh)",
        hovermode="x unified",
        legend=dict(orientation="h", yanchor="bottom", y=1.02),
        height=450,
    )

    return fig


def render_drift_tab(db_path: str):
    """Render drift monitoring and alerts."""
    st.subheader("Drift Detection")

    col1, col2, col3 = st.columns(3)

    # Try to load alerts
    try:
        alerts_df = get_drift_alerts(db_path, hours=48)
    except Exception:
        alerts_df = pd.DataFrame()

    # Summary metrics
    with col1:
        critical = len(alerts_df[alerts_df["severity"] == "critical"]) if not alerts_df.empty else 0
        st.metric(
            "Critical Alerts",
            critical,
            delta=None,
            delta_color="inverse" if critical > 0 else "off",
        )

    with col2:
        warning = len(alerts_df[alerts_df["severity"] == "warning"]) if not alerts_df.empty else 0
        st.metric("Warnings", warning)

    with col3:
        stable = len(alerts_df[alerts_df["alert_type"] == "drift_check"]) if not alerts_df.empty else 0
        st.metric("Stable Checks", stable)

    st.divider()

    if alerts_df.empty:
        st.info("No drift alerts in the last 48 hours. System is stable.")

        # Show demo drift status
        st.markdown("### Demo Drift Status")
        demo_drift = pd.DataFrame({
            "Series": ["CALI_WND", "ERCO_WND", "MISO_WND", "CALI_SUN", "ERCO_SUN"],
            "Current RMSE": [125.3, 98.7, 156.2, 45.1, 67.8],
            "Threshold": [150.0, 120.0, 180.0, 60.0, 80.0],
            "Status": ["✅ Stable", "✅ Stable", "✅ Stable", "✅ Stable", "✅ Stable"],
        })
        st.dataframe(demo_drift, width="stretch")
    else:
        # Show alerts table
        st.dataframe(
            alerts_df[["alert_at", "unique_id", "severity", "current_rmse", "threshold_rmse", "message"]],
            width="stretch",
        )

        # Drift timeline
        if len(alerts_df) > 1:
            alerts_df["alert_at"] = pd.to_datetime(alerts_df["alert_at"])
            fig = px.scatter(
                alerts_df,
                x="alert_at",
                y="current_rmse",
                color="severity",
                size="current_rmse",
                hover_data=["unique_id", "message"],
                title="Drift Timeline",
            )
            fig.add_hline(
                y=alerts_df["threshold_rmse"].mean(),
                line_dash="dash",
                annotation_text="Avg Threshold",
            )
            st.plotly_chart(fig, width="stretch")


def render_coverage_tab(db_path: str):
    """Render coverage analysis comparing nominal vs empirical."""
    st.subheader("Prediction Interval Coverage")

    st.markdown("""
    **Coverage** measures how often actual values fall within prediction intervals.
    - **Nominal**: The expected coverage (80% or 95%)
    - **Empirical**: The actual observed coverage
    - **Gap**: Difference indicates calibration quality
    """)

    # Demo coverage data
    coverage_data = pd.DataFrame({
        "Series": ["CALI_WND", "ERCO_WND", "MISO_WND", "SWPP_WND", "CALI_SUN", "ERCO_SUN"],
        "Nominal 80%": [80, 80, 80, 80, 80, 80],
        "Empirical 80%": [78.5, 82.1, 76.3, 79.8, 81.2, 77.9],
        "Nominal 95%": [95, 95, 95, 95, 95, 95],
        "Empirical 95%": [93.2, 96.1, 91.5, 94.8, 95.7, 92.3],
    })

    coverage_data["Gap 80%"] = coverage_data["Empirical 80%"] - coverage_data["Nominal 80%"]
    coverage_data["Gap 95%"] = coverage_data["Empirical 95%"] - coverage_data["Nominal 95%"]

    # Summary
    col1, col2 = st.columns(2)

    with col1:
        avg_80 = coverage_data["Empirical 80%"].mean()
        st.metric("Avg 80% Coverage", f"{avg_80:.1f}%", f"{avg_80 - 80:.1f}%")

    with col2:
        avg_95 = coverage_data["Empirical 95%"].mean()
        st.metric("Avg 95% Coverage", f"{avg_95:.1f}%", f"{avg_95 - 95:.1f}%")

    st.divider()

    # Coverage comparison chart
    fig = go.Figure()

    fig.add_trace(go.Bar(
        name="80% Empirical",
        x=coverage_data["Series"],
        y=coverage_data["Empirical 80%"],
        marker_color="rgba(68, 138, 255, 0.7)",
    ))

    fig.add_trace(go.Bar(
        name="95% Empirical",
        x=coverage_data["Series"],
        y=coverage_data["Empirical 95%"],
        marker_color="rgba(68, 138, 255, 0.4)",
    ))

    # Nominal lines
    fig.add_hline(y=80, line_dash="dash", line_color="red", annotation_text="80% Nominal")
    fig.add_hline(y=95, line_dash="dash", line_color="orange", annotation_text="95% Nominal")

    fig.update_layout(
        title="Coverage by Series",
        xaxis_title="Series",
        yaxis_title="Coverage (%)",
        barmode="group",
        height=400,
    )

    st.plotly_chart(fig, width="stretch")

    # Detailed table
    with st.expander("View Coverage Data"):
        st.dataframe(coverage_data, width="stretch")


def render_weather_tab(db_path: str, regions: list):
    """Render weather features visualization."""
    st.subheader("Weather Features")

    weather_df = pd.DataFrame()

    # Prefer real pipeline output; no demo fallback.
    parquet_path = Path("data/renewable/weather.parquet")
    if parquet_path.exists():
        try:
            weather_df = pd.read_parquet(parquet_path)
            st.success(f"Loaded {len(weather_df)} weather rows from pipeline")
        except Exception as exc:
            st.warning(f"Could not load weather parquet: {exc}")

    if weather_df.empty and Path(db_path).exists():
        try:
            with connect(db_path) as con:
                weather_df = pd.read_sql_query(
                    "SELECT * FROM weather_features ORDER BY ds ASC",
                    con,
                )
            if not weather_df.empty:
                st.success(f"Loaded {len(weather_df)} weather rows from database")
        except Exception as exc:
            st.warning(f"Could not load weather data from database: {exc}")

    if weather_df.empty:
        st.warning("No weather data available. Run the pipeline to populate weather features.")
        return

    weather_df["ds"] = pd.to_datetime(weather_df["ds"], errors="coerce")
    if regions:
        weather_df = weather_df[weather_df["region"].isin(regions)]
    if weather_df.empty:
        st.warning("No weather data matching selected regions.")
        return

    # Variable selector
    weather_vars = [
        col for col in ["wind_speed_10m", "wind_speed_100m", "direct_radiation", "cloud_cover"]
        if col in weather_df.columns
    ]
    if not weather_vars:
        st.warning("Weather data missing expected variables.")
        return
    selected_var = st.selectbox("Weather Variable", options=weather_vars)

    # Plot
    fig = px.line(
        weather_df,
        x="ds",
        y=selected_var,
        color="region",
        title=f"{selected_var} by Region",
    )
    fig.update_layout(height=400)
    st.plotly_chart(fig, width="stretch")

    # Summary stats
    st.markdown("### Current Conditions")

    cols = st.columns(len(regions[:4]))
    for i, region in enumerate(regions[:4]):
        if i < len(cols):
            with cols[i]:
                region_data = weather_df[weather_df["region"] == region].iloc[-1] if len(weather_df[weather_df["region"] == region]) > 0 else {}
                st.metric(
                    region,
                    f"{region_data.get('wind_speed_10m', 0):.1f} m/s",
                    help="Wind speed at 10m",
                )


def render_interpretability_tab(regions: list, fuel_type: str):
    """Render model interpretability visualizations (SHAP, feature importance, PDP)."""
    st.subheader("Model Interpretability")

    # Model Leaderboard Section
    st.markdown("### 🏆 Model Leaderboard (Cross-Validation)")

    # Model descriptions for education
    MODEL_INFO = {
        "AutoARIMA": {
            "type": "Statistical",
            "description": "Auto-tuned ARIMA with automatic p,d,q selection. Good for univariate series with trend/seasonality.",
            "strengths": "Robust, well-understood, good prediction intervals",
        },
        "MSTL_ARIMA": {
            "type": "Statistical",
            "description": "Multiple Seasonal-Trend decomposition + ARIMA. Handles daily (24h) and weekly (168h) seasonality.",
            "strengths": "Best for multi-seasonal patterns like energy data",
        },
        "AutoETS": {
            "type": "Statistical",
            "description": "Exponential smoothing with automatic error/trend/season selection.",
            "strengths": "Simple, fast, works well for smooth series",
        },
        "AutoTheta": {
            "type": "Statistical",
            "description": "Theta method with automatic decomposition. Robust to outliers.",
            "strengths": "Competition winner (M3), handles level shifts",
        },
        "CES": {
            "type": "Statistical",
            "description": "Complex Exponential Smoothing. Captures complex seasonal patterns.",
            "strengths": "Good for complex seasonality",
        },
        "SeasonalNaive": {
            "type": "Baseline",
            "description": "Uses value from same hour last week. Baseline benchmark.",
            "strengths": "Simple benchmark - if beaten, models add value",
        },
    }

    run_log_path = Path("data/renewable/run_log.json")
    if run_log_path.exists():
        try:
            import json
            run_log = json.loads(run_log_path.read_text())
            pipeline_results = run_log.get("pipeline_results", {})
            leaderboard_data = pipeline_results.get("leaderboard", [])

            if leaderboard_data:
                leaderboard_df = pd.DataFrame(leaderboard_data)
                best_model = pipeline_results.get("best_model", "")
                best_rmse = pipeline_results.get("best_rmse", 0)

                # Key metrics row
                col1, col2, col3, col4 = st.columns(4)
                with col1:
                    st.metric("Best Model", best_model)
                with col2:
                    st.metric("Best RMSE", f"{best_rmse:.3f}")
                with col3:
                    st.metric("Models Evaluated", len(leaderboard_data))
                with col4:
                    # Calculate improvement over baseline
                    baseline_rmse = leaderboard_df[leaderboard_df["model"] == "SeasonalNaive"]["rmse"].values
                    if len(baseline_rmse) > 0 and best_rmse > 0:
                        improvement = ((baseline_rmse[0] - best_rmse) / baseline_rmse[0]) * 100
                        st.metric("vs Baseline", f"{improvement:+.1f}%", help="Improvement over SeasonalNaive")
                    else:
                        st.metric("vs Baseline", "N/A")

                # Selection rationale
                st.markdown("#### Why This Model?")
                st.info(f"""
                **{best_model}** was selected because it has the **lowest RMSE** on cross-validation.

                - **RMSE (Root Mean Square Error)**: Penalizes large errors more heavily. Best for energy forecasting where big misses are costly.
                - **Selection method**: Time-series CV with {run_log.get('config', {}).get('cv_windows', 2)} windows, step size {run_log.get('config', {}).get('cv_step_size', 168)}h
                - **Horizon**: {run_log.get('config', {}).get('horizon', 24)}h ahead forecasts
                """)

                # Model description for winner
                if best_model in MODEL_INFO:
                    info = MODEL_INFO[best_model]
                    st.success(f"**{info['type']} Model**: {info['description']}")

                # Full leaderboard with visualization
                st.markdown("#### All Models Ranked by RMSE")

                display_cols = [c for c in ["model", "rmse", "mae", "mape", "coverage_80", "coverage_95"]
                               if c in leaderboard_df.columns]

                # Create visualization
                if "rmse" in leaderboard_df.columns:
                    fig = px.bar(
                        leaderboard_df.sort_values("rmse"),
                        x="model",
                        y="rmse",
                        title="Model Comparison (Lower RMSE = Better)",
                        color="rmse",
                        color_continuous_scale="RdYlGn_r",
                    )
                    fig.add_hline(y=best_rmse, line_dash="dash", line_color="green",
                                  annotation_text=f"Best: {best_rmse:.3f}")
                    fig.update_layout(height=350)
                    st.plotly_chart(fig, width="stretch")

                # Format numeric columns for table
                styled_df = leaderboard_df[display_cols].copy()
                for col in ["rmse", "mae", "mape"]:
                    if col in styled_df.columns:
                        styled_df[col] = styled_df[col].apply(lambda x: f"{x:.3f}" if pd.notna(x) else "N/A")
                for col in ["coverage_80", "coverage_95"]:
                    if col in styled_df.columns:
                        styled_df[col] = styled_df[col].apply(lambda x: f"{x:.1f}%" if pd.notna(x) else "N/A")

                st.dataframe(styled_df, width="stretch", hide_index=True)

                # Coverage analysis
                if "coverage_80" in leaderboard_df.columns:
                    st.markdown("#### Prediction Interval Coverage")
                    st.markdown("""
                    **Coverage** measures if prediction intervals are well-calibrated:
                    - **80% interval** should contain ~80% of actual values
                    - **95% interval** should contain ~95% of actual values
                    - **Under-coverage** (<target) = intervals too narrow, overconfident
                    - **Over-coverage** (>target) = intervals too wide, conservative
                    """)

                    coverage_df = leaderboard_df[["model", "coverage_80", "coverage_95"]].copy()
                    coverage_df["coverage_80_status"] = coverage_df["coverage_80"].apply(
                        lambda x: "Under" if x < 75 else ("Over" if x > 85 else "Good") if pd.notna(x) else "N/A"
                    )
                    coverage_df["coverage_95_status"] = coverage_df["coverage_95"].apply(
                        lambda x: "Under" if x < 90 else ("Over" if x > 99 else "Good") if pd.notna(x) else "N/A"
                    )

                # Model descriptions expander
                with st.expander("Model Descriptions"):
                    for model_name, info in MODEL_INFO.items():
                        st.markdown(f"**{model_name}** ({info['type']})")
                        st.markdown(f"- {info['description']}")
                        st.markdown(f"- *Strengths*: {info['strengths']}")
                        st.markdown("---")

                # CV configuration expander
                config = run_log.get("config", {})
                with st.expander("CV Configuration"):
                    st.write({
                        "cv_windows": config.get("cv_windows"),
                        "cv_step_size": config.get("cv_step_size"),
                        "horizon": config.get("horizon"),
                        "regions": config.get("regions"),
                        "fuel_types": config.get("fuel_types"),
                        "run_at": run_log.get("run_at_utc", "N/A"),
                    })
            else:
                st.info("Leaderboard not available. Run the pipeline with the latest code to generate.")
        except Exception as e:
            st.warning(f"Could not load leaderboard: {e}")
    else:
        st.info("No run log found. Run the pipeline to generate model comparison.")

    st.divider()

    st.markdown("### 🔍 Per-Series Interpretability")
    st.markdown("""
    **LightGBM** models are trained alongside statistical models (MSTL/ARIMA) to provide
    interpretability insights. The statistical models generate the primary forecasts,
    while LightGBM helps understand feature importance and relationships.
    """)

    interp_dir = Path("data/renewable/interpretability")

    if not interp_dir.exists():
        st.info("No interpretability data available. Run the pipeline to generate SHAP and PDP plots.")
        return

    # Get available series
    series_dirs = sorted([d.name for d in interp_dir.iterdir() if d.is_dir()])

    if not series_dirs:
        st.warning("Interpretability directory exists but contains no series data.")
        return

    # Filter by selected regions and fuel type
    filtered_series = []
    for series_id in series_dirs:
        parts = series_id.split("_")
        if len(parts) == 2:
            region, ft = parts
            if regions and region not in regions:
                continue
            if fuel_type != "Both" and ft != fuel_type:
                continue
            filtered_series.append(series_id)

    if not filtered_series:
        st.warning("No interpretability data for selected filters.")
        return

    # Series selector
    selected_series = st.selectbox(
        "Select Series",
        options=filtered_series,
        index=0,
        key="interpretability_series_select",
    )

    if not selected_series:
        return

    series_dir = interp_dir / selected_series

    # Layout: Feature Importance + SHAP Summary side by side
    col1, col2 = st.columns(2)

    with col1:
        st.markdown("### Feature Importance")
        importance_path = series_dir / "feature_importance.csv"
        if importance_path.exists():
            try:
                importance_df = pd.read_csv(importance_path)
                # Show top 15 features
                top_features = importance_df.head(15)

                # Create bar chart
                fig = px.bar(
                    top_features,
                    x="importance",
                    y="feature",
                    orientation="h",
                    title=f"Top Features: {selected_series}",
                    labels={"importance": "Importance", "feature": "Feature"},
                )
                fig.update_layout(yaxis=dict(autorange="reversed"), height=400)
                st.plotly_chart(fig, width="stretch")

                with st.expander("Full Feature List"):
                    st.dataframe(importance_df, width="stretch")
            except Exception as e:
                st.error(f"Error loading feature importance: {e}")
        else:
            st.info("Feature importance not available.")

    with col2:
        st.markdown("### SHAP Summary")
        shap_summary_path = series_dir / "shap_summary.png"
        if shap_summary_path.exists():
            st.image(str(shap_summary_path), width="stretch")
        else:
            # Try bar plot as fallback
            shap_bar_path = series_dir / "shap_bar.png"
            if shap_bar_path.exists():
                st.image(str(shap_bar_path), width="stretch")
            else:
                st.info("SHAP summary not available.")

    st.divider()

    # SHAP Dependence Plots
    st.markdown("### SHAP Dependence Plots")
    st.markdown("Shows how individual feature values affect predictions.")

    shap_dep_files = list(series_dir.glob("shap_dependence_*.png"))
    if shap_dep_files:
        # Create columns for dependence plots
        n_cols = min(3, len(shap_dep_files))
        cols = st.columns(n_cols)

        for i, dep_file in enumerate(shap_dep_files[:6]):  # Limit to 6 plots
            feature_name = dep_file.stem.replace("shap_dependence_", "")
            with cols[i % n_cols]:
                st.markdown(f"**{feature_name}**")
                st.image(str(dep_file), width="stretch")
    else:
        st.info("SHAP dependence plots not available.")

    st.divider()

    # Partial Dependence Plot
    st.markdown("### Partial Dependence Plot")
    st.markdown("Shows the average effect of features on predictions (marginal effect).")

    pdp_path = series_dir / "partial_dependence.png"
    if pdp_path.exists():
        st.image(str(pdp_path), width="stretch")
    else:
        st.info("Partial dependence plot not available.")

    # Waterfall plot for sample prediction
    waterfall_path = series_dir / "shap_waterfall_sample.png"
    if waterfall_path.exists():
        st.markdown("### Sample Prediction Explanation")
        st.markdown("SHAP waterfall showing how features contributed to a single prediction.")
        st.image(str(waterfall_path), width="stretch")


def generate_demo_forecasts(regions: list, fuel_type: str) -> pd.DataFrame:
    """Generate demo forecast data for display."""
    data = []
    base_time = datetime.now(timezone.utc).replace(minute=0, second=0, microsecond=0)

    fuel_types = [fuel_type] if fuel_type != "Both" else ["WND", "SUN"]

    for region in regions[:3]:
        for ft in fuel_types:
            unique_id = f"{region}_{ft}"
            base_value = 500 if ft == "WND" else 300

            for h in range(24):
                ds = base_time + timedelta(hours=h)

                # Add daily pattern
                if ft == "SUN":
                    hour_factor = max(0, np.sin((ds.hour - 6) * np.pi / 12)) if 6 < ds.hour < 18 else 0
                    yhat = base_value * hour_factor + np.random.normal(0, 20)
                else:
                    yhat = base_value + np.sin(ds.hour * np.pi / 12) * 100 + np.random.normal(0, 30)

                yhat = max(0, yhat)

                data.append({
                    "unique_id": unique_id,
                    "region": region,
                    "fuel_type": ft,
                    "ds": ds,
                    "yhat": yhat,
                    "yhat_lo_80": yhat * 0.85,
                    "yhat_hi_80": yhat * 1.15,
                    "yhat_lo_95": yhat * 0.75,
                    "yhat_hi_95": yhat * 1.25,
                })

    return pd.DataFrame(data)


def render_insights_tab(db_path: str):
    """Render comprehensive data insights including regional context and EDA results."""
    st.title("📚 Data & Insights")
    st.markdown("**Understanding the data, regions, and methodology behind renewable energy forecasting**")

    # ========================================================================
    # Section 1: Pipeline Architecture
    # ========================================================================
    st.header("⚙️ Pipeline Architecture")
    st.markdown("""
This forecasting system follows a rigorous pipeline from data ingestion through model validation:
    """)

    # Mermaid diagram from README
    code = r"""
    graph TB
        A[EIA API<br/>Generation Data] -->|fetch_renewable_data| B[generation.parquet<br/>unique_id, ds, y]
        C[Open-Meteo API<br/>Weather Data] -->|fetch_renewable_weather| D[weather.parquet<br/>ds, region, weather_vars]

        B --> E[EDA Module<br/>Investigation & Recommendations]
        D --> E

        E -->|Recommendations| F[Dataset Builder<br/>Fuel-Specific Preprocessing]

        F -->|Validated Dataset| G[StatsForecast CV<br/>MSTL, AutoARIMA, AutoETS]
        F -->|Optional| H[LightGBM SHAP<br/>Interpretability]

        G --> I[Best Model Selection<br/>Leaderboard]
        I --> J[Generate Forecasts<br/>24h + Intervals]

        J --> K[forecasts.parquet<br/>yhat, yhat_lo, yhat_hi]
        J --> L[Quality Gates<br/>Drift Detection]

        L -->|Pass| M[Git Commit<br/>Artifact Versioning]
        L -->|Fail| N[Pipeline Fails<br/>Manual Review]

        H --> O[SHAP Reports<br/>Feature Importance]

        style E fill:#e1f5ff
        style F fill:#fff4e1
        style G fill:#f0e1ff
        style L fill:#ffe1e1
    """

    stmd.st_mermaid(code)

    # ========================================================================
    # Section 2: Regional Electricity Markets
    # ========================================================================
    st.header("🌍 Regional Electricity Markets")
    st.markdown("""
The United States electricity grid is managed by multiple **Independent System Operators (ISOs)**
and **Regional Transmission Organizations (RTOs)**. This dashboard focuses on three major regions:
    """)

    # Create tabs for each region
    region_tab1, region_tab2, region_tab3, region_tab4 = st.tabs([
        "🏢 ERCOT (Texas)",
        "🌾 MISO (Midwest)",
        "☀️ CAISO (California)",
        "📊 Comparison"
    ])

    with region_tab1:
        st.subheader("ERCOT - Electric Reliability Council of Texas")

        col1, col2 = st.columns([2, 1])

        with col1:
            st.markdown("""
**Geographic Coverage:**
- **Texas** (90% of the state)
- Does NOT include El Paso, parts of East Texas, or the Panhandle

**Key Characteristics:**
- ⚡ **Population**: ~27 million people
- 🏢 **Unique Feature**: Operates **independently** from the rest of the US grid
- 🔌 **Interconnection**: Texas Interconnection (isolated from Eastern and Western grids)
- 🌞 **Renewables**: High solar and wind capacity (~35% of generation)
- 🌡️ **Climate**: Hot summers → high cooling demand
- 💡 **Market**: Deregulated electricity market (competitive pricing)

**Why It's Different:**
- Not subject to federal regulation (doesn't cross state borders)
- Cannot easily import/export power from other states
- Famous for the 2021 winter storm crisis
- **EIA Code**: `TEX` (Texas)
            """)

        with col2:
            st.info("""
**Grid Challenge**

ERCOT's isolation means
it cannot import power
during emergencies.

**Forecasting Impact**

Solar & wind forecasting
is critical - no backup
from neighboring grids.
            """)

    with region_tab2:
        st.subheader("MISO - Midcontinent Independent System Operator")

        col1, col2 = st.columns([2, 1])

        with col1:
            st.markdown("""
**Geographic Coverage:**
- **15 states** across North-Central US:
  - **North**: North Dakota, South Dakota, Minnesota, Wisconsin, Michigan
  - **Central**: Iowa, Illinois, Indiana, Missouri
  - **South**: Arkansas, Louisiana, Mississippi, Texas (parts)

**Key Characteristics:**
- ⚡ **Population**: ~45 million people
- 🌾 **Characteristics**: Large agricultural region, diverse fuel mix
- 💨 **Renewables**: Massive wind capacity (especially in the Great Plains)
- 🏭 **Industry**: Heavy manufacturing (automotive, steel)
- 🌡️ **Climate**: Four distinct seasons, cold winters, hot summers
- 💡 **Market**: Day-ahead and real-time energy markets

**Why It's Interesting:**
- One of the largest ISOs in North America
- Leading in wind energy integration
- Diverse geography (from Great Lakes to Gulf Coast)
- **EIA Code**: `MISO` (Midcontinent ISO)
            """)

        with col2:
            st.info("""
**Grid Challenge**

Vast geography creates
transmission challenges
and regional variations.

**Forecasting Impact**

Wind forecasting most
critical due to Great
Plains wind belt.
            """)

    with region_tab3:
        st.subheader("CAISO - California Independent System Operator")

        col1, col2 = st.columns([2, 1])

        with col1:
            st.markdown("""
**Geographic Coverage:**
- **California** (80% of the state)
- Small parts of Nevada

**Key Characteristics:**
- ⚡ **Population**: ~30 million people
- 🌞 **Renewables**: Aggressive renewable energy targets (60% by 2030, 100% by 2045)
- 🔋 **Innovation**: Leader in battery storage, rooftop solar
- 🌡️ **Climate**: Mediterranean (hot, dry summers; mild winters)
- 🔥 **Challenges**: Wildfires, drought, "duck curve" problem
- 💡 **Market**: Complex wholesale electricity market

**Why It's Unique:**
- **Most aggressive renewable targets** in the US
- **Duck curve problem**: Solar floods grid during day, steep ramp-up needed at sunset
- **Net metering**: Rooftops can sell back to grid
- High electricity prices (climate policies, infrastructure costs)
- **EIA Code**: `CAL` (California)
            """)

        with col2:
            st.warning("""
**Grid Challenge**

"Duck Curve" problem:
- Midday: Too much solar
- Evening: Sharp ramp-up

**Forecasting Impact**

Solar forecasting MOST
critical (50%+ renewable
target). Need accurate
sunset timing predictions.
            """)

    with region_tab4:
        st.subheader("Regional Comparison")

        # Comparison table
        comparison_data = {
            "Characteristic": [
                "States Covered",
                "Population",
                "Grid Connection",
                "Renewable %",
                "Primary Challenge",
                "Peak Demand Season",
                "Unique Feature",
                "Solar Capacity",
                "Wind Capacity",
            ],
            "ERCOT (ERCO)": [
                "Texas",
                "27M",
                "Isolated (Texas only)",
                "~35% (wind, solar)",
                "Grid isolation, heat waves",
                "Summer (cooling)",
                "No federal oversight",
                "Growing fast",
                "Very high",
            ],
            "MISO": [
                "15 states",
                "45M",
                "Eastern Interconnection",
                "~25% (mostly wind)",
                "Winter cold, wind variability",
                "Summer/Winter",
                "Largest ISO",
                "Moderate",
                "Very high (Great Plains)",
            ],
            "CAISO (CALI)": [
                "California",
                "30M",
                "Western Interconnection",
                "~50% (solar, wind, hydro)",
                "Duck curve, wildfires",
                "Summer (cooling)",
                "Most aggressive renewables",
                "Highest in US",
                "Moderate",
            ],
        }

        st.dataframe(
            comparison_data,
            width="stretch",
            hide_index=True,
        )

        st.markdown("""
### 🎯 Forecasting Implications by Region

**Solar Forecasting Priority:**
1. 🥇 **CAISO**: Most important (50%+ renewable target, duck curve management)
2. 🥈 **ERCOT**: Growing importance (rapid solar buildout)
3. 🥉 **MISO**: Less critical (wind-focused region)

**Wind Forecasting Priority:**
1. 🥇 **MISO**: Most critical (Great Plains wind belt, 15-state coverage)
2. 🥈 **ERCOT**: Very important (West Texas wind resources)
3. 🥉 **CAISO**: Moderate importance (some Tehachapi wind)

**Data Quality Observations:**
- **CALI_SUN**: 403 negative values found (likely net metering, auxiliary loads)
- **Other regions**: Clean data (no negative values detected)
        """)

    # ========================================================================
    # Section 3: Model Performance Dashboard (NEW)
    # ========================================================================
    st.header("🎯 Model Performance Dashboard")
    st.markdown("""
This section shows the latest model performance metrics from cross-validation and
compares different forecasting models to select the best performer.
    """)

    # Load run_log.json for model performance
    data_dir = Path(db_path).parent
    run_log_path = data_dir / "run_log.json"

    if run_log_path.exists():
        import json
        with open(run_log_path, 'r') as f:
            run_log = json.load(f)

        pipeline_results = run_log.get('pipeline_results', {})

        # Display key metrics
        col1, col2, col3, col4 = st.columns(4)
        with col1:
            best_model = pipeline_results.get('best_model', 'N/A')
            st.metric("🏆 Best Model", best_model)
        with col2:
            best_rmse = pipeline_results.get('best_rmse', 0)
            st.metric("📊 Best RMSE", f"{best_rmse:,.0f}")
        with col3:
            series_count = pipeline_results.get('series_count', 0)
            st.metric("📈 Series Forecasted", series_count)
        with col4:
            rows_out = pipeline_results.get('preprocessing', {}).get('rows_output', 0)
            st.metric("📋 Training Rows", f"{rows_out:,}")

        # Model Leaderboard - Interactive Table
        st.subheader("📊 Model Leaderboard (Cross-Validation Results)")

        leaderboard = pipeline_results.get('leaderboard', [])
        if leaderboard:
            import pandas as pd
            lb_df = pd.DataFrame(leaderboard)

            # Format for display
            if 'coverage_80' in lb_df.columns:
                lb_df['coverage_80'] = lb_df['coverage_80'].apply(lambda x: f"{x*100:.1f}%")
            if 'coverage_95' in lb_df.columns:
                lb_df['coverage_95'] = lb_df['coverage_95'].apply(lambda x: f"{x*100:.1f}%")
            if 'rmse' in lb_df.columns:
                lb_df['rmse'] = lb_df['rmse'].apply(lambda x: f"{x:,.0f}")
            if 'mae' in lb_df.columns:
                lb_df['mae'] = lb_df['mae'].apply(lambda x: f"{x:,.0f}")

            st.dataframe(
                lb_df,
                width="stretch",
                hide_index=True,
            )

            # Interactive Model Comparison Chart
            st.subheader("📈 Model Comparison (RMSE & MAE)")

            # Reload raw data for plotting
            lb_df_raw = pd.DataFrame(leaderboard)

            import plotly.graph_objects as go
            fig = go.Figure()

            # RMSE bars
            fig.add_trace(go.Bar(
                name='RMSE',
                x=lb_df_raw['model'],
                y=lb_df_raw['rmse'],
                marker_color='indianred',
                text=[f"{v:,.0f}" for v in lb_df_raw['rmse']],
                textposition='outside'
            ))

            # MAE bars
            fig.add_trace(go.Bar(
                name='MAE',
                x=lb_df_raw['model'],
                y=lb_df_raw['mae'],
                marker_color='lightseagreen',
                text=[f"{v:,.0f}" for v in lb_df_raw['mae']],
                textposition='outside'
            ))

            fig.update_layout(
                title="Model Performance: Lower is Better",
                xaxis_title="Model",
                yaxis_title="Error Metric (MWh)",
                barmode='group',
                height=400,
                hovermode='x unified',
                legend=dict(
                    orientation="h",
                    yanchor="bottom",
                    y=1.02,
                    xanchor="right",
                    x=1
                )
            )

            st.plotly_chart(fig, width="stretch")

            # Coverage Analysis
            st.subheader("🎯 Prediction Interval Coverage")
            st.markdown("""
**Coverage** measures how often actual values fall within prediction intervals:
- **80% Interval**: Should contain ~80% of actuals
- **95% Interval**: Should contain ~95% of actuals
            """)

            fig_coverage = go.Figure()

            fig_coverage.add_trace(go.Bar(
                name='80% Coverage',
                x=lb_df_raw['model'],
                y=lb_df_raw['coverage_80'] * 100,
                marker_color='skyblue',
                text=[f"{v*100:.1f}%" for v in lb_df_raw['coverage_80']],
                textposition='outside'
            ))

            fig_coverage.add_trace(go.Bar(
                name='95% Coverage',
                x=lb_df_raw['model'],
                y=lb_df_raw['coverage_95'] * 100,
                marker_color='navy',
                text=[f"{v*100:.1f}%" for v in lb_df_raw['coverage_95']],
                textposition='outside'
            ))

            # Add target lines
            fig_coverage.add_hline(y=80, line_dash="dash", line_color="gray", annotation_text="80% Target")
            fig_coverage.add_hline(y=95, line_dash="dash", line_color="red", annotation_text="95% Target")

            fig_coverage.update_layout(
                title="Prediction Interval Coverage by Model",
                xaxis_title="Model",
                yaxis_title="Coverage (%)",
                barmode='group',
                height=400,
                yaxis_range=[0, 100]
            )

            st.plotly_chart(fig_coverage, width="stretch")

        else:
            st.warning("No leaderboard data available in run log")
    else:
        st.warning("No run log found. Run the pipeline to generate performance metrics.")

    # ========================================================================
    # Section 4: Forecast Accuracy by Region (NEW)
    # ========================================================================
    st.header("🌎 Forecast Accuracy by Region")
    st.markdown("""
Analyzing forecast performance across different regions helps identify where
the models perform best and where improvements are needed.
    """)

    # Load forecasts and generation data
    forecasts_path = data_dir / "forecasts.parquet"
    generation_path = data_dir / "generation.parquet"

    if forecasts_path.exists() and generation_path.exists():
        import numpy as np
        import pandas as pd

        forecasts_df = pd.read_parquet(forecasts_path)
        generation_df = pd.read_parquet(generation_path)

        # Ensure datetime
        forecasts_df['ds'] = pd.to_datetime(forecasts_df['ds'])
        generation_df['ds'] = pd.to_datetime(generation_df['ds'])

        # Merge forecasts with actuals
        merged = forecasts_df.merge(
            generation_df[['unique_id', 'ds', 'y']],
            on=['unique_id', 'ds'],
            how='inner'
        )

        if not merged.empty:
            # Extract region from unique_id
            merged['region'] = merged['unique_id'].str.extract(r'(CALI|ERCO|MISO)')[0]
            merged['fuel'] = merged['unique_id'].str.extract(r'(WND|SUN)')[0]

            # Calculate errors
            merged['error'] = merged['yhat'] - merged['y']
            merged['abs_error'] = np.abs(merged['error'])
            merged['sq_error'] = merged['error'] ** 2

            # Aggregate by region
            regional_metrics = merged.groupby('region').agg({
                'abs_error': 'mean',
                'sq_error': lambda x: np.sqrt(x.mean()),
                'error': ['mean', 'std'],
                'unique_id': 'count'
            }).round(2)

            regional_metrics.columns = ['MAE', 'RMSE', 'Bias', 'Error_Std', 'Count']
            regional_metrics = regional_metrics.reset_index()

            # Display metrics table
            st.subheader("📊 Regional Performance Metrics")

            col1, col2 = st.columns([2, 1])

            with col1:
                st.dataframe(
                    regional_metrics.style.format({
                        'MAE': '{:.0f}',
                        'RMSE': '{:.0f}',
                        'Bias': '{:.0f}',
                        'Error_Std': '{:.0f}',
                        'Count': '{:.0f}'
                    }),
                    width="stretch",
                    hide_index=True,
                )

            with col2:
                st.info("""
**Metrics Explained**

**MAE**: Mean Absolute Error
Lower = more accurate

**RMSE**: Root Mean Squared Error
Penalizes large errors

**Bias**: Average error
(+ = overforecast, - = underforecast)
                """)

            # Interactive regional comparison
            st.subheader("📍 Regional Accuracy Comparison")

            fig_regional = go.Figure()

            fig_regional.add_trace(go.Bar(
                name='MAE',
                x=regional_metrics['region'],
                y=regional_metrics['MAE'],
                marker_color='indianred',
                text=[f"{v:.0f}" for v in regional_metrics['MAE']],
                textposition='outside'
            ))

            fig_regional.add_trace(go.Bar(
                name='RMSE',
                x=regional_metrics['region'],
                y=regional_metrics['RMSE'],
                marker_color='lightseagreen',
                text=[f"{v:.0f}" for v in regional_metrics['RMSE']],
                textposition='outside'
            ))

            fig_regional.update_layout(
                title="Forecast Accuracy by Region (Lower is Better)",
                xaxis_title="Region",
                yaxis_title="Error (MWh)",
                barmode='group',
                height=400
            )

            st.plotly_chart(fig_regional, width="stretch")

            # Error distribution by region
            st.subheader("📉 Error Distribution by Region")

            fig_dist = go.Figure()

            for region in merged['region'].unique():
                region_data = merged[merged['region'] == region]
                fig_dist.add_trace(go.Box(
                    y=region_data['error'],
                    name=region,
                    boxmean='sd'
                ))

            fig_dist.add_hline(y=0, line_dash="dash", line_color="black", annotation_text="Perfect Forecast")

            fig_dist.update_layout(
                title="Forecast Error Distribution by Region",
                yaxis_title="Error (MWh) [Forecast - Actual]",
                height=400,
                showlegend=True
            )

            st.plotly_chart(fig_dist, width="stretch")

            # Fuel type breakdown
            st.subheader("⚡ Accuracy by Fuel Type")

            fuel_metrics = merged.groupby(['region', 'fuel']).agg({
                'abs_error': 'mean',
                'sq_error': lambda x: np.sqrt(x.mean()),
            }).round(2).reset_index()
            fuel_metrics.columns = ['Region', 'Fuel', 'MAE', 'RMSE']

            # Create pivot for heatmap
            pivot_mae = fuel_metrics.pivot(index='Fuel', columns='Region', values='MAE')

            fig_heatmap = go.Figure(data=go.Heatmap(
                z=pivot_mae.values,
                x=pivot_mae.columns,
                y=pivot_mae.index,
                colorscale='RdYlGn_r',
                text=pivot_mae.values,
                texttemplate='%{text:.0f}',
                textfont={"size": 14},
                colorbar=dict(title="MAE (MWh)")
            ))

            fig_heatmap.update_layout(
                title="Mean Absolute Error by Region and Fuel Type",
                xaxis_title="Region",
                yaxis_title="Fuel Type",
                height=300
            )

            st.plotly_chart(fig_heatmap, width="stretch")

        else:
            st.warning("No matching forecast-actual pairs found for analysis")
    else:
        st.warning("Forecast or generation data not available for accuracy analysis")

    # ========================================================================
    # Section 5: Data Quality & EDA History (NEW)
    # ========================================================================
    st.header("📊 Data Quality & EDA History")
    st.markdown("""
Track data quality metrics over time and access historical EDA runs.
    """)

    # List all EDA runs
    eda_dir = data_dir / "eda"

    if eda_dir.exists():
        eda_runs = sorted([d for d in eda_dir.iterdir() if d.is_dir()], reverse=True)

        if eda_runs:
            st.subheader("📅 EDA Run History")

            # Create table of EDA runs
            eda_history = []
            for eda_run in eda_runs:
                recs_file = eda_run / "recommendations.json"
                if recs_file.exists():
                    with open(recs_file, 'r') as f:
                        recs = json.load(f)

                    preprocessing = recs.get('preprocessing', {})
                    eda_history.append({
                        'Timestamp': eda_run.name,
                        'Policy': preprocessing.get('negative_policy', 'N/A'),
                        'Confidence': preprocessing.get('negative_confidence', 'N/A'),
                        'Negatives Found': preprocessing.get('data_summary', {}).get('negative_count', 0),
                        'Affected Series': len(preprocessing.get('data_summary', {}).get('affected_series', [])),
                        'Path': str(eda_run)
                    })

            if eda_history:
                eda_df = pd.DataFrame(eda_history)

                # Display with download link
                col1, col2 = st.columns([3, 1])

                with col1:
                    st.dataframe(
                        eda_df[['Timestamp', 'Policy', 'Confidence', 'Negatives Found', 'Affected Series']],
                        width="stretch",
                        hide_index=True,
                    )

                with col2:
                    st.info(f"""
**Total EDA Runs**
{len(eda_runs)}

**Latest Run**
{eda_runs[0].name}
                    """)

                # Data quality trend
                if len(eda_history) > 1:
                    st.subheader("📈 Data Quality Trend")

                    fig_quality = go.Figure()

                    fig_quality.add_trace(go.Scatter(
                        x=eda_df['Timestamp'],
                        y=eda_df['Negatives Found'],
                        mode='lines+markers',
                        name='Negative Values',
                        line=dict(color='red', width=2),
                        marker=dict(size=8)
                    ))

                    fig_quality.update_layout(
                        title="Negative Values Over Time",
                        xaxis_title="EDA Run Timestamp",
                        yaxis_title="Count of Negative Values",
                        height=350,
                        hovermode='x unified'
                    )

                    st.plotly_chart(fig_quality, width="stretch")

                # Allow selection of specific EDA run
                st.subheader("🔍 View Specific EDA Run")

                selected_run = st.selectbox(
                    "Select an EDA run to view details:",
                    options=[r.name for r in eda_runs],
                    index=0
                )

                if selected_run:
                    selected_path = data_dir / "eda" / selected_run
                    recs_file = selected_path / "recommendations.json"

                    if recs_file.exists():
                        with open(recs_file, 'r') as f:
                            selected_recs = json.load(f)

                        with st.expander(f"📄 EDA Results for {selected_run}", expanded=False):
                            st.json(selected_recs)

                            # Links to visualization files
                            st.markdown("**Download Visualizations:**")
                            viz_files = list(selected_path.rglob("*.png"))
                            if viz_files:
                                for viz_file in viz_files:
                                    rel_path = viz_file.relative_to(selected_path)
                                    st.markdown(f"- `{rel_path}`")
                            else:
                                st.info("No visualizations found for this run")

            else:
                st.info("No EDA recommendations found in run history")
        else:
            st.info("No EDA runs found")
    else:
        st.warning("EDA directory not found")

    # ========================================================================
    # Section 6: Exploratory Data Analysis (Enhanced with Plotly)
    # ========================================================================
    st.header("🔬 Exploratory Data Analysis")
    st.markdown("""
Before building forecasting models, we perform comprehensive EDA to understand:
- Data quality issues (negatives, missing values, outliers)
- Seasonal patterns (daily, weekly cycles)
- Zero-inflation (expected for solar at night)
- Weather alignment (ensure weather data matches generation timestamps)
    """)

    # Load latest EDA results
    data_dir = Path(db_path).parent
    eda_dir = data_dir / "eda"

    if not eda_dir.exists():
        st.warning("No EDA results found. Run the pipeline to generate analysis.")
        return

    # Get latest EDA run
    eda_runs = sorted([d for d in eda_dir.iterdir() if d.is_dir()], reverse=True)
    if not eda_runs:
        st.warning("No EDA results found. Run the pipeline to generate analysis.")
        return

    latest_eda = eda_runs[0]
    st.info(f"📅 **EDA Results from**: {latest_eda.name}")

    # Load recommendations
    recs_file = latest_eda / "recommendations.json"
    eda_report_file = latest_eda / "eda_report.json"

    if recs_file.exists():
        import json
        with open(recs_file, 'r') as f:
            recs = json.load(f)

        # Data Summary
        st.subheader("📊 Data Summary")
        if 'preprocessing' in recs and 'data_summary' in recs['preprocessing']:
            summary = recs['preprocessing']['data_summary']

            col1, col2, col3 = st.columns(3)
            with col1:
                st.metric("Total Rows", f"{summary.get('generation_rows', 0):,}")
            with col2:
                st.metric("Negative Values", summary.get('negative_count', 0))
            with col3:
                affected = summary.get('affected_series', [])
                st.metric("Affected Series", len(affected))

        # Preprocessing Recommendation
        st.subheader("💡 Preprocessing Recommendation")
        if 'preprocessing' in recs:
            prep = recs['preprocessing']

            col1, col2 = st.columns([3, 1])

            with col1:
                st.markdown(f"""
**Policy**: `{prep.get('negative_policy', 'N/A')}`

**Reason**: {prep.get('negative_reason', 'No reason provided')}

**Confidence**: {prep.get('negative_confidence', 'N/A')}
                """)

            with col2:
                confidence = prep.get('negative_confidence', 'LOW')
                if confidence == 'HIGH':
                    st.success("✅ High Confidence")
                elif confidence == 'MEDIUM':
                    st.info("⚠️ Medium Confidence")
                else:
                    st.warning("⚠️ Low Confidence")

    # Visualizations (Enhanced with Interactive Plotly)
    st.subheader("📈 EDA Visualizations (Interactive)")

    viz_tabs = st.tabs([
        "🔴 Negative Values",
        "📅 Seasonality",
        "0️⃣ Zero Inflation",
        "🌤️ Generation Profiles"
    ])

    with viz_tabs[0]:
        st.markdown("""
**Negative Value Investigation**

Physical reality: Renewable generation CANNOT be negative. Negative values indicate:
- Net generation accounting (gross - auxiliary load)
- Metering errors
- Data reporting issues

For forecasting, we clamp these to zero as recommended by EDA.
        """)

        # Try to load generation data for interactive viz
        generation_path = data_dir / "generation.parquet"
        if generation_path.exists():
            gen_df = pd.read_parquet(generation_path)
            gen_df['ds'] = pd.to_datetime(gen_df['ds'])

            # Show negative values if they exist
            negatives = gen_df[gen_df['y'] < 0]

            if not negatives.empty:
                st.warning(f"⚠️ Found {len(negatives)} negative values across {negatives['unique_id'].nunique()} series")

                # Interactive scatter plot of negatives
                fig_neg = go.Figure()

                for series_id in negatives['unique_id'].unique():
                    series_data = negatives[negatives['unique_id'] == series_id]
                    fig_neg.add_trace(go.Scatter(
                        x=series_data['ds'],
                        y=series_data['y'],
                        mode='markers',
                        name=series_id,
                        marker=dict(size=8, opacity=0.6)
                    ))

                fig_neg.add_hline(y=0, line_dash="dash", line_color="black", annotation_text="Zero Line")

                fig_neg.update_layout(
                    title="Negative Values Timeline",
                    xaxis_title="Timestamp",
                    yaxis_title="Generation (MWh)",
                    height=400,
                    hovermode='x unified'
                )

                st.plotly_chart(fig_neg, width="stretch")
            else:
                st.success("✅ No negative values detected in current data!")

        # Show static image as reference
        neg_img = latest_eda / "negative_values" / "negative_investigation.png"
        if neg_img.exists():
            with st.expander("📊 Static EDA Report", expanded=False):
                st.image(str(neg_img), width="stretch")

    with viz_tabs[1]:
        st.markdown("""
**Seasonality Analysis**

Renewable generation exhibits strong cyclical patterns:
- **Daily**: Solar peaks at noon, wind varies by time
- **Weekly**: Industrial demand affects generation patterns
- **Seasonal**: Summer vs winter differences

These patterns are captured by MSTL (Multiple Seasonal-Trend decomposition using LOESS) in our models.
        """)

        # Interactive hourly profile
        if generation_path.exists():
            gen_df = pd.read_parquet(generation_path)
            gen_df['ds'] = pd.to_datetime(gen_df['ds'])
            gen_df['hour'] = gen_df['ds'].dt.hour
            gen_df['dow'] = gen_df['ds'].dt.day_name()

            # Hourly profiles by series
            fig_hour = go.Figure()

            for series_id in gen_df['unique_id'].unique():
                series_data = gen_df[gen_df['unique_id'] == series_id]
                hourly_avg = series_data.groupby('hour')['y'].mean()

                fig_hour.add_trace(go.Scatter(
                    x=hourly_avg.index,
                    y=hourly_avg.values,
                    mode='lines+markers',
                    name=series_id,
                    line=dict(width=2)
                ))

            fig_hour.update_layout(
                title="Average Generation by Hour of Day",
                xaxis_title="Hour of Day",
                yaxis_title="Average Generation (MWh)",
                height=450,
                hovermode='x unified',
                xaxis=dict(tickmode='linear', tick0=0, dtick=2)
            )

            st.plotly_chart(fig_hour, width="stretch")

            # Day of week profile
            st.markdown("**Weekly Patterns**")

            dow_order = ['Monday', 'Tuesday', 'Wednesday', 'Thursday', 'Friday', 'Saturday', 'Sunday']
            dow_avg = gen_df.groupby(['unique_id', 'dow'])['y'].mean().reset_index()

            fig_dow = go.Figure()

            for series_id in dow_avg['unique_id'].unique():
                series_data = dow_avg.loc[dow_avg["unique_id"] == series_id].copy()
                series_data.loc[:, "dow"] = pd.Categorical(
                    series_data["dow"],
                    categories=dow_order,
                    ordered=True,
                )
                series_data = series_data.sort_values('dow')

                fig_dow.add_trace(go.Bar(
                    x=series_data['dow'],
                    y=series_data['y'],
                    name=series_id
                ))

            fig_dow.update_layout(
                title="Average Generation by Day of Week",
                xaxis_title="Day of Week",
                yaxis_title="Average Generation (MWh)",
                height=400,
                barmode='group'
            )

            st.plotly_chart(fig_dow, width="stretch")

        # Show static image as reference
        season_img = latest_eda / "seasonality" / "hourly_profiles.png"
        if season_img.exists():
            with st.expander("📊 Static EDA Report", expanded=False):
                st.image(str(season_img), width="stretch")

    with viz_tabs[2]:
        st.markdown("""
**Zero-Inflation Analysis**

**Expected Zeros**:
- **Solar**: Nighttime generation is ALWAYS zero (sun not shining)
- **Wind**: Rarely zero (wind always has some component)

Zero-inflation is normal for solar and factored into model selection.
We avoid MAPE (Mean Absolute Percentage Error) as it's undefined when y=0.
        """)

        # Interactive zero analysis
        if generation_path.exists():
            gen_df = pd.read_parquet(generation_path)
            gen_df['ds'] = pd.to_datetime(gen_df['ds'])
            gen_df['hour'] = gen_df['ds'].dt.hour

            # Calculate zero ratio by hour for each series
            fig_zero = go.Figure()

            for series_id in gen_df['unique_id'].unique():
                series_data = gen_df[gen_df['unique_id'] == series_id]
                # Use vectorized approach: faster and no FutureWarning
                zero_by_hour = (
                    series_data.assign(is_zero=series_data['y'].eq(0))
                    .groupby('hour')['is_zero']
                    .mean() * 100
                )

                fig_zero.add_trace(go.Bar(
                    x=zero_by_hour.index,
                    y=zero_by_hour.values,
                    name=series_id,
                    opacity=0.7
                ))

            fig_zero.update_layout(
                title="Zero-Inflation by Hour of Day",
                xaxis_title="Hour of Day",
                yaxis_title="Percentage of Zero Values (%)",
                height=450,
                barmode='group',
                xaxis=dict(tickmode='linear', tick0=0, dtick=2)
            )

            st.plotly_chart(fig_zero, width="stretch")

            # Summary statistics
            col1, col2 = st.columns(2)

            with col1:
                st.markdown("**Solar Series (Expected Zeros)**")
                solar_series = [s for s in gen_df['unique_id'].unique() if 'SUN' in s]
                for series_id in solar_series:
                    zero_pct = (gen_df[gen_df['unique_id'] == series_id]['y'] == 0).mean() * 100
                    st.metric(series_id, f"{zero_pct:.1f}%")

            with col2:
                st.markdown("**Wind Series (Minimal Zeros)**")
                wind_series = [s for s in gen_df['unique_id'].unique() if 'WND' in s]
                for series_id in wind_series:
                    zero_pct = (gen_df[gen_df['unique_id'] == series_id]['y'] == 0).mean() * 100
                    st.metric(series_id, f"{zero_pct:.1f}%")

        # Show static image as reference
        zero_img = latest_eda / "zero_inflation" / "zero_inflation.png"
        if zero_img.exists():
            with st.expander("📊 Static EDA Report", expanded=False):
                st.image(str(zero_img), width="stretch")

    with viz_tabs[3]:
        st.markdown("""
**Generation Profiles Over Time**

Interactive time series view of generation data across all regions and fuel types.
Use the legend to toggle series on/off.
        """)

        if generation_path.exists():
            gen_df = pd.read_parquet(generation_path)
            gen_df['ds'] = pd.to_datetime(gen_df['ds'])

            # Time series plot
            fig_ts = go.Figure()

            for series_id in gen_df['unique_id'].unique():
                series_data = gen_df[gen_df['unique_id'] == series_id].sort_values('ds')
                fig_ts.add_trace(go.Scatter(
                    x=series_data['ds'],
                    y=series_data['y'],
                    mode='lines',
                    name=series_id,
                    line=dict(width=1.5),
                    opacity=0.8
                ))

            fig_ts.update_layout(
                title="Generation Time Series (All Series)",
                xaxis_title="Timestamp",
                yaxis_title="Generation (MWh)",
                height=500,
                hovermode='x unified',
                legend=dict(
                    orientation="v",
                    yanchor="top",
                    y=1,
                    xanchor="left",
                    x=1.01
                )
            )

            # Add range slider
            fig_ts.update_xaxes(
                rangeslider_visible=True,
                rangeselector=dict(
                    buttons=list([
                        dict(count=1, label="1d", step="day", stepmode="backward"),
                        dict(count=7, label="1w", step="day", stepmode="backward"),
                        dict(count=14, label="2w", step="day", stepmode="backward"),
                        dict(step="all", label="All")
                    ])
                )
            )

            st.plotly_chart(fig_ts, width="stretch")

            # Summary statistics
            st.markdown("**Summary Statistics**")
            summary_stats = gen_df.groupby('unique_id')['y'].agg(['count', 'mean', 'std', 'min', 'max']).round(2)
            summary_stats.columns = ['Count', 'Mean (MWh)', 'Std Dev', 'Min', 'Max']
            st.dataframe(summary_stats, width="stretch")

    # ========================================================================
    # Section 4: Model & Architecture Details
    # ========================================================================
    st.header("🤖 Modeling Approach")

    col1, col2 = st.columns(2)

    with col1:
        st.subheader("Why StatsForecast?")
        st.markdown("""
We use **StatsForecast** (Nixtla) instead of Prophet or MLForecast:

✅ **Native multi-series support** (10x faster)
✅ **Built-in prediction intervals** (no conformal prediction needed)
✅ **Production-ready** (battle-tested at scale)
✅ **Exogenous regressors** (weather variables)

**Models Tested**:
- **MSTL**: Multiple Seasonal-Trend decomposition
- **AutoARIMA**: Automatic ARIMA model selection
- **AutoETS**: Exponential smoothing state space

Best model selected via cross-validation (RMSE).
        """)

    with col2:
        st.subheader("Weather Features")
        st.markdown("""
**7 Key Weather Variables** (Open-Meteo API):

☀️ **Solar-Related**:
- Direct solar radiation
- Diffuse solar radiation
- Cloud cover

💨 **Wind-Related**:
- Wind speed at 10m
- Wind speed at 100m

🌡️ **General**:
- Temperature at 2m

These features are strongly correlated with generation and improve forecast accuracy by 15-20%.
        """)

    # ========================================================================
    # Section 5: Data Sources
    # ========================================================================
    st.header("📡 Data Sources")

    source_col1, source_col2 = st.columns(2)

    with source_col1:
        st.subheader("🏭 EIA RTO Fuel-Type Data")
        st.markdown("""
**Energy Information Administration (EIA)**

- **Authoritative**: Official US electricity generation data
- **Coverage**: Hourly granularity covering 80%+ of US grid
- **Accessibility**: Free API with key (no usage limits)
- **Timeliness**: Real-time with 12-48h publishing lag
- **Quality**: High (direct from RTOs/ISOs)

🔗 [EIA API Documentation](https://www.eia.gov/opendata/)
        """)

    with source_col2:
        st.subheader("🌤️ Open-Meteo Weather API")
        st.markdown("""
**Open-Meteo: Open-source Weather API**

- **Free & Open**: No authentication, unlimited requests
- **Leakage Prevention**: Separate historical + forecast endpoints
- **Global Coverage**: Works for any lat/lon coordinate
- **Variables**: 7 key features correlated with generation
- **Reliability**: 99.9%+ uptime

🔗 [Open-Meteo Documentation](https://open-meteo.com/en/docs)
        """)

    # Final note
    st.markdown("---")
    st.info("""
💡 **Note**: This dashboard provides real-time insights into renewable energy forecasting.
For technical details, see the [GitHub repository](https://github.com/yourusername/atsaf)
and the Jupyter notebook in `/chapters/renewable_energy_forecasting.ipynb`.
    """)


def run_pipeline_from_dashboard(db_path: str, regions: list, fuel_type: str):
    """Run the forecasting pipeline from the dashboard."""
    with st.spinner("Refreshing forecasts... (may take 2-3 minutes)"):
        try:
            import os

            from src.renewable.jobs import run_hourly

            # IMPORTANT: Force pipeline to run even if no new data
            # When user explicitly clicks "Run Pipeline", they want it to run
            # regardless of data freshness (which is only relevant for automated cron jobs)
            original_force_run = os.environ.get("FORCE_RUN")
            os.environ["FORCE_RUN"] = "true"

            try:
                # Run the hourly pipeline job
                run_hourly.main()
            finally:
                # Restore original FORCE_RUN setting
                if original_force_run is None:
                    os.environ.pop("FORCE_RUN", None)
                else:
                    os.environ["FORCE_RUN"] = original_force_run

            st.success("Pipeline completed! Forecasts have been updated with latest EIA data.")
            st.info("Reloading page to show new forecasts...")

            # Wait a moment then reload
            import time
            time.sleep(2)
            st.rerun()

        except Exception as e:
            st.error(f"Pipeline failed: {e}")
            import traceback
            with st.expander("Error details"):
                st.code(traceback.format_exc())


if __name__ == "__main__":
    main()


Overwriting src/renewable/dashboard.py


In [15]:
%%writefile src/renewable/jobs/run_hourly.py
# file: src/renewable/jobs/run_hourly.py
"""Hourly renewable pipeline entry point with validation."""

from __future__ import annotations

import json
import os
from datetime import datetime, timezone
from pathlib import Path
import pandas as pd
from dotenv import load_dotenv

from src.renewable.tasks import RenewablePipelineConfig, run_full_pipeline
from src.renewable.validation import validate_generation_df
from src.renewable.data_freshness import check_all_series_freshness, FreshnessCheckResult

load_dotenv()

# Suppress DeprecationWarning from statsforecast library (invalid escape sequences)
# These are in third-party code we cannot fix directly
import warnings
warnings.filterwarnings('ignore', category=DeprecationWarning, module='statsforecast')


def _env_list(name: str, default_csv: str) -> list[str]:
    raw = os.getenv(name, default_csv)
    return [item.strip() for item in raw.split(",") if item.strip()]


def _env_int(name: str, default: int) -> int:
    raw = os.getenv(name, str(default))
    try:
        return int(raw)
    except ValueError:
        return default


def _env_float(name: str, default: float) -> float:
    raw = os.getenv(name, str(default))
    try:
        return float(raw)
    except ValueError:
        return default


def _expected_series(regions: list[str], fuel_types: list[str]) -> list[str]:
    return [f"{region}_{fuel}" for region in regions for fuel in fuel_types]


def _json_default(value: object) -> str:
    if isinstance(value, pd.Timestamp):
        return value.isoformat()
    if isinstance(value, datetime):
        return value.isoformat()
    if hasattr(value, "item"):
        try:
            return value.item()
        except Exception:
            return str(value)
    return str(value)


def _summarize_generation_coverage(df: pd.DataFrame) -> dict:
    if df.empty:
        return {"row_count": 0, "series_count": 0, "coverage": []}

    coverage = (
        df.groupby("unique_id")["ds"]
        .agg(min_ds="min", max_ds="max", rows="count")
        .reset_index()
        .sort_values("unique_id")
    )
    return {
        "row_count": int(len(df)),
        "series_count": int(df["unique_id"].nunique()),
        "coverage": coverage.to_dict(orient="records"),
    }


def _read_previous_run_summary(data_dir: str) -> dict | None:
    """Read previous run_log.json for rowcount comparison."""
    path = Path(data_dir) / "run_log.json"
    if not path.exists():
        return None
    try:
        return json.loads(path.read_text())
    except Exception:
        return None


def _summarize_negative_forecasts(
    df: pd.DataFrame,
    sample_rows: int = 5,
) -> dict:
    if df.empty or "yhat" not in df.columns:
        return {
            "row_count": int(len(df)),
            "negative_rows": 0,
            "series": [],
            "sample": [],
        }

    neg = df[df["yhat"] < 0]
    if neg.empty:
        return {
            "row_count": int(len(df)),
            "negative_rows": 0,
            "series": [],
            "sample": [],
        }

    series_summary = (
        neg.groupby("unique_id")["yhat"]
        .agg(count="count", min_value="min", max_value="max", mean_value="mean")
        .reset_index()
        .sort_values("unique_id")
    )
    sample = neg[["unique_id", "ds", "yhat"]].head(sample_rows)
    return {
        "row_count": int(len(df)),
        "negative_rows": int(len(neg)),
        "series": series_summary.to_dict(orient="records"),
        "sample": sample.to_dict(orient="records"),
    }


def run_hourly_pipeline() -> dict:
    data_dir = os.getenv("RENEWABLE_DATA_DIR", "data/renewable")
    regions = _env_list("RENEWABLE_REGIONS", "CALI,ERCO,MISO")
    fuel_types = _env_list("RENEWABLE_FUELS", "WND,SUN")
    lookback_days = _env_int("LOOKBACK_DAYS", 30)

    # Horizon configuration: support both preset and direct override
    horizon_preset = os.getenv("RENEWABLE_HORIZON_PRESET", None)  # "24h" | "48h" | "72h"
    horizon_override = _env_int("RENEWABLE_HORIZON", 0)  # Legacy direct override

    # If direct override is set, use it; otherwise use preset (or None for default)
    if horizon_override > 0:
        horizon = horizon_override
        horizon_preset = None  # Ignore preset if direct override is set
    else:
        horizon = 24  # Default, may be overridden by preset

    cv_windows = _env_int("RENEWABLE_CV_WINDOWS", 2)
    cv_step_size = _env_int("RENEWABLE_CV_STEP_SIZE", 168)

    start_date = os.getenv("RENEWABLE_START_DATE", "")
    end_date = os.getenv("RENEWABLE_END_DATE", "")

    # Check if we should force run (e.g., manual dispatch from dashboard)
    force_run = os.getenv("FORCE_RUN", "false").lower() == "true"

    # DEBUG: Log force_run status
    print(f"[pipeline] FORCE_RUN={force_run}")

    # Data freshness check - skip full pipeline if no new data
    if not force_run:
        api_key = os.getenv("EIA_API_KEY", "")
        if not api_key:
            print("WARNING: EIA_API_KEY not set, skipping freshness check")
        else:
            run_log_path = Path(data_dir) / "run_log.json"
            freshness = check_all_series_freshness(
                regions=regions,
                fuel_types=fuel_types,
                run_log_path=run_log_path,
                api_key=api_key,
            )

            if not freshness.has_new_data:
                # No new data - return early with skip status
                skip_log = {
                    "run_at_utc": datetime.now(timezone.utc).isoformat(),
                    "status": "skipped",
                    "reason": "no_new_data",
                    "freshness_check": {
                        "checked_at_utc": freshness.checked_at_utc,
                        "summary": freshness.summary,
                        "series_status": freshness.series_status,
                    },
                    "config": {
                        "regions": regions,
                        "fuel_types": fuel_types,
                        "data_dir": data_dir,
                    },
                }

                # Write skip log (append to run_log.json)
                Path(data_dir).mkdir(parents=True, exist_ok=True)
                skip_log_path = Path(data_dir) / "skip_log.json"
                skip_log_path.write_text(
                    json.dumps(skip_log, indent=2, default=_json_default)
                )

                print(f"SKIPPED: {freshness.summary}")
                print(f"Skip log written to: {skip_log_path}")

                # Set output for GitHub Actions
                github_output = os.getenv("GITHUB_OUTPUT")
                if github_output:
                    with open(github_output, "a") as f:
                        f.write("status=skipped\n")

                return skip_log

            print(f"Freshness check: {freshness.summary}")
    else:
        print("[pipeline] FORCE_RUN=true - bypassing freshness check (manual run requested)")
        print("[pipeline] Pipeline will run regardless of data freshness")

    cfg = RenewablePipelineConfig(
        regions=regions,
        fuel_types=fuel_types,
        lookback_days=lookback_days,
        horizon=horizon,
        horizon_preset=horizon_preset,  # Apply preset if specified
        cv_windows=cv_windows,          # Pass to constructor to validate with correct value
        cv_step_size=cv_step_size,      # Pass to constructor
        data_dir=data_dir,
        overwrite=True,
        start_date=start_date,
        end_date=end_date,
    )
    # Post-construction assignments no longer needed

    # Add option to skip EDA for fast iteration (set SKIP_EDA=true in .env)
    skip_eda = os.getenv("SKIP_EDA", "false").lower() == "true"

    fetch_diagnostics: list[dict] = []
    results = run_full_pipeline(
        cfg,
        fetch_diagnostics=fetch_diagnostics,
        skip_eda=skip_eda,
    )

    gen_path = cfg.generation_path()
    gen_df = pd.read_parquet(gen_path)
    generation_coverage = _summarize_generation_coverage(gen_df)

    max_lag_hours = _env_int("MAX_LAG_HOURS", 48)  # EIA publishes with 12-24h delay
    max_missing_ratio = _env_float("MAX_MISSING_RATIO", 0.02)
    report = validate_generation_df(
        gen_df,
        max_lag_hours=max_lag_hours,
        max_missing_ratio=max_missing_ratio,
        expected_series=_expected_series(regions, fuel_types),
    )

    forecasts_df = pd.read_parquet(cfg.forecasts_path())
    negative_forecasts = _summarize_negative_forecasts(forecasts_df)

    # Quality gates
    max_rowdrop_pct = _env_float("MAX_ROWDROP_PCT", 0.30)
    max_neg_forecast_ratio = _env_float("MAX_NEG_FORECAST_RATIO", 0.10)

    prev_run = _read_previous_run_summary(data_dir)
    prev_gen_rows = 0
    if prev_run:
        prev_gen_rows = prev_run.get("pipeline_results", {}).get("generation_rows", 0)

    curr_gen_rows = results.get("generation_rows", 0)
    rowdrop_ok = True
    if prev_gen_rows > 0:
        floor_ok = int(prev_gen_rows * (1.0 - max_rowdrop_pct))
        rowdrop_ok = curr_gen_rows >= floor_ok

    neg_forecast_ratio = 0.0
    if negative_forecasts["row_count"] > 0:
        neg_forecast_ratio = (
            negative_forecasts["negative_rows"] / negative_forecasts["row_count"]
        )
    neg_forecast_ok = neg_forecast_ratio <= max_neg_forecast_ratio

    quality_gates = {
        "rowdrop": {
            "ok": rowdrop_ok,
            "prev_rows": prev_gen_rows,
            "curr_rows": curr_gen_rows,
            "max_rowdrop_pct": max_rowdrop_pct,
        },
        "neg_forecast": {
            "ok": neg_forecast_ok,
            "ratio": neg_forecast_ratio,
            "max_ratio": max_neg_forecast_ratio,
        },
    }

    run_log = {
        "run_at_utc": datetime.now(timezone.utc).isoformat(),
        "config": {
            "regions": regions,
            "fuel_types": fuel_types,
            "lookback_days": lookback_days,
            "horizon": horizon,
            "cv_windows": cv_windows,
            "cv_step_size": cv_step_size,
            "data_dir": data_dir,
            "start_date": cfg.start_date,
            "end_date": cfg.end_date,
        },
        "pipeline_results": results,
        "validation": {
            "ok": report.ok,
            "message": report.message,
            "details": report.details,
        },
        "diagnostics": {
            "fetch": fetch_diagnostics,
            "generation_coverage": generation_coverage,
            "negative_forecasts": negative_forecasts,
        },
        "quality_gates": quality_gates,
    }

    Path(data_dir).mkdir(parents=True, exist_ok=True)
    (Path(data_dir) / "run_log.json").write_text(
        json.dumps(run_log, indent=2, default=_json_default)
    )

    # Check validation
    if not report.ok:
        raise SystemExit(f"VALIDATION_FAILED: {report.message} | {report.details}")

    # Check quality gates
    if not rowdrop_ok:
        raise SystemExit(
            f"QUALITY_GATE_FAILED: rowdrop | "
            f"curr={curr_gen_rows} prev={prev_gen_rows} max_drop={max_rowdrop_pct:.0%}"
        )
    if not neg_forecast_ok:
        raise SystemExit(
            f"QUALITY_GATE_FAILED: neg_forecast | "
            f"ratio={neg_forecast_ratio:.1%} max={max_neg_forecast_ratio:.0%}"
        )

    # Set output for GitHub Actions (successful run)
    github_output = os.getenv("GITHUB_OUTPUT")
    if github_output:
        with open(github_output, "a") as f:
            f.write("status=success\n")

    return run_log


def main() -> None:
    run_hourly_pipeline()


if __name__ == "__main__":
    main()


Overwriting src/renewable/jobs/run_hourly.py


In [16]:
%%writefile src/renewable/dag_builder.py
# file: src/renewable/dag_builder.py
"""Renewable pipeline DAG builder for Airflow."""
from __future__ import annotations

from datetime import datetime, timedelta
from typing import Any, Dict, Optional

from airflow import DAG
from airflow.operators.python import PythonOperator
AIRFLOW_AVAILABLE = True



DEFAULT_ARGS = {
    "owner": "data-team",
    "depends_on_past": False,
    "email_on_failure": False,
    "email_on_retry": False,
    "retries": 2,
    "retry_delay": timedelta(minutes=5),
}


def build_hourly_dag(
    dag_id: str = "renewable_hourly_pipeline",
    schedule: str = "17 * * * *",
    start_date: Optional[datetime] = None,
    default_args: Optional[Dict[str, Any]] = None,
) -> "DAG":
    if not AIRFLOW_AVAILABLE:
        raise ImportError("Airflow is not installed. Install apache-airflow to use build_hourly_dag().")

    from src.renewable.jobs.run_hourly import run_hourly_pipeline

    if start_date is None:
        start_date = datetime.utcnow() - timedelta(days=1)
    if default_args is None:
        default_args = DEFAULT_ARGS.copy()

    with DAG(
        dag_id=dag_id,
        default_args=default_args,
        description="Renewable hourly pipeline",
        schedule_interval=schedule,
        start_date=start_date,
        catchup=False,
        max_active_runs=1,
        tags=["renewable", "eia", "forecasting"],
    ) as dag:
        PythonOperator(
            task_id="run_hourly",
            python_callable=run_hourly_pipeline,
        )

    return dag


def build_dag_dot() -> str:
    return """digraph RENEWABLE_PIPELINE {
  rankdir=LR;
  node [shape=box, style="rounded,filled", fillcolor="#e8f5e9"];

  run_hourly;
}"""


Overwriting src/renewable/dag_builder.py


# git actions

In [17]:
%%writefile .github/workflows/pre-commit.yml
# file: .github/workflows/pre-commit.yml
name: pre-commit

on:
  pull_request:
  push:
    branches:
      - main

jobs:
  run:
    runs-on: ubuntu-latest
    steps:
      - uses: actions/checkout@v4

      - uses: actions/setup-python@v5
        with:
          python-version: "3.11"

      - uses: pre-commit/action@v3.0.1

      - name: Install test dependencies
        run: |
          python -m pip install --upgrade pip
          pip install pytest pandas numpy requests python-dotenv

      - name: Run smoke tests
        env:
          PYTHONPATH: ${{ github.workspace }}
        run: pytest tests/ -v -k "not slow" --tb=short || true


Overwriting .github/workflows/pre-commit.yml


In [18]:
%%writefile .github/workflows/renewable-hourly.yml
# file: .github/workflows/renewable-hourly.yml
name: renewable-hourly

on:
  workflow_dispatch:
    inputs:
      force_run:
        description: 'Force full pipeline run (skip freshness check)'
        type: boolean
        default: false
  schedule:
    - cron: "17 * * * *"

permissions:
  contents: write

concurrency:
  group: renewable-hourly
  cancel-in-progress: true

jobs:
  update:
    runs-on: ubuntu-latest
    timeout-minutes: 25
    env:
      EIA_API_KEY: ${{ secrets.EIA_API_KEY }}
      FORCE_RUN: ${{ github.event_name == 'workflow_dispatch' && inputs.force_run && 'true' || 'false' }}
      RENEWABLE_REGIONS: "CALI,ERCO,MISO"
      RENEWABLE_FUELS: "WND,SUN"
      LOOKBACK_DAYS: "30"
      RENEWABLE_HORIZON: "24"
      RENEWABLE_CV_WINDOWS: "2"
      RENEWABLE_CV_STEP_SIZE: "168"
      MAX_LAG_HOURS: "48"  # EIA publishes hourly data with 12-24h delay
      MAX_MISSING_RATIO: "0.02"
      RENEWABLE_DATA_DIR: "data/renewable"
      RENEWABLE_N_JOBS: "1"
      OMP_NUM_THREADS: "1"
      MKL_NUM_THREADS: "1"
      OPENBLAS_NUM_THREADS: "1"
      NUMBA_NUM_THREADS: "1"
      VECLIB_MAXIMUM_THREADS: "1"
    steps:
      - uses: actions/checkout@v4

      - uses: actions/setup-python@v5
        with:
          python-version: "3.11"

      - name: Check EIA API key
        run: |
          if [ -z "$EIA_API_KEY" ]; then
            echo "EIA_API_KEY is not set. Add it to repo secrets." >&2
            exit 1
          fi

      - name: Install deps
        run: |
          python -m pip install --upgrade pip
          # Install from pyproject.toml for single source of truth
          # Use -e for editable install (allows imports to work correctly)
          pip install -e .

      - name: Run hourly pipeline
        id: pipeline
        run: |
          python -m src.renewable.jobs.run_hourly

      - name: Quality gate check
        if: steps.pipeline.outputs.status != 'skipped'
        run: |
          python -c "
          import json, sys
          from pathlib import Path
          log_path = Path('data/renewable/run_log.json')
          if not log_path.exists():
              print('No run_log.json found')
              
          log = json.loads(log_path.read_text())
          val = log.get('validation', {})
          if not val.get('ok'):
              print(f'VALIDATION FAILED: {val.get(\"message\")}')
              print(f'Details: {val.get(\"details\")}')
              
          gates = log.get('quality_gates', {})
          if not gates.get('rowdrop', {}).get('ok', True):
              print(f'ROWDROP GATE FAILED: {gates.get(\"rowdrop\")}')
              
          if not gates.get('neg_forecast', {}).get('ok', True):
              print(f'NEG_FORECAST GATE FAILED: {gates.get(\"neg_forecast\")}')
              
          print('QUALITY GATES PASSED')
          "

      - name: Skip notification
        if: steps.pipeline.outputs.status == 'skipped'
        run: |
          echo "### Pipeline skipped - no new EIA data" >> "$GITHUB_STEP_SUMMARY"
          if [ -f data/renewable/skip_log.json ]; then
            python -c "
          import json
          from pathlib import Path
          data = json.loads(Path('data/renewable/skip_log.json').read_text())
          freshness = data.get('freshness_check', {})
          print(f'- Checked at: {freshness.get(\"checked_at_utc\")}')
          print(f'- Summary: {freshness.get(\"summary\")}')
          " >> "$GITHUB_STEP_SUMMARY"
          fi

      - name: Summarize run
        if: always() && steps.pipeline.outputs.status != 'skipped'
        run: |
          if [ -f data/renewable/run_log.json ]; then
          python - <<'PY' | tee -a "$GITHUB_STEP_SUMMARY"
          import json
          from pathlib import Path

          data = json.loads(Path("data/renewable/run_log.json").read_text())
          validation = data.get("validation", {})
          details = validation.get("details", {})
          pipeline = data.get("pipeline_results", {})
          interp = pipeline.get("interpretability", {})

          lines = [
              "### Renewable hourly run",
              f"- run_at_utc: {data.get('run_at_utc')}",
              f"- validation_ok: {validation.get('ok')}",
              f"- message: {validation.get('message')}",
              f"- max_ds: {details.get('max_ds')}",
              f"- lag_hours: {details.get('lag_hours')}",
              f"- best_model: {pipeline.get('best_model')}",
              f"- best_rmse: {pipeline.get('best_rmse', 0):.1f}",
              "",
              "#### Interpretability",
              f"- series_count: {interp.get('series_count', 0)}",
              f"- output_dir: {interp.get('output_dir', 'N/A')}",
          ]
          print("\n".join(lines))
          PY
          else
          echo "No run_log.json found." >> "$GITHUB_STEP_SUMMARY"
          fi

      - name: Commit updated artifacts
        if: steps.pipeline.outputs.status != 'skipped'
        run: |
          git config user.name "github-actions[bot]"
          git config user.email "41898282+github-actions[bot]@users.noreply.github.com"
          git add data/renewable/generation.parquet \
            data/renewable/weather.parquet \
            data/renewable/forecasts.parquet \
            data/renewable/run_log.json
          # Add interpretability artifacts if they exist
          if [ -d data/renewable/interpretability ]; then
            git add data/renewable/interpretability/
          fi
          git commit -m "renewable: hourly data update (UTC)" || echo "No changes to commit"
          git push


Overwriting .github/workflows/renewable-hourly.yml
