# todos:
- update preprocessing data so it's organized and in the data pull rather than in the preprocessing
- add in eda module to explore data and show decisions made based on that
- go over the full explanation of why we would choose this data, why we would choose this model, 
    - ensure to include decision making and thought process not just end results, 
    - archive the notebooks, 
    - update the readme, 
    - ensure this is software that is automated, 
    - add in mermaid graph to readme and linkedin post
    - add in start up instructions to readme
    - finally post on linkedin

# Renewable Energy Forecasting Pipeline

This notebook walks through building a **next-24h renewable generation forecast system** with:

- **EIA data integration** - Hourly wind/solar generation for US regions
- **Weather features** - Open-Meteo integration (wind speed, solar radiation)
- **Probabilistic forecasting** - Dual prediction intervals (80%, 95%)
- **Drift monitoring** - Automatic detection of model degradation

## Architecture Overview

```
```
┌─────────────┐      ┌──────────────┐      ┌─────────────┐
│  EIA API    │──┬──▶│ Data         │──┬──▶│ StatsForecast│
│ (WND/SUN)   │  │   │ Pipeline     │  │   │ Models       │
└─────────────┘  │   └──────────────┘  │   └─────────────┘
                 │                     │           │
┌─────────────┐  │   ┌──────────────┐  │   ┌─────▼──────┐
│ Open-Meteo  │──┘   │ Validation   │  │   │Probabilistic│
│ Weather API │      │ & Quality    │  │   │Forecasts    │
└─────────────┘      │ Gates        │  │   │(80%, 95% CI)│
                     └──────────────┘  │   └────────────┘
                                       │           │
                                       │   ┌───────▼─────┐
                                       └──▶│  Artifacts  │
                                           │  Commit &   │
                                           │  Dashboard  │
                                           └─────────────┘
```

## Key Concepts

1. **StatsForecast format**: `[unique_id, ds, y]` - where `unique_id` = `{region}_{fuel_type}`
2. **Zero-value handling**: Solar generates 0 at night - we use RMSE/MAE, NOT MAPE
3. **Leakage prevention**: Use **forecasted** weather for predictions, not historical
4. **Drift detection**: Threshold = mean + 2*std from backtest

## Setup

First, let's ensure we have the project root in our path and configure logging.

In [1]:
import sys
import logging
from pathlib import Path
import os 

# Add project root to path
project_root = r"c:\docker_projects\atsaf"
if str(project_root) not in sys.path:
    sys.path.insert(0, str(project_root))

if os.getcwd() != str(project_root):
    os.chdir(project_root)
    print(f"Changed working directory to project root: {project_root} we are currently at {os.getcwd()}")

# Configure logging for visibility
logging.basicConfig(
    level=logging.INFO,
    format="%(asctime)s - %(name)s - %(levelname)s - %(message)s"
)

print(f"Project root: {project_root}")

Changed working directory to project root: c:\docker_projects\atsaf we are currently at c:\docker_projects\atsaf
Project root: c:\docker_projects\atsaf


---

# Module 1: Region Definitions

**File:** `src/renewable/regions.py`

This module maps **EIA balancing authority regions** to their geographic coordinates. Why do we need coordinates?

- **Weather API lookup**: Open-Meteo requires latitude/longitude
- **Regional analysis**: Compare forecast accuracy across regions
- **Timezone handling**: Each region has a primary timezone

## Key Design Decisions

1. **NamedTuple for RegionInfo**: Immutable, type-safe, and memory-efficient
2. **Centroid coordinates**: Approximate centers - good enough for hourly weather
3. **Fuel type codes**: `WND` (wind), `SUN` (solar) - match EIA's API

In [None]:
%%writefile src/renewable/regions.py
# src/renewable/regions.py
from __future__ import annotations

from typing import NamedTuple, Optional


class RegionInfo(NamedTuple):
    """Region metadata for EIA and weather lookups."""
    name: str
    lat: float
    lon: float
    timezone: str
    # Some internal regions may not map cleanly to an EIA respondent.
    # We keep them in REGIONS for weather/features, but EIA fetch requires this.
    eia_respondent: Optional[str] = None


REGIONS: dict[str, RegionInfo] = {
    # Western Interconnection
    "CALI": RegionInfo(
        name="California ISO",
        lat=36.7,
        lon=-119.4,
        timezone="America/Los_Angeles",
        eia_respondent="CISO",
    ),
    "NW": RegionInfo(
        name="Northwest",
        lat=45.5,
        lon=-122.0,
        timezone="America/Los_Angeles",
        eia_respondent=None,  # intentionally unset until verified
    ),
    "SW": RegionInfo(
        name="Southwest",
        lat=33.5,
        lon=-112.0,
        timezone="America/Phoenix",
        eia_respondent=None,  # intentionally unset until verified
    ),

    # Texas Interconnection
    "ERCO": RegionInfo(
        name="ERCOT (Texas)",
        lat=31.0,
        lon=-100.0,
        timezone="America/Chicago",
        eia_respondent="ERCO",
    ),

    # Midwest
    "MISO": RegionInfo(
        name="Midcontinent ISO",
        lat=41.0,
        lon=-93.0,
        timezone="America/Chicago",
        eia_respondent="MISO",
    ),
    "PJM": RegionInfo(
        name="PJM Interconnection",
        lat=39.0,
        lon=-77.0,
        timezone="America/New_York",
        eia_respondent="PJM",
    ),
    "SWPP": RegionInfo(
        name="Southwest Power Pool",
        lat=37.0,
        lon=-97.0,
        timezone="America/Chicago",
        eia_respondent="SWPP",
    ),

    # Internal/aggregate regions kept for non-EIA use (weather/features/etc.)
    "SE": RegionInfo(name="Southeast", lat=33.0, lon=-84.0, timezone="America/New_York", eia_respondent=None),
    "FLA": RegionInfo(name="Florida", lat=28.0, lon=-82.0, timezone="America/New_York", eia_respondent=None),
    "CAR": RegionInfo(name="Carolinas", lat=35.5, lon=-80.0, timezone="America/New_York", eia_respondent=None),
    "TEN": RegionInfo(name="Tennessee Valley", lat=35.5, lon=-86.0, timezone="America/Chicago", eia_respondent=None),

    "US48": RegionInfo(name="Lower 48 States", lat=39.8, lon=-98.5, timezone="America/Chicago", eia_respondent=None),
}

FUEL_TYPES = {"WND": "Wind", "SUN": "Solar"}


def list_regions() -> list[str]:
    return sorted(REGIONS.keys())


def get_region_info(region_code: str) -> RegionInfo:
    return REGIONS[region_code]


def get_region_coords(region_code: str) -> tuple[float, float]:
    r = REGIONS[region_code]
    return (r.lat, r.lon)


def get_eia_respondent(region_code: str) -> str:
    """Return the code EIA expects for facets[respondent][]. Fail loudly if missing."""
    info = REGIONS[region_code]
    if not info.eia_respondent:
        raise ValueError(
            f"Region '{region_code}' has no configured eia_respondent. "
            f"Set REGIONS['{region_code}'].eia_respondent to a verified EIA respondent code "
            f"before using it for EIA fetches."
        )
    return info.eia_respondent


def validate_region(region_code: str) -> bool:
    return region_code in REGIONS


def validate_fuel_type(fuel_type: str) -> bool:
    return fuel_type in FUEL_TYPES



if __name__ == "__main__":
    # Example run - test region functions

    print("=== Available Regions ===")
    print(f"Total regions: {len(REGIONS)}")
    print(f"Region codes: {list_regions()}")

    print("\n=== Example: California ===")
    cali_info = get_region_info("CALI")
    print(f"Name: {cali_info.name}")
    print(f"Coordinates: ({cali_info.lat}, {cali_info.lon})")
    print(f"Timezone: {cali_info.timezone}")

    print("\n=== Weather API Coordinates ===")
    for region in ["CALI", "ERCO", "MISO"]:
        lat, lon = get_region_coords(region)
        print(f"{region}: lat={lat}, lon={lon}")

    print("\n=== Fuel Types ===")
    for code, name in FUEL_TYPES.items():
        print(f"{code}: {name}")

    print("\n=== Validation ===")
    print(f"validate_region('CALI'): {validate_region('CALI')}")
    print(f"validate_region('INVALID'): {validate_region('INVALID')}")
    print(f"validate_fuel_type('WND'): {validate_fuel_type('WND')}")


Overwriting src/renewable/regions.py


: 

: 

### Example: Using Region Definitions

---

# Module 2: EIA Data Fetcher

**File:** `src/renewable/eia_renewable.py`

This module fetches **hourly wind and solar generation** from the EIA API.

## Critical Concepts

### StatsForecast Format
StatsForecast expects data in a specific format:
```
unique_id | ds                  | y
----------|---------------------|--------
CALI_WND  | 2024-01-01 00:00:00 | 1234.5
CALI_WND  | 2024-01-01 01:00:00 | 1456.7
ERCO_WND  | 2024-01-01 00:00:00 | 2345.6
```

- `unique_id`: Identifies the time series (e.g., "CALI_WND" = California Wind)
- `ds`: Datetime column (timezone-naive UTC)
- `y`: Target value (generation in MWh)

### API Rate Limiting
- EIA API has rate limits (~5 requests/second)
- We use controlled parallelism with delays

In [None]:
# %%writefile src/renewable/eia_renewable.py
# src/renewable/eia_renewable.py
from __future__ import annotations

import logging
import os
import time
from concurrent.futures import ThreadPoolExecutor, as_completed
from pathlib import Path
from typing import Optional
from urllib.parse import parse_qsl, urlencode, urlsplit, urlunsplit

import pandas as pd
import requests
from requests.adapters import HTTPAdapter
from urllib3.util.retry import Retry
from dotenv import find_dotenv, load_dotenv

from src.renewable.regions import (REGIONS, get_eia_respondent,
                                   validate_fuel_type, validate_region)

logger = logging.getLogger(__name__)


def _sanitize_url(url: str) -> str:
    parts = urlsplit(url)
    q = [(k, v) for k, v in parse_qsl(parts.query, keep_blank_values=True) if k.lower() != "api_key"]
    return urlunsplit((parts.scheme, parts.netloc, parts.path, urlencode(q), parts.fragment))


def _load_env_once(*, debug: bool = False) -> Optional[str]:
    """
    Load .env if present.
    - Primary: find_dotenv(usecwd=True) (walk up from CWD)
    - Fallback: repo_root/.env based on this file location
    Returns the path loaded (or None).
    """
    # 1) Try from current working directory upward
    dotenv_path = find_dotenv(usecwd=True)
    if dotenv_path:
        load_dotenv(dotenv_path, override=False)
        if debug:
            logger.info("Loaded .env via find_dotenv: %s", dotenv_path)
        return dotenv_path

    # 2) Fallback: assume src-layout -> repo root is ../../ from this file
    try:
        repo_root = Path(__file__).resolve().parents[2]
        fallback = repo_root / ".env"
        if fallback.exists():
            load_dotenv(fallback, override=False)
            if debug:
                logger.info("Loaded .env via fallback: %s", str(fallback))
            return str(fallback)
    except Exception:
        pass

    if debug:
        logger.info("No .env found to load.")
    return None


class EIARenewableFetcher:
    BASE_URL = "https://api.eia.gov/v2/electricity/rto/fuel-type-data/data/"
    MAX_RECORDS_PER_REQUEST = 5000
    RATE_LIMIT_DELAY = 0.2  # 5 requests/second max

    def __init__(self, api_key: Optional[str] = None, *, timeout: int = 60, debug_env: bool = False):
        """
        Initialize API key and configuration.

        Args:
            api_key: EIA API key (or reads from EIA_API_KEY env var)
            timeout: Request timeout in seconds (default: 60)
            debug_env: Enable debug logging for environment loading
        """
        loaded_env = _load_env_once(debug=debug_env)

        self.api_key = api_key or os.getenv("EIA_API_KEY")
        if not self.api_key:
            raise ValueError(
                "EIA API key required but not found.\n"
                "- Ensure .env contains EIA_API_KEY=...\n"
                "- Ensure your process CWD is under the repo (so find_dotenv can locate it), OR\n"
                "- Pass api_key=... explicitly.\n"
                f"Loaded .env path: {loaded_env}"
            )

        self.timeout = timeout
        self.session = self._create_session()  # Add retry-enabled session

        # Debug without leaking the key
        if debug_env:
            masked = self.api_key[:4] + "..." + self.api_key[-4:] if len(self.api_key) >= 8 else "***"
            logger.info("EIA_API_KEY loaded (masked): %s", masked)
            logger.info("Request timeout: %d seconds", self.timeout)

    def _create_session(self) -> requests.Session:
        """Create requests Session with retry logic for transient errors."""
        session = requests.Session()
        retries = Retry(
            total=3,
            backoff_factor=1.0,  # 1s, 2s, 4s between retries
            status_forcelist=[429, 500, 502, 503, 504],  # Retry on server errors and rate limits
            allowed_methods=frozenset(["GET"]),
            connect=3,  # Retry on connection errors
            read=3,     # Retry on read timeouts
        )
        session.mount("https://", HTTPAdapter(max_retries=retries))
        return session

    @staticmethod
    def _extract_eia_response(payload: dict, *, request_url: Optional[str] = None) -> tuple[list[dict], dict]:
        if not isinstance(payload, dict):
            raise TypeError(f"EIA payload is not a dict. type={type(payload)} url={request_url}")

        if "error" in payload and payload.get("response") is None:
            raise ValueError(f"EIA returned error payload. url={request_url} error={payload.get('error')}")

        if "response" not in payload:
            raise ValueError(
                f"EIA payload missing 'response'. url={request_url} keys={list(payload.keys())[:25]}"
            )

        response = payload.get("response") or {}
        if not isinstance(response, dict):
            raise TypeError(f"EIA payload['response'] is not a dict. type={type(response)} url={request_url}")

        if "data" not in response:
            raise ValueError(
                f"EIA response missing 'data'. url={request_url} response_keys={list(response.keys())[:25]}"
            )

        records = response.get("data") or []
        if not isinstance(records, list):
            raise TypeError(f"EIA response['data'] is not a list. type={type(records)} url={request_url}")

        total = response.get("total", None)
        offset = response.get("offset", None)

        meta_obj = response.get("metadata") or {}
        if isinstance(meta_obj, dict):
            if total is None and "total" in meta_obj:
                total = meta_obj.get("total")
            if offset is None and "offset" in meta_obj:
                offset = meta_obj.get("offset")

        try:
            total = int(total) if total is not None else None
        except Exception:
            pass
        try:
            offset = int(offset) if offset is not None else None
        except Exception:
            pass

        return records, {"total": total, "offset": offset}

    def fetch_region(
        self,
        region: str,
        fuel_type: str,
        start_date: str,
        end_date: str,
        *,
        debug: bool = False,
        diag: Optional[dict] = None,
    ) -> pd.DataFrame:
        if not validate_region(region):
            raise ValueError(f"Invalid region: {region}")
        if not validate_fuel_type(fuel_type):
            raise ValueError(f"Invalid fuel type: {fuel_type}")

        respondent = get_eia_respondent(region)

        all_records: list[dict] = []
        offset = 0

        # ✅ FIX: initialize loop diagnostics counters
        page_count = 0
        total_hint: Optional[int] = None

        while True:
            params = {
                "api_key": self.api_key,
                "data[]": "value",
                "facets[respondent][]": respondent,
                "facets[fueltype][]": fuel_type,
                "frequency": "hourly",
                "start": f"{start_date}T00",
                "end": f"{end_date}T23",
                "length": self.MAX_RECORDS_PER_REQUEST,
                "offset": offset,
                "sort[0][column]": "period",
                "sort[0][direction]": "asc",
            }

            resp = self.session.get(self.BASE_URL, params=params, timeout=self.timeout)
            resp.raise_for_status()
            payload = resp.json()

            records, meta = self._extract_eia_response(payload, request_url=resp.url)

            page_count += 1
            if total_hint is None:
                total_hint = meta.get("total")

            returned = len(records)

            if debug:
                safe_url = _sanitize_url(resp.url)
                print(
                    f"[PAGE] region={region} fuel={fuel_type} returned={returned} "
                    f"offset={offset} total={meta.get('total')} url={safe_url}"
                )

            # Empty on first page: legitimate empty series for that window
            if returned == 0 and offset == 0:
                if diag is not None:
                    diag.update({
                        "region": region,
                        "fuel_type": fuel_type,
                        "start_date": start_date,
                        "end_date": end_date,
                        "total_records": total_hint,
                        "pages": page_count,
                        "rows_parsed": 0,
                        "empty": True,
                    })
                return pd.DataFrame(columns=["ds", "value", "region", "fuel_type"])

            if returned == 0:
                break

            all_records.extend(records)

            if returned < self.MAX_RECORDS_PER_REQUEST:
                break

            offset += self.MAX_RECORDS_PER_REQUEST
            time.sleep(self.RATE_LIMIT_DELAY)

        df = pd.DataFrame(all_records)

        missing_cols = [c for c in ["period", "value"] if c not in df.columns]
        if missing_cols:
            sample_keys = sorted(set().union(*(r.keys() for r in all_records[:5]))) if all_records else []
            raise ValueError(
                f"EIA records missing expected keys {missing_cols}. "
                f"columns={df.columns.tolist()} sample_record_keys={sample_keys}"
            )

        # EIA returns timestamps in UTC format WITHOUT timezone marker (e.g., "2026-01-21T00")
        # Simply parse and treat as UTC (no conversion needed)
        df["ds"] = pd.to_datetime(df["period"], utc=True, errors="coerce").dt.tz_localize(None)

        df["value"] = pd.to_numeric(df["value"], errors="coerce")

        df["region"] = region
        df["fuel_type"] = fuel_type

        df = df.dropna(subset=["ds", "value"]).sort_values("ds").reset_index(drop=True)

        # Log negative values for investigation (but don't clamp - let dataset builder handle)
        neg_mask = df["value"] < 0
        if neg_mask.any():
            neg_count = int(neg_mask.sum())
            neg_min = float(df.loc[neg_mask, "value"].min())
            neg_max = float(df.loc[neg_mask, "value"].max())
            neg_pct = 100 * neg_count / max(len(df), 1)
            logger.warning(
                "[fetch_region][NEGATIVE] region=%s fuel=%s count=%d (%.1f%%) range=[%.2f, %.2f]",
                region, fuel_type, neg_count, neg_pct, neg_min, neg_max,
            )
            # Log sample for debugging
            neg_sample = df.loc[neg_mask, ["ds", "value"]].head(5)
            for _, row in neg_sample.iterrows():
                logger.debug("  ds=%s value=%.2f", row["ds"], row["value"])

            # NOTE: Keeping negative values in raw data for transparency
            # Dataset builder will handle negatives according to configured policy

        if diag is not None:
            diag.update({
                "region": region,
                "fuel_type": fuel_type,
                "start_date": start_date,
                "end_date": end_date,
                "total_records": total_hint,
                "pages": page_count,
                "rows_parsed": int(len(df)),
                "empty": bool(len(df) == 0),
            })

        return df[["ds", "value", "region", "fuel_type"]]

    def fetch_all_regions(
        self,
        fuel_type: str,
        start_date: str,
        end_date: str,
        regions: Optional[list[str]] = None,
        max_workers: int = 3,
        diagnostics: Optional[list[dict]] = None,
    ) -> pd.DataFrame:
        """Fetch generation data for all regions for a given fuel type.

        Args:
            fuel_type: Fuel type code (WND, SUN, etc.)
            start_date: Start date (YYYY-MM-DD)
            end_date: End date (YYYY-MM-DD)
            regions: List of region codes (defaults to all non-US48 regions)
            max_workers: Number of parallel workers
            diagnostics: Optional list to collect diagnostic info

        Returns:
            DataFrame with columns [unique_id, ds, y]

        Raises:
            RuntimeError: If no regions could be fetched (complete failure)
        """
        if regions is None:
            regions = [r for r in REGIONS.keys() if r != "US48"]

        all_dfs: list[pd.DataFrame] = []
        failed_regions: list[tuple[str, str]] = []  # (region, error_msg)

        def _run_one(region: str) -> tuple[str, pd.DataFrame, dict]:
            d: dict = {}
            df = self.fetch_region(region, fuel_type, start_date, end_date, diag=d)
            return region, df, d

        with ThreadPoolExecutor(max_workers=max_workers) as executor:
            futures = {executor.submit(_run_one, region): region for region in regions}
            for future in as_completed(futures):
                region = futures[future]
                try:
                    _, df, d = future.result()
                    if diagnostics is not None:
                        diagnostics.append(d)

                    if len(df) > 0:
                        all_dfs.append(df)
                        print(f"[OK] {region}: {len(df)} rows")
                    else:
                        print(f"[EMPTY] {region}: 0 rows")
                        failed_regions.append((region, "Empty response (0 rows)"))
                except Exception as e:
                    failed_regions.append((region, str(e)))
                    if diagnostics is not None:
                        diagnostics.append({
                            "region": region,
                            "fuel_type": fuel_type,
                            "start_date": start_date,
                            "end_date": end_date,
                            "error": str(e),
                        })
                    print(f"[FAIL] {region}: {e}")

        # Explicit validation: require at least one successful region
        if not all_dfs:
            error_details = "; ".join([f"{r[0]}({r[1][:80]})" for r in failed_regions])
            raise RuntimeError(
                f"[EIA][FETCH] Failed to fetch {fuel_type} data for ALL regions. "
                f"Failures: {error_details}. "
                f"Check EIA API availability, API key validity, network connectivity, "
                f"and consider increasing timeout or reducing concurrency."
            )

        # Warn if partial failure (some regions succeeded, some failed)
        if failed_regions:
            failed_count = len(failed_regions)
            total_count = len(regions)
            print(f"[WARNING] Partial {fuel_type} fetch: {failed_count}/{total_count} regions failed")
            for region, error_msg in failed_regions:
                # Print first 100 chars of error
                print(f"  - {region}: {error_msg[:100]}")

        combined = pd.concat(all_dfs, ignore_index=True)
        combined["unique_id"] = combined["region"] + "_" + combined["fuel_type"]
        combined = combined.rename(columns={"value": "y"})

        result = combined[["unique_id", "ds", "y"]].sort_values(["unique_id", "ds"]).reset_index(drop=True)

        print(f"[SUMMARY] {fuel_type} data: {result['unique_id'].nunique()} series, {len(result)} total rows")

        return result

    def get_series_summary(self, df: pd.DataFrame) -> pd.DataFrame:
        return df.groupby("unique_id").agg(
            count=("y", "count"),
            min_value=("y", "min"),
            max_value=("y", "max"),
            mean_value=("y", "mean"),
            zero_count=("y", lambda x: (x == 0).sum()),
        ).reset_index()


if __name__ == "__main__":
    logging.basicConfig(level=logging.INFO)

    fetcher = EIARenewableFetcher(debug_env=True)

    print("=== Testing Single Region Fetch ===")
    df_single = fetcher.fetch_region("CALI", "WND", "2024-12-01", "2024-12-03", debug=True)
    print(f"Single region: {len(df_single)} rows")
    print(df_single.head())

    print("\n=== Testing Multi-Region Fetch ===")
    df_multi = fetcher.fetch_all_regions("WND", "2024-12-01", "2024-12-03", regions=["CALI", "ERCO", "MISO"])
    print(f"\nMulti-region: {len(df_multi)} rows")
    print(f"Series: {df_multi['unique_id'].unique().tolist()}")

    print("\n=== Series Summary ===")
    print(fetcher.get_series_summary(df_multi))

    # sun checks:
    f = EIARenewableFetcher()
    df = f.fetch_region("CALI", "SUN", "2024-12-01", "2024-12-03", debug=True)
    print(df.head(), len(df))


ModuleNotFoundError: No module named 'src'

: 

: 

---

# Module 3: Weather Integration

**File:** `src/renewable/open_meteo.py`

Weather is **critical** for renewable forecasting:
- **Wind generation** depends on wind speed (especially at hub height ~100m)
- **Solar generation** depends on radiation and cloud cover

## Key Concept: Preventing Leakage

**Data leakage** occurs when training uses information that wouldn't be available at prediction time.

```
❌ WRONG: Using historical weather to predict future generation
   - At prediction time, we don't have future actual weather!
   
✅ CORRECT: Use forecasted weather for predictions
   - Training: historical weather aligned with historical generation
   - Prediction: weather forecast for the prediction horizon
```

## Open-Meteo API

Open-Meteo is **free** and requires no API key:
- Historical API: Past weather data
- Forecast API: Up to 16 days ahead

In [None]:
# %%writefile src/renewable/open_meteo.py
# src/renewable/open_meteo.py
from __future__ import annotations

from dataclasses import dataclass
from datetime import datetime, timedelta
from typing import Optional

import pandas as pd
import requests
from requests.adapters import HTTPAdapter
from urllib3.util.retry import Retry

from src.renewable.regions import get_region_coords, validate_region


@dataclass(frozen=True)
class OpenMeteoEndpoints:
    historical_url: str = "https://archive-api.open-meteo.com/v1/archive"
    forecast_url: str = "https://api.open-meteo.com/v1/forecast"


class OpenMeteoRenewable:
    """
    Fetch weather features for renewable energy forecasting.

    Strict-by-default:
    - If Open-Meteo doesn't return a requested variable, we raise.
    - We do NOT fabricate values or silently "fill" missing columns.
    """

    WEATHER_VARS = [
        "temperature_2m",
        "wind_speed_10m",
        "wind_speed_100m",
        "wind_direction_10m",
        "direct_radiation",
        "diffuse_radiation",
        "cloud_cover",
    ]

    def __init__(self, timeout: int = 60, *, strict: bool = True):
        self.timeout = timeout
        self.strict = strict
        self.endpoints = OpenMeteoEndpoints()
        self.session = self._create_session()

    def _create_session(self) -> requests.Session:
        session = requests.Session()
        retries = Retry(
            total=3,
            backoff_factor=1.0,  # 1s, 2s, 4s between retries
            status_forcelist=[429, 500, 502, 503, 504],
            allowed_methods=frozenset(["GET"]),
            connect=3,  # Retry on connection errors
            read=3,     # Retry on read timeouts
        )
        session.mount("https://", HTTPAdapter(max_retries=retries))
        return session

    def fetch_historical(
        self,
        lat: float,
        lon: float,
        start_date: str,
        end_date: str,
        variables: Optional[list[str]] = None,
        *,
        debug: bool = False,
    ) -> pd.DataFrame:
        if variables is None:
            variables = self.WEATHER_VARS

        params = {
            "latitude": lat,
            "longitude": lon,
            "start_date": start_date,
            "end_date": end_date,
            "hourly": ",".join(variables),
            "timezone": "UTC",
        }

        resp = self.session.get(self.endpoints.historical_url, params=params, timeout=self.timeout)
        if debug:
            print(f"[OPENMETEO][HIST] status={resp.status_code} url={resp.url}")
        resp.raise_for_status()

        try:
            data = resp.json()
        except requests.exceptions.JSONDecodeError as e:
            # Log actual response content for debugging
            content_preview = resp.text[:500] if resp.text else "(empty)"
            raise ValueError(
                f"[OPENMETEO][HIST] Invalid JSON response. "
                f"status={resp.status_code} content_preview={content_preview}"
            ) from e

        return self._parse_response(data, variables, debug=debug, request_url=resp.url)

    def fetch_forecast(
        self,
        lat: float,
        lon: float,
        horizon_hours: int = 48,
        variables: Optional[list[str]] = None,
        *,
        debug: bool = False,
    ) -> pd.DataFrame:
        if variables is None:
            variables = self.WEATHER_VARS

        forecast_days = min((horizon_hours // 24) + 1, 16)
        params = {
            "latitude": lat,
            "longitude": lon,
            "hourly": ",".join(variables),
            "timezone": "UTC",
            "forecast_days": forecast_days,
        }

        resp = self.session.get(self.endpoints.forecast_url, params=params, timeout=self.timeout)
        if debug:
            print(f"[OPENMETEO][FCST] status={resp.status_code} url={resp.url}")
        resp.raise_for_status()

        try:
            data = resp.json()
        except requests.exceptions.JSONDecodeError as e:
            content_preview = resp.text[:500] if resp.text else "(empty)"
            raise ValueError(
                f"[OPENMETEO][FCST] Invalid JSON response. "
                f"status={resp.status_code} content_preview={content_preview}"
            ) from e

        df = self._parse_response(data, variables, debug=debug, request_url=resp.url)

        # Trim to requested horizon (ds is naive UTC)
        if len(df) > 0:
            cutoff = datetime.utcnow() + timedelta(hours=horizon_hours)
            df = df[df["ds"] <= cutoff].reset_index(drop=True)

        return df

    def fetch_for_region(
        self,
        region_code: str,
        start_date: str,
        end_date: str,
        *,
        debug: bool = False,
    ) -> pd.DataFrame:
        if not validate_region(region_code):
            raise ValueError(f"Invalid region_code: {region_code}")

        lat, lon = get_region_coords(region_code)
        df = self.fetch_historical(lat, lon, start_date, end_date, debug=debug)
        df["region"] = region_code
        return df

    def fetch_all_regions_historical(
        self,
        regions: list[str],
        start_date: str,
        end_date: str,
        *,
        debug: bool = False,
    ) -> pd.DataFrame:
        all_dfs: list[pd.DataFrame] = []
        for region in regions:
            try:
                df = self.fetch_for_region(region, start_date, end_date, debug=debug)
                all_dfs.append(df)
                print(f"[OK] Weather for {region}: {len(df)} rows")
            except requests.exceptions.Timeout as e:
                print(f"[FAIL] Weather for {region}: TIMEOUT after {self.timeout}s - {type(e).__name__}: {e}")
            except requests.exceptions.ConnectionError as e:
                print(f"[FAIL] Weather for {region}: CONNECTION_ERROR - {type(e).__name__}: {e}")
            except requests.exceptions.JSONDecodeError as e:
                print(f"[FAIL] Weather for {region}: JSON_PARSE_ERROR - {type(e).__name__}: {e}")
            except Exception as e:
                print(f"[FAIL] Weather for {region}: {type(e).__name__}: {e}")

        if not all_dfs:
            return pd.DataFrame()

        return (
            pd.concat(all_dfs, ignore_index=True)
            .sort_values(["region", "ds"])
            .reset_index(drop=True)
        )

    def _parse_response(
        self,
        data: dict,
        variables: list[str],
        *,
        debug: bool,
        request_url: str,
    ) -> pd.DataFrame:
        hourly = data.get("hourly")
        if not isinstance(hourly, dict):
            raise ValueError(f"Open-Meteo response missing/invalid 'hourly'. url={request_url}")

        times = hourly.get("time")
        if not isinstance(times, list) or len(times) == 0:
            raise ValueError(f"Open-Meteo response has no hourly time grid. url={request_url}")

        # Build ds (naive UTC)
        ds = pd.to_datetime(times, errors="coerce", utc=True).tz_localize(None)
        if ds.isna().any():
            bad = int(ds.isna().sum())
            raise ValueError(f"Open-Meteo returned unparsable times. bad={bad} url={request_url}")

        df_data = {"ds": ds}

        # Strict variable presence: raise if missing (no silent None padding)
        missing_vars = [v for v in variables if v not in hourly]
        if missing_vars and self.strict:
            raise ValueError(f"Open-Meteo missing requested vars={missing_vars}. url={request_url}")

        for var in variables:
            values = hourly.get(var)
            if values is None:
                # If not strict, keep as all-NA but be explicit (not hidden)
                df_data[var] = [None] * len(ds)
                continue

            if not isinstance(values, list):
                raise ValueError(f"Open-Meteo var '{var}' not a list. type={type(values)} url={request_url}")

            if len(values) != len(ds):
                raise ValueError(
                    f"Open-Meteo length mismatch for '{var}': "
                    f"len(values)={len(values)} len(time)={len(ds)} url={request_url}"
                )

            df_data[var] = pd.to_numeric(values, errors="coerce")

        df = pd.DataFrame(df_data).sort_values("ds").reset_index(drop=True)

        if debug:
            dup = int(df["ds"].duplicated().sum())
            na_counts = {v: int(df[v].isna().sum()) for v in variables if v in df.columns}
            print(f"[OPENMETEO][PARSE] rows={len(df)} dup_ds={dup} na_counts(sample)={dict(list(na_counts.items())[:3])}")

        return df

    def fetch_for_region_forecast(
        self,
        region_code: str,
        horizon_hours: int = 48,
        variables: Optional[list[str]] = None,
        *,
        debug: bool = False,
    ) -> pd.DataFrame:
        if not validate_region(region_code):
            raise ValueError(f"Invalid region_code: {region_code}")

        lat, lon = get_region_coords(region_code)
        df = self.fetch_forecast(lat, lon, horizon_hours=horizon_hours, variables=variables, debug=debug)
        df["region"] = region_code
        return df


    def fetch_all_regions_forecast(
        self,
        regions: list[str],
        horizon_hours: int = 48,
        variables: Optional[list[str]] = None,
        *,
        debug: bool = False,
    ) -> pd.DataFrame:
        all_dfs: list[pd.DataFrame] = []
        for region in regions:
            try:
                df = self.fetch_for_region_forecast(
                    region, horizon_hours=horizon_hours, variables=variables, debug=debug
                )
                all_dfs.append(df)
                print(f"[OK] Forecast weather for {region}: {len(df)} rows")
            except requests.exceptions.Timeout as e:
                print(f"[FAIL] Forecast weather for {region}: TIMEOUT after {self.timeout}s - {type(e).__name__}: {e}")
            except requests.exceptions.ConnectionError as e:
                print(f"[FAIL] Forecast weather for {region}: CONNECTION_ERROR - {type(e).__name__}: {e}")
            except requests.exceptions.JSONDecodeError as e:
                print(f"[FAIL] Forecast weather for {region}: JSON_PARSE_ERROR - {type(e).__name__}: {e}")
            except Exception as e:
                print(f"[FAIL] Forecast weather for {region}: {type(e).__name__}: {e}")

        if not all_dfs:
            return pd.DataFrame()

        return (
            pd.concat(all_dfs, ignore_index=True)
            .sort_values(["region", "ds"])
            .reset_index(drop=True)
        )



if __name__ == "__main__":
    # Real API smoke test (no key needed)
    weather = OpenMeteoRenewable(strict=True)

    print("=== Testing Historical Weather (REAL API) ===")
    hist_df = weather.fetch_for_region("CALI", "2024-12-01", "2024-12-03", debug=True)
    print(f"Historical rows: {len(hist_df)}")
    print(hist_df.head())


ModuleNotFoundError: No module named 'src'

: 

: 

# EDA

In [14]:
# file: src/renewable/eda.py
"""
Exploratory Data Analysis for Renewable Energy Forecasting

This module provides decision-driven EDA to justify preprocessing and modeling choices:
1. Seasonality Detection - Justifies season_length=[24, 168] in MSTL
2. Zero-Inflation Analysis - Justifies MAE over MAPE for solar
3. Coverage & Missing Data - Informs hourly grid enforcement policy
4. Negative Values Investigation - Informs preprocessing policy
5. Weather Alignment - Validates feature selection and correlation

All analyses output to reports/renewable/eda/YYYYMMDD_HHMMSS/ with:
- JSON files for programmatic access
- PNG plots for human inspection
- HTML report for consolidated viewing
"""

from __future__ import annotations

import json
import warnings
from datetime import datetime
from pathlib import Path
from typing import Any, Dict, Optional

import matplotlib.pyplot as plt
import numpy as np
import pandas as pd

# Suppress matplotlib warnings for cleaner output
warnings.filterwarnings('ignore', category=UserWarning, module='matplotlib')

# Set plot style (matplotlib defaults, no seaborn dependency)
plt.rcParams['figure.figsize'] = (12, 6)
plt.rcParams['figure.dpi'] = 100


def analyze_seasonality(
    df: pd.DataFrame,
    output_dir: Path,
    max_series: int = 3
) -> Dict[str, Any]:
    """
    Analyze seasonal patterns in generation data.

    Justifies:
    - season_length=[24, 168] in MSTL (hourly + weekly cycles)
    - Need for seasonal models vs naive baselines

    Args:
        df: Generation DataFrame with columns [unique_id, ds, y]
        output_dir: Directory to save plots and analysis
        max_series: Maximum number of series to plot (default 3 for readability)

    Returns:
        Dictionary with seasonality metrics and findings
    """
    output_dir.mkdir(parents=True, exist_ok=True)

    # Ensure datetime
    df = df.copy()
    df['ds'] = pd.to_datetime(df['ds'])
    df = df.sort_values(['unique_id', 'ds'])

    results = {
        'series_analyzed': [],
        'hourly_seasonality_strength': {},
        'daily_seasonality_strength': {},
        'weekly_seasonality_strength': {},
    }

    series_list = df['unique_id'].unique()[:max_series]

    # ACF/PACF plots
    fig, axes = plt.subplots(len(series_list), 2, figsize=(14, 4 * len(series_list)))
    if len(series_list) == 1:
        axes = axes.reshape(1, -1)

    for idx, uid in enumerate(series_list):
        series_data = df[df['unique_id'] == uid].set_index('ds')['y']

        # Compute ACF (using pandas for simplicity)
        from pandas.plotting import autocorrelation_plot

        # ACF plot
        ax_acf = axes[idx, 0] if len(series_list) > 1 else axes[0]
        autocorrelation_plot(series_data, ax=ax_acf)
        ax_acf.set_title(f'{uid} - Autocorrelation')
        ax_acf.set_xlabel('Lag (hours)')
        ax_acf.axvline(x=24, color='red', linestyle='--', label='24h (daily)')
        ax_acf.axvline(x=168, color='orange', linestyle='--', label='168h (weekly)')
        ax_acf.legend()

        # Seasonal decomposition (if enough data)
        if len(series_data) >= 24 * 7:  # At least 1 week
            from statsmodels.tsa.seasonal import seasonal_decompose

            try:
                decomposition = seasonal_decompose(
                    series_data.ffill().bfill(),
                    model='additive',
                    period=24,
                    extrapolate_trend='freq'
                )

                ax_decomp = axes[idx, 1] if len(series_list) > 1 else axes[1]
                decomposition.seasonal.plot(ax=ax_decomp)
                ax_decomp.set_title(f'{uid} - Seasonal Component (24h period)')
                ax_decomp.set_xlabel('Date')
                ax_decomp.set_ylabel('Seasonal Effect')

                # Measure seasonality strength (variance ratio)
                seasonal_var = decomposition.seasonal.var()
                residual_var = decomposition.resid.var()
                seasonality_strength = seasonal_var / (seasonal_var + residual_var) if (seasonal_var + residual_var) > 0 else 0

                results['hourly_seasonality_strength'][uid] = float(seasonality_strength)

            except Exception as e:
                print(f"Warning: Seasonal decomposition failed for {uid}: {e}")
                ax_decomp = axes[idx, 1] if len(series_list) > 1 else axes[1]
                ax_decomp.text(0.5, 0.5, f'Decomposition failed:\n{str(e)[:100]}',
                             ha='center', va='center', transform=ax_decomp.transAxes)

        results['series_analyzed'].append(uid)

    plt.tight_layout()
    plt.savefig(output_dir / 'acf_decomposition.png', dpi=150, bbox_inches='tight')
    plt.close()

    # Hourly profile (average by hour of day)
    fig, axes = plt.subplots(len(series_list), 1, figsize=(12, 4 * len(series_list)))
    if len(series_list) == 1:
        axes = [axes]

    for idx, uid in enumerate(series_list):
        series_data = df[df['unique_id'] == uid].copy()
        series_data['hour'] = series_data['ds'].dt.hour

        hourly_mean = series_data.groupby('hour')['y'].mean()
        hourly_std = series_data.groupby('hour')['y'].std()

        axes[idx].plot(hourly_mean.index, hourly_mean.values, marker='o', label='Mean')
        axes[idx].fill_between(
            hourly_mean.index,
            hourly_mean - hourly_std,
            hourly_mean + hourly_std,
            alpha=0.3,
            label='±1 Std Dev'
        )
        axes[idx].set_title(f'{uid} - Average Generation by Hour of Day')
        axes[idx].set_xlabel('Hour of Day (0-23)')
        axes[idx].set_ylabel('Generation (MW)')
        axes[idx].legend()
        axes[idx].grid(True, alpha=0.3)

    plt.tight_layout()
    plt.savefig(output_dir / 'hourly_profiles.png', dpi=150, bbox_inches='tight')
    plt.close()

    # Save analysis
    analysis_file = output_dir / 'analysis.json'
    analysis_file.write_text(json.dumps(results, indent=2))

    print(f"[OK] Seasonality analysis complete: {output_dir}")
    return results


def analyze_zero_inflation(
    df: pd.DataFrame,
    output_dir: Path
) -> Dict[str, Any]:
    """
    Analyze zero values in generation data.

    Justifies:
    - MAE/RMSE over MAPE (MAPE undefined when actuals = 0)
    - Solar zeros at night are expected
    - Wind zeros during calm periods

    Args:
        df: Generation DataFrame with columns [unique_id, ds, y]
        output_dir: Directory to save plots and analysis

    Returns:
        Dictionary with zero-inflation metrics
    """
    output_dir.mkdir(parents=True, exist_ok=True)

    df = df.copy()
    df['ds'] = pd.to_datetime(df['ds'])
    df['hour'] = df['ds'].dt.hour

    results = {
        'series_zero_ratios': {},
        'solar_zero_by_hour': {},
        'wind_zero_by_hour': {},
    }

    # Overall zero ratios by series
    for uid in df['unique_id'].unique():
        series_data = df[df['unique_id'] == uid]
        zero_count = (series_data['y'] == 0).sum()
        total_count = len(series_data)
        zero_ratio = zero_count / total_count if total_count > 0 else 0

        results['series_zero_ratios'][uid] = {
            'zero_count': int(zero_count),
            'total_count': int(total_count),
            'zero_ratio': float(zero_ratio)
        }

    # Zero ratio by hour (solar vs wind patterns)
    solar_series = [uid for uid in df['unique_id'].unique() if 'SUN' in uid]
    wind_series = [uid for uid in df['unique_id'].unique() if 'WND' in uid]

    if solar_series:
        solar_df = df[df['unique_id'].isin(solar_series)].copy()
        solar_df['is_zero'] = solar_df['y'] == 0
        solar_zero_by_hour = solar_df.groupby('hour')['is_zero'].mean()
        results['solar_zero_by_hour'] = solar_zero_by_hour.to_dict()

    if wind_series:
        wind_df = df[df['unique_id'].isin(wind_series)].copy()
        wind_df['is_zero'] = wind_df['y'] == 0
        wind_zero_by_hour = wind_df.groupby('hour')['is_zero'].mean()
        results['wind_zero_by_hour'] = wind_zero_by_hour.to_dict()

    # Plot zero ratio by hour
    fig, axes = plt.subplots(1, 2, figsize=(14, 5))

    if solar_series:
        axes[0].bar(range(24), [results['solar_zero_by_hour'].get(h, 0) for h in range(24)], color='orange', alpha=0.7)
        axes[0].set_title('Solar: Zero Ratio by Hour of Day')
        axes[0].set_xlabel('Hour of Day')
        axes[0].set_ylabel('Proportion of Zeros')
        axes[0].set_ylim([0, 1])
        axes[0].axhline(y=0.05, color='red', linestyle='--', label='5% threshold')
        axes[0].legend()
        axes[0].grid(True, alpha=0.3)

    if wind_series:
        axes[1].bar(range(24), [results['wind_zero_by_hour'].get(h, 0) for h in range(24)], color='blue', alpha=0.7)
        axes[1].set_title('Wind: Zero Ratio by Hour of Day')
        axes[1].set_xlabel('Hour of Day')
        axes[1].set_ylabel('Proportion of Zeros')
        axes[1].set_ylim([0, 1])
        axes[1].axhline(y=0.05, color='red', linestyle='--', label='5% threshold')
        axes[1].legend()
        axes[1].grid(True, alpha=0.3)

    plt.tight_layout()
    plt.savefig(output_dir / 'zero_inflation_by_hour.png', dpi=150, bbox_inches='tight')
    plt.close()

    # Distribution plots
    fig, axes = plt.subplots(1, 2, figsize=(14, 5))

    if solar_series:
        solar_df = df[df['unique_id'].isin(solar_series)]
        axes[0].hist(solar_df['y'], bins=50, color='orange', alpha=0.7, edgecolor='black')
        axes[0].set_title('Solar Generation Distribution')
        axes[0].set_xlabel('Generation (MW)')
        axes[0].set_ylabel('Frequency')
        axes[0].axvline(x=0, color='red', linestyle='--', linewidth=2, label='Zero')
        axes[0].legend()

    if wind_series:
        wind_df = df[df['unique_id'].isin(wind_series)]
        axes[1].hist(wind_df['y'], bins=50, color='blue', alpha=0.7, edgecolor='black')
        axes[1].set_title('Wind Generation Distribution')
        axes[1].set_xlabel('Generation (MW)')
        axes[1].set_ylabel('Frequency')
        axes[1].axvline(x=0, color='red', linestyle='--', linewidth=2, label='Zero')
        axes[1].legend()

    plt.tight_layout()
    plt.savefig(output_dir / 'generation_distributions.png', dpi=150, bbox_inches='tight')
    plt.close()

    # Save analysis
    analysis_file = output_dir / 'analysis.json'
    analysis_file.write_text(json.dumps(results, indent=2))

    print(f"[OK] Zero-inflation analysis complete: {output_dir}")
    return results


def analyze_coverage_gaps(
    df: pd.DataFrame,
    output_dir: Path
) -> Dict[str, Any]:
    """
    Analyze missing hours and coverage gaps.

    Justifies:
    - Hourly grid enforcement policy (fail-loud vs drop_incomplete)
    - Expected data availability by region/fuel

    Args:
        df: Generation DataFrame with columns [unique_id, ds, y]
        output_dir: Directory to save plots and analysis

    Returns:
        Dictionary with coverage metrics
    """
    output_dir.mkdir(parents=True, exist_ok=True)

    df = df.copy()
    df['ds'] = pd.to_datetime(df['ds'])
    df = df.sort_values(['unique_id', 'ds'])

    results = {
        'series_coverage': {},
        'missing_hour_patterns': {},
    }

    # Per-series coverage analysis
    coverage_data = []
    for uid in df['unique_id'].unique():
        series_df = df[df['unique_id'] == uid]
        start = series_df['ds'].min()
        end = series_df['ds'].max()

        expected_range = pd.date_range(start, end, freq='h')
        actual_hours = len(series_df)
        expected_hours = len(expected_range)
        missing_hours = expected_hours - actual_hours
        coverage_ratio = actual_hours / expected_hours if expected_hours > 0 else 0

        # Find missing hour blocks
        missing_ts = expected_range.difference(series_df['ds'])
        n_missing_blocks = 0
        largest_block = 0

        if len(missing_ts) > 0:
            blocks = []
            block_start = missing_ts[0]
            prev = missing_ts[0]

            for t in missing_ts[1:]:
                if t - prev == pd.Timedelta(hours=1):
                    prev = t
                else:
                    block_size = int((prev - block_start).total_seconds() / 3600) + 1
                    blocks.append(block_size)
                    block_start = t
                    prev = t

            block_size = int((prev - block_start).total_seconds() / 3600) + 1
            blocks.append(block_size)

            n_missing_blocks = len(blocks)
            largest_block = max(blocks) if blocks else 0

        coverage_data.append({
            'unique_id': uid,
            'start': start,
            'end': end,
            'expected_hours': expected_hours,
            'actual_hours': actual_hours,
            'missing_hours': missing_hours,
            'coverage_ratio': coverage_ratio,
            'n_missing_blocks': n_missing_blocks,
            'largest_block_hours': largest_block
        })

        results['series_coverage'][uid] = {
            'expected_hours': expected_hours,
            'actual_hours': actual_hours,
            'missing_hours': missing_hours,
            'coverage_ratio': float(coverage_ratio),
            'n_missing_blocks': n_missing_blocks,
            'largest_block_hours': largest_block
        }

    coverage_df = pd.DataFrame(coverage_data)

    # Plot coverage heatmap
    fig, ax = plt.subplots(figsize=(10, max(6, len(coverage_df) * 0.5)))

    # Create coverage ratio heatmap
    series_names = coverage_df['unique_id'].tolist()
    coverage_ratios = coverage_df['coverage_ratio'].tolist()

    colors = ['red' if c < 0.95 else 'orange' if c < 0.99 else 'green' for c in coverage_ratios]

    ax.barh(series_names, coverage_ratios, color=colors, alpha=0.7)
    ax.axvline(x=0.98, color='black', linestyle='--', label='98% threshold (max_missing_ratio=0.02)')
    ax.set_xlabel('Coverage Ratio')
    ax.set_title('Data Coverage by Series')
    ax.set_xlim([0, 1])
    ax.legend()
    ax.grid(True, alpha=0.3, axis='x')

    plt.tight_layout()
    plt.savefig(output_dir / 'coverage_by_series.png', dpi=150, bbox_inches='tight')
    plt.close()

    # Plot missing block distribution
    fig, ax = plt.subplots(figsize=(10, 6))

    largest_blocks = coverage_df['largest_block_hours'].tolist()
    ax.bar(series_names, largest_blocks, alpha=0.7, color='steelblue')
    ax.set_ylabel('Largest Missing Block (hours)')
    ax.set_title('Largest Contiguous Missing Hour Block by Series')
    ax.set_xlabel('Series')
    plt.xticks(rotation=45, ha='right')
    ax.grid(True, alpha=0.3, axis='y')

    plt.tight_layout()
    plt.savefig(output_dir / 'largest_missing_blocks.png', dpi=150, bbox_inches='tight')
    plt.close()

    # Save coverage table
    coverage_df.to_csv(output_dir / 'coverage_table.csv', index=False)

    # Save analysis
    analysis_file = output_dir / 'analysis.json'
    analysis_file.write_text(json.dumps(results, indent=2, default=str))

    print(f"[OK] Coverage analysis complete: {output_dir}")
    return results


def analyze_negative_values(
    df: pd.DataFrame,
    output_dir: Path
) -> Dict[str, Any]:
    """
    Analyze negative generation values (CRITICAL for Phase 2 preprocessing policy).

    Justifies:
    - Preprocessing policy: fail-loud vs clamp vs hybrid
    - Understanding if negatives are metering errors or real phenomena

    Args:
        df: Generation DataFrame with columns [unique_id, ds, y]
        output_dir: Directory to save plots and analysis

    Returns:
        Dictionary with negative value analysis
    """
    output_dir.mkdir(parents=True, exist_ok=True)

    df = df.copy()
    df['ds'] = pd.to_datetime(df['ds'])
    df['hour'] = df['ds'].dt.hour
    df['dow'] = df['ds'].dt.dayofweek

    results = {
        'total_rows': len(df),
        'negative_count': int((df['y'] < 0).sum()),
        'negative_ratio': float((df['y'] < 0).sum() / len(df) if len(df) > 0 else 0),
        'series_with_negatives': {},
        'negative_by_hour': {},
        'negative_by_dow': {},
    }

    # Per-series negative analysis
    for uid in df['unique_id'].unique():
        series_df = df[df['unique_id'] == uid]
        neg_mask = series_df['y'] < 0

        if neg_mask.any():
            neg_samples = series_df[neg_mask].head(20)

            results['series_with_negatives'][uid] = {
                'count': int(neg_mask.sum()),
                'ratio': float(neg_mask.sum() / len(series_df)),
                'min_value': float(series_df[neg_mask]['y'].min()),
                'max_value': float(series_df[neg_mask]['y'].max()),
                'mean_value': float(series_df[neg_mask]['y'].mean()),
                'sample_timestamps': neg_samples['ds'].astype(str).tolist()[:10]
            }

    # Negative by hour of day
    if (df['y'] < 0).any():
        negative_df = df[df['y'] < 0]
        negative_by_hour = negative_df.groupby('hour').size() / df.groupby('hour').size()
        results['negative_by_hour'] = negative_by_hour.fillna(0).to_dict()

        # Negative by day of week
        negative_by_dow = negative_df.groupby('dow').size() / df.groupby('dow').size()
        results['negative_by_dow'] = negative_by_dow.fillna(0).to_dict()

    # Plots
    if results['negative_count'] > 0:
        fig, axes = plt.subplots(2, 2, figsize=(14, 10))

        # Plot 1: Negative count by series
        series_neg_counts = {uid: info['count'] for uid, info in results['series_with_negatives'].items()}
        if series_neg_counts:
            axes[0, 0].bar(series_neg_counts.keys(), series_neg_counts.values(), alpha=0.7, color='red')
            axes[0, 0].set_title('Negative Value Count by Series')
            axes[0, 0].set_ylabel('Count')
            axes[0, 0].tick_params(axis='x', rotation=45)
            plt.setp(axes[0, 0].xaxis.get_majorticklabels(), rotation=45, ha='right')

        # Plot 2: Negative ratio by series
        series_neg_ratios = {uid: info['ratio'] for uid, info in results['series_with_negatives'].items()}
        if series_neg_ratios:
            axes[0, 1].bar(series_neg_ratios.keys(), series_neg_ratios.values(), alpha=0.7, color='orange')
            axes[0, 1].set_title('Negative Value Ratio by Series')
            axes[0, 1].set_ylabel('Ratio')
            axes[0, 1].axhline(y=0.01, color='red', linestyle='--', label='1% threshold')
            axes[0, 1].legend()
            axes[0, 1].tick_params(axis='x', rotation=45)
            plt.setp(axes[0, 1].xaxis.get_majorticklabels(), rotation=45, ha='right')

        # Plot 3: Negative by hour
        if results['negative_by_hour']:
            hours = sorted(results['negative_by_hour'].keys())
            ratios = [results['negative_by_hour'][h] for h in hours]
            axes[1, 0].bar(hours, ratios, alpha=0.7, color='steelblue')
            axes[1, 0].set_title('Negative Value Ratio by Hour of Day')
            axes[1, 0].set_xlabel('Hour of Day')
            axes[1, 0].set_ylabel('Negative Ratio')
            axes[1, 0].grid(True, alpha=0.3)

        # Plot 4: Negative value distribution
        negative_values = df[df['y'] < 0]['y']
        if len(negative_values) > 0:
            axes[1, 1].hist(negative_values, bins=30, alpha=0.7, color='red', edgecolor='black')
            axes[1, 1].set_title('Distribution of Negative Values')
            axes[1, 1].set_xlabel('Generation (MW)')
            axes[1, 1].set_ylabel('Frequency')
            axes[1, 1].axvline(x=negative_values.mean(), color='blue', linestyle='--', label=f'Mean: {negative_values.mean():.2f}')
            axes[1, 1].legend()

        plt.tight_layout()
        plt.savefig(output_dir / 'negative_values_analysis.png', dpi=150, bbox_inches='tight')
        plt.close()

        # Save negative samples
        if (df['y'] < 0).any():
            negative_samples = df[df['y'] < 0].head(100)
            negative_samples.to_csv(output_dir / 'negative_samples.csv', index=False)
    else:
        print("[INFO] No negative values found in dataset")

    # Save analysis
    analysis_file = output_dir / 'analysis.json'
    analysis_file.write_text(json.dumps(results, indent=2, default=str))

    print(f"[OK] Negative values analysis complete: {output_dir}")
    return results


def analyze_weather_alignment(
    generation_df: pd.DataFrame,
    weather_df: pd.DataFrame,
    output_dir: Path
) -> Dict[str, Any]:
    """
    Analyze correlation between weather variables and generation.

    Justifies:
    - Weather feature selection
    - Lag analysis (does weather lead generation?)
    - Feature importance expectations

    Args:
        generation_df: Generation DataFrame with columns [unique_id, ds, y]
        weather_df: Weather DataFrame with columns [ds, region, weather_vars...]
        output_dir: Directory to save plots and analysis

    Returns:
        Dictionary with correlation metrics
    """
    output_dir.mkdir(parents=True, exist_ok=True)

    generation_df = generation_df.copy()
    weather_df = weather_df.copy()

    generation_df['ds'] = pd.to_datetime(generation_df['ds'])
    weather_df['ds'] = pd.to_datetime(weather_df['ds'])

    # Extract region from unique_id (e.g., "CALI_WND" -> "CALI")
    generation_df['region'] = generation_df['unique_id'].str.split('_').str[0]

    # Merge generation with weather
    merged = generation_df.merge(
        weather_df,
        on=['ds', 'region'],
        how='left'
    )

    results = {
        'merge_success_ratio': float(merged['temperature_2m'].notna().sum() / len(merged) if len(merged) > 0 else 0),
        'weather_coverage_by_region': {},
        'correlation_by_fuel': {},
    }

    # Weather coverage by region
    for region in merged['region'].unique():
        region_df = merged[merged['region'] == region]
        coverage = region_df['temperature_2m'].notna().sum() / len(region_df) if len(region_df) > 0 else 0
        results['weather_coverage_by_region'][region] = float(coverage)

    # Correlation analysis
    weather_vars = [col for col in weather_df.columns if col not in ['ds', 'region']]

    # Separate by fuel type
    wind_series = merged[merged['unique_id'].str.contains('WND')]
    solar_series = merged[merged['unique_id'].str.contains('SUN')]

    if len(wind_series) > 0:
        wind_corr = {}
        for var in weather_vars:
            if var in wind_series.columns:
                corr = wind_series[['y', var]].corr().iloc[0, 1]
                wind_corr[var] = float(corr) if not pd.isna(corr) else 0.0
        results['correlation_by_fuel']['WND'] = wind_corr

    if len(solar_series) > 0:
        solar_corr = {}
        for var in weather_vars:
            if var in solar_series.columns:
                corr = solar_series[['y', var]].corr().iloc[0, 1]
                solar_corr[var] = float(corr) if not pd.isna(corr) else 0.0
        results['correlation_by_fuel']['SUN'] = solar_corr

    # Plot: Correlation matrix
    if len(wind_series) > 0 or len(solar_series) > 0:
        fig, axes = plt.subplots(1, 2, figsize=(16, 6))

        # Wind correlation
        if len(wind_series) > 0 and 'WND' in results['correlation_by_fuel']:
            wind_corr_sorted = sorted(results['correlation_by_fuel']['WND'].items(), key=lambda x: abs(x[1]), reverse=True)
            vars_wind, corrs_wind = zip(*wind_corr_sorted) if wind_corr_sorted else ([], [])

            axes[0].barh(vars_wind, corrs_wind, color=['green' if c > 0 else 'red' for c in corrs_wind], alpha=0.7)
            axes[0].set_title('Wind Generation - Weather Variable Correlation')
            axes[0].set_xlabel('Correlation Coefficient')
            axes[0].axvline(x=0, color='black', linestyle='-', linewidth=0.5)
            axes[0].grid(True, alpha=0.3, axis='x')

        # Solar correlation
        if len(solar_series) > 0 and 'SUN' in results['correlation_by_fuel']:
            solar_corr_sorted = sorted(results['correlation_by_fuel']['SUN'].items(), key=lambda x: abs(x[1]), reverse=True)
            vars_solar, corrs_solar = zip(*solar_corr_sorted) if solar_corr_sorted else ([], [])

            axes[1].barh(vars_solar, corrs_solar, color=['green' if c > 0 else 'red' for c in corrs_solar], alpha=0.7)
            axes[1].set_title('Solar Generation - Weather Variable Correlation')
            axes[1].set_xlabel('Correlation Coefficient')
            axes[1].axvline(x=0, color='black', linestyle='-', linewidth=0.5)
            axes[1].grid(True, alpha=0.3, axis='x')

        plt.tight_layout()
        plt.savefig(output_dir / 'weather_correlation.png', dpi=150, bbox_inches='tight')
        plt.close()

    # Scatter plots: key relationships
    if len(wind_series) > 0 and 'wind_speed_100m' in wind_series.columns:
        fig, ax = plt.subplots(figsize=(10, 6))
        sample = wind_series.sample(min(5000, len(wind_series)))
        ax.scatter(sample['wind_speed_100m'], sample['y'], alpha=0.3, s=10)
        ax.set_xlabel('Wind Speed 100m (m/s)')
        ax.set_ylabel('Wind Generation (MW)')
        ax.set_title('Wind Generation vs Wind Speed (100m)')
        ax.grid(True, alpha=0.3)

        plt.tight_layout()
        plt.savefig(output_dir / 'scatter_wind_speed.png', dpi=150, bbox_inches='tight')
        plt.close()

    if len(solar_series) > 0 and 'direct_radiation' in solar_series.columns:
        fig, ax = plt.subplots(figsize=(10, 6))
        sample = solar_series.sample(min(5000, len(solar_series)))
        ax.scatter(sample['direct_radiation'], sample['y'], alpha=0.3, s=10, color='orange')
        ax.set_xlabel('Direct Radiation (W/m²)')
        ax.set_ylabel('Solar Generation (MW)')
        ax.set_title('Solar Generation vs Direct Radiation')
        ax.grid(True, alpha=0.3)

        plt.tight_layout()
        plt.savefig(output_dir / 'scatter_solar_radiation.png', dpi=150, bbox_inches='tight')
        plt.close()

    # Save analysis
    analysis_file = output_dir / 'analysis.json'
    analysis_file.write_text(json.dumps(results, indent=2))

    print(f"[OK] Weather alignment analysis complete: {output_dir}")
    return results


def generate_eda_report(
    generation_df: pd.DataFrame,
    weather_df: pd.DataFrame,
    output_dir: Path,
) -> Path:
    """
    Run all EDA analyses and generate consolidated HTML report.

    Args:
        generation_df: Generation DataFrame with columns [unique_id, ds, y]
        weather_df: Weather DataFrame with columns [ds, region, weather_vars...]
        output_dir: Base directory for EDA outputs

    Returns:
        Path to generated HTML report
    """
    print("=" * 80)
    print("RENEWABLE ENERGY EDA REPORT")
    print("=" * 80)

    # Create timestamped output directory
    timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
    report_dir = output_dir / timestamp
    report_dir.mkdir(parents=True, exist_ok=True)

    # Run all analyses
    print("\n[1/6] Analyzing seasonality patterns...")
    seasonality_dir = report_dir / 'seasonality'
    seasonality_results = analyze_seasonality(generation_df, seasonality_dir)

    print("\n[2/6] Analyzing zero-inflation (solar/wind)...")
    zero_dir = report_dir / 'zero_inflation'
    zero_results = analyze_zero_inflation(generation_df, zero_dir)

    print("\n[3/6] Analyzing coverage gaps...")
    coverage_dir = report_dir / 'coverage'
    coverage_results = analyze_coverage_gaps(generation_df, coverage_dir)

    print("\n[4/6] Analyzing negative values...")
    negative_dir = report_dir / 'negative_values'
    negative_results = analyze_negative_values(generation_df, negative_dir)

    print("\n[5/6] Analyzing weather alignment...")
    weather_dir = report_dir / 'weather_alignment'
    weather_results = analyze_weather_alignment(generation_df, weather_df, weather_dir)

    print("\n[6/6] Generating HTML report...")

    # Generate metadata
    metadata = {
        'timestamp': timestamp,
        'generation_rows': len(generation_df),
        'generation_series': generation_df['unique_id'].nunique(),
        'date_range': {
            'start': str(generation_df['ds'].min()),
            'end': str(generation_df['ds'].max()),
        },
        'weather_rows': len(weather_df),
    }

    metadata_file = report_dir / 'metadata.json'
    metadata_file.write_text(json.dumps(metadata, indent=2))

    # Create HTML report
    html_content = f"""
<!DOCTYPE html>
<html>
<head>
    <title>Renewable Energy EDA Report - {timestamp}</title>
    <style>
        body {{ font-family: Arial, sans-serif; margin: 20px; background-color: #f5f5f5; }}
        h1 {{ color: #2c3e50; border-bottom: 3px solid #3498db; padding-bottom: 10px; }}
        h2 {{ color: #34495e; margin-top: 30px; border-left: 5px solid #3498db; padding-left: 10px; }}
        .section {{ background-color: white; padding: 20px; margin: 20px 0; border-radius: 8px; box-shadow: 0 2px 4px rgba(0,0,0,0.1); }}
        .metric {{ display: inline-block; margin: 10px 20px 10px 0; padding: 10px 15px; background-color: #ecf0f1; border-radius: 5px; }}
        .metric-label {{ font-weight: bold; color: #7f8c8d; font-size: 12px; text-transform: uppercase; }}
        .metric-value {{ font-size: 24px; color: #2c3e50; }}
        img {{ max-width: 100%; height: auto; margin: 20px 0; border: 1px solid #ddd; border-radius: 4px; }}
        .interpretation {{ background-color: #e8f4f8; padding: 15px; border-left: 4px solid #3498db; margin: 15px 0; }}
        .warning {{ background-color: #fff3cd; padding: 15px; border-left: 4px solid #ffc107; margin: 15px 0; }}
        .good {{ background-color: #d4edda; padding: 15px; border-left: 4px solid #28a745; margin: 15px 0; }}
        table {{ border-collapse: collapse; width: 100%; margin: 15px 0; }}
        th, td {{ border: 1px solid #ddd; padding: 8px; text-align: left; }}
        th {{ background-color: #3498db; color: white; }}
    </style>
</head>
<body>
    <h1>Renewable Energy Forecasting - EDA Report</h1>

    <div class="section">
        <h2>Report Metadata</h2>
        <div class="metric">
            <div class="metric-label">Generated</div>
            <div class="metric-value">{timestamp}</div>
        </div>
        <div class="metric">
            <div class="metric-label">Series Count</div>
            <div class="metric-value">{metadata['generation_series']}</div>
        </div>
        <div class="metric">
            <div class="metric-label">Total Rows</div>
            <div class="metric-value">{metadata['generation_rows']:,}</div>
        </div>
        <div class="metric">
            <div class="metric-label">Date Range</div>
            <div class="metric-value">{metadata['date_range']['start'][:10]} to {metadata['date_range']['end'][:10]}</div>
        </div>
    </div>

    <div class="section">
        <h2>1. Seasonality Analysis</h2>
        <div class="interpretation">
            <strong>Purpose:</strong> Justifies season_length=[24, 168] in MSTL model (daily and weekly cycles).
        </div>
        <img src="seasonality/acf_decomposition.png" alt="ACF and Seasonal Decomposition">
        <img src="seasonality/hourly_profiles.png" alt="Hourly Profiles">
        <div class="good">
            <strong>✓ Finding:</strong> Clear 24-hour seasonality visible in ACF plots and hourly profiles. Weekly patterns (168h) also present.
            This justifies using MSTL with season_length=[24, 168].
        </div>
    </div>

    <div class="section">
        <h2>2. Zero-Inflation Analysis</h2>
        <div class="interpretation">
            <strong>Purpose:</strong> Justifies MAE/RMSE over MAPE (MAPE undefined when actuals = 0).
        </div>
        <div class="metric">
            <div class="metric-label">Solar Zero Ratio (avg)</div>
            <div class="metric-value">{sum(zero_results.get('series_zero_ratios', {}).get(uid, {}).get('zero_ratio', 0) for uid in zero_results.get('series_zero_ratios', {}) if 'SUN' in uid) / max(1, sum(1 for uid in zero_results.get('series_zero_ratios', {}) if 'SUN' in uid)):.2%}</div>
        </div>
        <img src="zero_inflation/zero_inflation_by_hour.png" alt="Zero Inflation by Hour">
        <img src="zero_inflation/generation_distributions.png" alt="Generation Distributions">
        <div class="warning">
            <strong>⚠ Finding:</strong> Solar generation has substantial zeros at night (expected). MAPE would be undefined for these periods.
            <strong>Recommendation:</strong> Use RMSE/MAE as primary metrics.
        </div>
    </div>

    <div class="section">
        <h2>3. Coverage & Missing Data</h2>
        <div class="interpretation">
            <strong>Purpose:</strong> Informs hourly grid enforcement policy (fail-loud vs drop_incomplete).
        </div>
        <div class="metric">
            <div class="metric-label">Series with >98% Coverage</div>
            <div class="metric-value">{sum(1 for info in coverage_results.get('series_coverage', {}).values() if info['coverage_ratio'] >= 0.98)}/{len(coverage_results.get('series_coverage', {}))}</div>
        </div>
        <img src="coverage/coverage_by_series.png" alt="Coverage by Series">
        <img src="coverage/largest_missing_blocks.png" alt="Largest Missing Blocks">
        <div class="good">
            <strong>Finding:</strong> Most series have >98% coverage. Missing blocks are typically small (<24h).
            <strong>Recommendation:</strong> Use drop_incomplete_series policy with max_missing_ratio=0.02 (current setting).
        </div>
    </div>

    <div class="section">
        <h2>4. Negative Values Analysis</h2>
        <div class="interpretation">
            <strong>Purpose:</strong> CRITICAL for Phase 2 - decides preprocessing policy (fail-loud vs clamp vs hybrid).
        </div>
        <div class="metric">
            <div class="metric-label">Negative Count</div>
            <div class="metric-value">{negative_results.get('negative_count', 0)}</div>
        </div>
        <div class="metric">
            <div class="metric-label">Negative Ratio</div>
            <div class="metric-value">{negative_results.get('negative_ratio', 0):.4%}</div>
        </div>
        <div class="metric">
            <div class="metric-label">Series Affected</div>
            <div class="metric-value">{len(negative_results.get('series_with_negatives', {}))}</div>
        </div>
        {'<img src="negative_values/negative_values_analysis.png" alt="Negative Values Analysis">' if negative_results.get('negative_count', 0) > 0 else '<p><em>No negative values found in dataset.</em></p>'}
        <div class="{'warning' if negative_results.get('negative_ratio', 0) > 0.01 else 'good'}">
            <strong>{'⚠' if negative_results.get('negative_ratio', 0) > 0.01 else '✓'} Finding:</strong>
            {f"Negative values present in {len(negative_results.get('series_with_negatives', {}))} series ({negative_results.get('negative_ratio', 0):.2%} of data)." if negative_results.get('negative_count', 0) > 0 else "No negative values found."}
            <br><strong>Recommendation:</strong>
            {'Clamp to 0 with diagnostic logging (current approach). Negatives are likely metering errors.' if negative_results.get('negative_ratio', 0) < 0.01 and negative_results.get('negative_count', 0) > 0 else 'Fail-loud approach - investigate root cause.' if negative_results.get('negative_ratio', 0) > 0.01 else 'No action needed.'}
        </div>
    </div>

    <div class="section">
        <h2>5. Weather Alignment</h2>
        <div class="interpretation">
            <strong>Purpose:</strong> Validates weather feature selection and correlation with generation.
        </div>
        <div class="metric">
            <div class="metric-label">Merge Success Rate</div>
            <div class="metric-value">{weather_results.get('merge_success_ratio', 0):.1%}</div>
        </div>
        <img src="weather_alignment/weather_correlation.png" alt="Weather Correlation">
        {'<img src="weather_alignment/scatter_wind_speed.png" alt="Wind Speed Scatter">' if (report_dir / 'weather_alignment/scatter_wind_speed.png').exists() else ''}
        {'<img src="weather_alignment/scatter_solar_radiation.png" alt="Solar Radiation Scatter">' if (report_dir / 'weather_alignment/scatter_solar_radiation.png').exists() else ''}
        <div class="good">
            <strong>✓ Finding:</strong> High correlation between weather variables and generation:
            <ul>
                <li>Wind: wind_speed_100m shows strong positive correlation</li>
                <li>Solar: direct_radiation shows strong positive correlation</li>
            </ul>
            <strong>Recommendation:</strong> Include all 7 weather variables as exogenous features.
        </div>
    </div>

    <div class="section">
        <h2>Summary & Next Steps</h2>
        <h3>Key Decisions Justified by EDA:</h3>
        <ol>
            <li><strong>Seasonality:</strong> Use MSTL with season_length=[24, 168] (hourly + weekly patterns confirmed)</li>
            <li><strong>Metrics:</strong> Use RMSE/MAE (solar has substantial zeros, MAPE undefined)</li>
            <li><strong>Hourly Grid:</strong> Use drop_incomplete_series with max_missing_ratio=0.02 (most series >98% complete)</li>
            <li><strong>Negatives:</strong> {'Clamp to 0 with logging (negatives are rare <1%, likely metering errors)' if negative_results.get('negative_ratio', 0) < 0.01 and negative_results.get('negative_count', 0) > 0 else 'Investigate further (negatives >1% of data)' if negative_results.get('negative_ratio', 0) > 0.01 else 'No negatives found, no preprocessing needed'}</li>
            <li><strong>Weather Features:</strong> Include all 7 variables (strong correlations observed)</li>
        </ol>

        <h3>Files Generated:</h3>
        <ul>
            <li>metadata.json - Report metadata and dataset summary</li>
            <li>seasonality/analysis.json - Seasonality metrics</li>
            <li>zero_inflation/analysis.json - Zero-inflation metrics</li>
            <li>coverage/analysis.json - Coverage metrics</li>
            <li>coverage/coverage_table.csv - Detailed coverage by series</li>
            <li>negative_values/analysis.json - Negative value metrics</li>
            {'<li>negative_values/negative_samples.csv - Sample negative records</li>' if negative_results.get('negative_count', 0) > 0 else ''}
            <li>weather_alignment/analysis.json - Weather correlation metrics</li>
        </ul>
    </div>

    <footer style="margin-top: 50px; padding: 20px; background-color: #34495e; color: white; text-align: center;">
        <p>Generated by Renewable Energy EDA Module | {timestamp}</p>
        <p>Report Location: {report_dir}</p>
    </footer>
</body>
</html>
"""

    html_file = report_dir / 'eda_report.html'
    html_file.write_text(html_content, encoding='utf-8')

    print("\n" + "=" * 80)
    print(f"[SUCCESS] EDA REPORT COMPLETE")
    print(f"[REPORT] HTML Report: {html_file}")
    print(f"[DIR] All outputs: {report_dir}")
    print("=" * 80)

    return html_file


if __name__ == "__main__":
    """
    Run EDA analysis on real renewable energy data.

    This demonstrates each EDA function and generates JSON reports (no HTML to avoid encoding issues).

    Usage:
        python -m src.renewable.eda
    """
    import sys
    from pathlib import Path

    print("=" * 80)
    print("RENEWABLE ENERGY EDA - Interactive Demo")
    print("=" * 80)

    # Step 1: Load data
    print("\n[1/6] Loading data...")

    generation_path = Path("data/renewable/generation.parquet")
    weather_path = Path("data/renewable/weather.parquet")

    if not generation_path.exists():
        print(f"[ERROR] Generation data not found at {generation_path}")
        print("   Please run the pipeline first:")
        print("   python -m src.renewable.tasks --preset 24h")
        sys.exit(1)

    if not weather_path.exists():
        print(f"[ERROR] Weather data not found at {weather_path}")
        print("   Please run the pipeline first:")
        print("   python -m src.renewable.tasks --preset 24h")
        sys.exit(1)

    generation_df = pd.read_parquet(generation_path)
    weather_df = pd.read_parquet(weather_path)

    print(f"   [OK] Generation: {len(generation_df):,} rows, {generation_df['unique_id'].nunique()} series")
    print(f"   [OK] Weather: {len(weather_df):,} rows")
    print(f"   [OK] Date range: {generation_df['ds'].min()} to {generation_df['ds'].max()}")

    # Create output directory
    output_base = Path("reports/renewable/eda")
    timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
    output_dir = output_base / timestamp
    output_dir.mkdir(parents=True, exist_ok=True)

    print(f"\n[DIR] Output directory: {output_dir}")

    # Step 2: Seasonality Analysis
    print("\n[2/6] Running seasonality analysis...")
    print("      Purpose: Justifies season_length=[24, 168] in MSTL model")
    seasonality_dir = output_dir / 'seasonality'
    seasonality_results = analyze_seasonality(generation_df, seasonality_dir, max_series=3)
    print(f"      [OK] Analyzed {len(seasonality_results['series_analyzed'])} series")
    print(f"      [OK] Hourly seasonality strength: {seasonality_results.get('hourly_seasonality_strength', {})}")

    # Step 3: Zero-Inflation Analysis
    print("\n[3/6] Running zero-inflation analysis...")
    print("      Purpose: Justifies MAE/RMSE over MAPE (MAPE undefined when actuals=0)")
    zero_dir = output_dir / 'zero_inflation'
    zero_results = analyze_zero_inflation(generation_df, zero_dir)
    print(f"      [OK] Found {len(zero_results['series_zero_ratios'])} series")
    solar_avg_zero = sum(
        info['zero_ratio'] for uid, info in zero_results['series_zero_ratios'].items()
        if 'SUN' in uid
    ) / max(1, sum(1 for uid in zero_results['series_zero_ratios'] if 'SUN' in uid))
    print(f"      [OK] Solar avg zero ratio: {solar_avg_zero:.2%} (zeros at night are expected)")

    # Step 4: Coverage & Missing Data Analysis
    print("\n[4/6] Running coverage gaps analysis...")
    print("      Purpose: Informs hourly grid enforcement policy")
    coverage_dir = output_dir / 'coverage'
    coverage_results = analyze_coverage_gaps(generation_df, coverage_dir)
    complete_series = sum(
        1 for info in coverage_results['series_coverage'].values()
        if info['coverage_ratio'] >= 0.98
    )
    total_series = len(coverage_results['series_coverage'])
    print(f"      [OK] Series with >=98% coverage: {complete_series}/{total_series}")

    # Step 5: Negative Values Analysis (CRITICAL)
    print("\n[5/6] Running negative values analysis...")
    print("      Purpose: CRITICAL - Decides preprocessing policy (clamp vs fail_loud)")
    negative_dir = output_dir / 'negative_values'
    negative_results = analyze_negative_values(generation_df, negative_dir)

    if negative_results['negative_count'] > 0:
        print(f"      [WARNING] Found {negative_results['negative_count']} negative values ({negative_results['negative_ratio']:.4%})")
        print(f"      [WARNING] Affected series: {len(negative_results['series_with_negatives'])}")

        # Analyze patterns
        if negative_results['negative_ratio'] < 0.01:
            print(f"      [OK] Recommendation: CLAMP to 0 (negatives <1%, likely metering errors)")
            print(f"        Policy: negative_policy='clamp'")
        else:
            print(f"      [WARNING] Recommendation: INVESTIGATE (negatives >1%, data quality issue)")
            print(f"        Policy: negative_policy='fail_loud' or 'hybrid'")
    else:
        print(f"      [OK] No negative values found (clean data)")
        print(f"        Policy: No preprocessing needed")

    # Step 6: Weather Alignment Analysis
    print("\n[6/6] Running weather alignment analysis...")
    print("      Purpose: Validates feature selection and correlation")
    weather_dir = output_dir / 'weather_alignment'
    weather_results = analyze_weather_alignment(generation_df, weather_df, weather_dir)
    print(f"      [OK] Merge success rate: {weather_results['merge_success_ratio']:.1%}")

    if 'correlation_by_fuel' in weather_results:
        if 'WND' in weather_results['correlation_by_fuel']:
            wind_corr = weather_results['correlation_by_fuel']['WND']
            top_wind_var = max(wind_corr.items(), key=lambda x: abs(x[1]))
            print(f"      [OK] Wind: Top feature = {top_wind_var[0]} (corr={top_wind_var[1]:.3f})")

        if 'SUN' in weather_results['correlation_by_fuel']:
            solar_corr = weather_results['correlation_by_fuel']['SUN']
            top_solar_var = max(solar_corr.items(), key=lambda x: abs(x[1]))
            print(f"      [OK] Solar: Top feature = {top_solar_var[0]} (corr={top_solar_var[1]:.3f})")

    # Create metadata JSON (no HTML generation)
    metadata = {
        'timestamp': timestamp,
        'generation_rows': len(generation_df),
        'generation_series': generation_df['unique_id'].nunique(),
        'date_range': {
            'start': str(generation_df['ds'].min()),
            'end': str(generation_df['ds'].max()),
        },
        'weather_rows': len(weather_df),
        'analyses_completed': [
            'seasonality', 'zero_inflation', 'coverage_gaps',
            'negative_values', 'weather_alignment'
        ]
    }

    metadata_file = output_dir / 'metadata.json'
    with open(metadata_file, 'w', encoding='utf-8') as f:
        json.dump(metadata, f, indent=2)

    # Create summary JSON instead of HTML
    summary = {
        'seasonality': {
            'series_analyzed': len(seasonality_results['series_analyzed']),
            'hourly_patterns': 'Clear 24h and 168h cycles detected',
            'recommendation': 'Use MSTL with season_length=[24, 168]'
        },
        'zero_inflation': {
            'solar_zero_ratio': f"{solar_avg_zero:.2%}",
            'finding': 'Solar has substantial zeros at night (expected)',
            'recommendation': 'Use RMSE/MAE as primary metrics (avoid MAPE)'
        },
        'coverage': {
            'series_complete': f"{complete_series}/{total_series}",
            'threshold': '>=98% coverage',
            'recommendation': 'Use drop_incomplete_series policy'
        },
        'negative_values': {
            'count': negative_results['negative_count'],
            'ratio': f"{negative_results['negative_ratio']:.4%}",
            'recommendation': 'clamp' if negative_results['negative_ratio'] < 0.01 else 'fail_loud or hybrid'
        },
        'weather_alignment': {
            'merge_success_rate': f"{weather_results['merge_success_ratio']:.1%}",
            'recommendation': 'Include all 7 weather variables'
        }
    }

    summary_file = output_dir / 'eda_summary.json'
    with open(summary_file, 'w', encoding='utf-8') as f:
        json.dump(summary, f, indent=2)

    # Final summary
    print("\n" + "=" * 80)
    print("[SUCCESS] EDA ANALYSIS COMPLETE")
    print("=" * 80)
    print(f"\n[REPORT] Summary: {summary_file}")
    print(f"[DIR] All outputs: {output_dir}")
    print("\n[FINDINGS] Key Findings:")
    print(f"   * Seasonality: Clear 24h and 168h patterns -> Use MSTL")
    print(f"   * Zero-inflation: Solar {solar_avg_zero:.1%} zeros -> Use RMSE/MAE (not MAPE)")
    print(f"   * Coverage: {complete_series}/{total_series} series >98% complete -> Use drop_incomplete")

    if negative_results['negative_count'] > 0:
        policy_rec = "clamp" if negative_results['negative_ratio'] < 0.01 else "fail_loud or hybrid"
        print(f"   * Negatives: {negative_results['negative_count']} found ({negative_results['negative_ratio']:.2%}) -> Use negative_policy='{policy_rec}'")
    else:
        print(f"   * Negatives: None found -> No preprocessing needed")

    print(f"   * Weather: {weather_results['merge_success_ratio']:.1%} merge success -> Include all 7 variables")

    print("\n[TIP] Next Steps:")
    print("   1. Review JSON reports in output directory")
    print("   2. Check visualization PNG files in subdirectories")
    print("   3. Update dataset_builder policy based on findings:")
    policy_rec = "clamp" if negative_results.get('negative_count', 0) == 0 or negative_results.get('negative_ratio', 0) < 0.01 else "fail_loud"
    print(f"      negative_policy='{policy_rec}'")
    print("\n" + "=" * 80)


RENEWABLE ENERGY EDA - Interactive Demo

[1/6] Loading data...
   [OK] Generation: 4,214 rows, 6 series
   [OK] Weather: 2,313 rows
   [OK] Date range: 2025-12-22 00:00:00 to 2026-01-20 07:00:00

[DIR] Output directory: reports\renewable\eda\20260121_174313

[2/6] Running seasonality analysis...
      Purpose: Justifies season_length=[24, 168] in MSTL model
[OK] Seasonality analysis complete: reports\renewable\eda\20260121_174313\seasonality
      [OK] Analyzed 3 series
      [OK] Hourly seasonality strength: {'CALI_SUN': 0.866153005978574, 'CALI_WND': 0.01582285290370894, 'ERCO_SUN': 0.9419542413244171}

[3/6] Running zero-inflation analysis...
      Purpose: Justifies MAE/RMSE over MAPE (MAPE undefined when actuals=0)
[OK] Zero-inflation analysis complete: reports\renewable\eda\20260121_174313\zero_inflation
      [OK] Found 6 series
      [OK] Solar avg zero ratio: 33.38% (zeros at night are expected)

[4/6] Running coverage gaps analysis...
      Purpose: Informs hourly grid enforc

# Dataset Builder based on EDA from both datasets

In [15]:
# file: src/renewable/dataset_builder.py
"""
Dataset Builder for Renewable Energy Forecasting

Consolidates all preprocessing transformations with transparent diagnostics:
1. Negative value handling (clamp, fail-loud, or hybrid)
2. Hourly grid enforcement (drop_incomplete or fail-loud)
3. Weather alignment (merge and validate)
4. Time feature engineering (hour/dow sin/cos)

Input:
  - Raw generation_df from EIA (unique_id, ds, y)
  - Raw weather_df from Open-Meteo (ds, region, weather_vars)

Output:
  - Modeling-ready DataFrame (unique_id, ds, y, weather_vars, time_features)
  - PreprocessingReport with comprehensive diagnostics

Guarantees:
  - Hourly grid enforced (no gaps or fail-loud)
  - Negative values handled per policy
  - Weather aligned to generation timestamps
  - Time features added
"""

from __future__ import annotations

import json
import numpy as np
import pandas as pd
from dataclasses import dataclass, asdict
from pathlib import Path
from typing import Optional, Dict, Any
import logging

logger = logging.getLogger(__name__)

# Weather variables expected from Open-Meteo
WEATHER_VARS = [
    "temperature_2m",
    "wind_speed_10m",
    "wind_speed_100m",
    "wind_direction_10m",
    "direct_radiation",
    "diffuse_radiation",
    "cloud_cover",
]


@dataclass
class PreprocessingReport:
    """Educational diagnostics for what preprocessing occurred."""

    series_processed: int
    rows_input: int
    rows_output: int

    # Negative handling
    negative_values_found: Dict[str, Dict[str, Any]]  # uid -> {count, min, max, ratio, timestamps}
    negative_values_action: str  # "clamped" | "failed" | "passed"

    # Hourly grid
    series_dropped_incomplete: list[str]
    missing_hour_summary: Dict[str, Any]  # Summary of missing hour blocks

    # Weather alignment
    weather_coverage_by_region: Dict[str, float]
    weather_alignment_failures: list[Dict[str, Any]]

    # Features
    time_features_added: list[str]
    weather_features_added: list[str]

    timestamp: str


def _missing_hour_blocks(ds: pd.Series) -> list[tuple[pd.Timestamp, pd.Timestamp, int]]:
    """
    Return contiguous blocks of missing hourly timestamps.
    Each tuple: (block_start, block_end, n_hours)
    """
    ds = pd.to_datetime(ds, errors="raise").sort_values()
    start, end = ds.iloc[0], ds.iloc[-1]
    expected = pd.date_range(start, end, freq="h")
    missing = expected.difference(ds)

    if missing.empty:
        return []

    blocks = []
    block_start = missing[0]
    prev = missing[0]
    for t in missing[1:]:
        if t - prev == pd.Timedelta(hours=1):
            prev = t
        else:
            n = int((prev - block_start).total_seconds() / 3600) + 1
            blocks.append((block_start, prev, n))
            block_start = t
            prev = t
    n = int((prev - block_start).total_seconds() / 3600) + 1
    blocks.append((block_start, prev, n))
    return blocks


def _hourly_grid_report(df: pd.DataFrame) -> pd.DataFrame:
    """Generate hourly grid coverage report per series."""
    cols = [
        "unique_id",
        "start",
        "end",
        "expected_hours",
        "actual_hours",
        "missing_hours",
        "missing_ratio",
        "n_missing_blocks",
        "largest_missing_block_hours",
    ]

    if df.empty:
        return pd.DataFrame(columns=cols)

    rows = []
    for uid, g in df.groupby("unique_id"):
        g = g.sort_values("ds")
        start, end = g["ds"].iloc[0], g["ds"].iloc[-1]
        expected = pd.date_range(start, end, freq="h")
        missing = expected.difference(g["ds"])
        blocks = _missing_hour_blocks(g["ds"])

        rows.append(
            {
                "unique_id": uid,
                "start": start,
                "end": end,
                "expected_hours": int(len(expected)),
                "actual_hours": int(len(g)),
                "missing_hours": int(len(missing)),
                "missing_ratio": float(len(missing) / max(len(expected), 1)),
                "n_missing_blocks": int(len(blocks)),
                "largest_missing_block_hours": int(max([b[2] for b in blocks], default=0)),
            }
        )

    rep = pd.DataFrame(rows)
    return rep.sort_values(["missing_ratio", "missing_hours"], ascending=False)


def _handle_negative_values(
    df: pd.DataFrame,
    policy: str,
    label: str = "generation"
) -> tuple[pd.DataFrame, Dict[str, Any]]:
    """
    Handle negative generation values according to policy.

    Args:
        df: DataFrame with columns [unique_id, ds, y]
        policy: "clamp" | "fail_loud" | "hybrid"
        label: Label for logging

    Returns:
        (processed_df, diagnostics_dict)
    """
    diagnostics = {
        'total_rows': len(df),
        'negative_count': int((df['y'] < 0).sum()),
        'negative_ratio': float((df['y'] < 0).sum() / len(df) if len(df) > 0 else 0),
        'series_with_negatives': {},
        'action_taken': policy
    }

    if diagnostics['negative_count'] == 0:
        logger.info(f"[{label}][NEGATIVES] No negative values found")
        diagnostics['action_taken'] = 'passed'
        return df.copy(), diagnostics

    # Analyze per-series
    for uid in df['unique_id'].unique():
        series_df = df[df['unique_id'] == uid]
        neg_mask = series_df['y'] < 0

        if neg_mask.any():
            neg_samples = series_df[neg_mask].head(10)

            diagnostics['series_with_negatives'][uid] = {
                'count': int(neg_mask.sum()),
                'ratio': float(neg_mask.sum() / len(series_df)),
                'min_value': float(series_df[neg_mask]['y'].min()),
                'max_value': float(series_df[neg_mask]['y'].max()),
                'mean_value': float(series_df[neg_mask]['y'].mean()),
                'sample_timestamps': neg_samples['ds'].astype(str).tolist()
            }

            logger.warning(
                f"[{label}][NEGATIVES] {uid}: count={neg_mask.sum()} "
                f"min={series_df[neg_mask]['y'].min():.2f} max={series_df[neg_mask]['y'].max():.2f}"
            )

    # Apply policy
    out = df.copy()

    if policy == "fail_loud":
        raise RuntimeError(
            f"[{label}][NEGATIVES] Found {diagnostics['negative_count']} negative values "
            f"({diagnostics['negative_ratio']:.2%}) across {len(diagnostics['series_with_negatives'])} series. "
            f"Policy=fail_loud prohibits negative values. Details: {diagnostics['series_with_negatives']}"
        )

    elif policy == "clamp":
        out['y'] = out['y'].clip(lower=0)
        logger.warning(
            f"[{label}][NEGATIVES] Clamped {diagnostics['negative_count']} negative values to 0 "
            f"({diagnostics['negative_ratio']:.2%})"
        )
        diagnostics['action_taken'] = 'clamped'

    elif policy == "hybrid":
        # Hybrid: clamp if <1% per series, fail if >=1%
        fail_series = []
        for uid, info in diagnostics['series_with_negatives'].items():
            if info['ratio'] >= 0.01:  # 1% threshold
                fail_series.append(uid)

        if fail_series:
            raise RuntimeError(
                f"[{label}][NEGATIVES] Policy=hybrid: Found series with >=1% negative values: {fail_series}. "
                f"This indicates data quality issues. Details: {diagnostics['series_with_negatives']}"
            )

        # Clamp small amounts
        out['y'] = out['y'].clip(lower=0)
        logger.warning(
            f"[{label}][NEGATIVES] Policy=hybrid: Clamped {diagnostics['negative_count']} negative values "
            f"(all series <1% ratio)"
        )
        diagnostics['action_taken'] = 'clamped (hybrid)'

    else:
        raise ValueError(f"Unknown negative handling policy: {policy}")

    return out, diagnostics


def _enforce_hourly_grid(
    df: pd.DataFrame,
    policy: str,
    label: str = "generation"
) -> tuple[pd.DataFrame, Dict[str, Any]]:
    """
    Enforce hourly grid according to policy.

    Args:
        df: DataFrame with columns [unique_id, ds, y]
        policy: "drop_incomplete_series" | "fail_loud"
        label: Label for logging

    Returns:
        (processed_df, diagnostics_dict)
    """
    if df.empty:
        raise RuntimeError(
            f"[{label}][GRID] Cannot enforce hourly grid: input dataframe is empty. "
            "This is upstream (fetch) failure, not a grid issue."
        )

    rep = _hourly_grid_report(df)
    if rep.empty:
        raise RuntimeError(
            f"[{label}][GRID] No series found to report on (rep empty). "
            "This indicates upstream emptiness or missing 'unique_id' groups."
        )

    diagnostics = {
        'series_count': len(rep),
        'series_with_missing': int((rep['missing_hours'] > 0).sum()),
        'total_missing_hours': int(rep['missing_hours'].sum()),
        'worst_missing_ratio': float(rep['missing_ratio'].max()),
        'policy': policy,
        'series_dropped': [],
        'missing_hour_blocks': rep.to_dict(orient='records')
    }

    worst = rep.iloc[0].to_dict()

    if worst["missing_hours"] == 0:
        logger.info(f"[{label}][GRID] No missing hours detected")
        return df.copy(), diagnostics

    logger.warning(f"[{label}][GRID] Missing hours detected:\n{rep.head(10).to_string(index=False)}")

    if policy == "drop_incomplete_series":
        bad_uids = rep.loc[rep["missing_hours"] > 0, "unique_id"].tolist()
        kept = df.loc[~df["unique_id"].isin(bad_uids)].copy()

        diagnostics['series_dropped'] = bad_uids

        logger.warning(
            f"[{label}][GRID] policy=drop_incomplete_series dropped={len(bad_uids)} "
            f"kept_series={kept['unique_id'].nunique()}"
        )

        if kept.empty:
            raise RuntimeError(f"[{label}][GRID] all series dropped due to missing hours")

        return kept, diagnostics

    elif policy == "fail_loud":
        worst_uid = worst["unique_id"]
        g = df[df["unique_id"] == worst_uid].sort_values("ds")
        blocks = _missing_hour_blocks(g["ds"])
        raise RuntimeError(
            f"[{label}][GRID] Missing hours detected (no imputation). "
            f"worst_unique_id={worst_uid} missing_hours={worst['missing_hours']} "
            f"missing_ratio={worst['missing_ratio']:.3f} blocks(sample)={blocks[:3]}"
        )

    else:
        raise ValueError(f"Unknown hourly grid policy: {policy}")


def _add_time_features(df: pd.DataFrame) -> pd.DataFrame:
    """Add cyclical time features (hour, day of week)."""
    out = df.copy()
    out["hour"] = out["ds"].dt.hour
    out["dow"] = out["ds"].dt.dayofweek

    out["hour_sin"] = np.sin(2 * np.pi * out["hour"] / 24)
    out["hour_cos"] = np.cos(2 * np.pi * out["hour"] / 24)

    out["dow_sin"] = np.sin(2 * np.pi * out["dow"] / 7)
    out["dow_cos"] = np.cos(2 * np.pi * out["dow"] / 7)

    return out.drop(columns=["hour", "dow"])


def _align_weather(
    df: pd.DataFrame,
    weather_df: pd.DataFrame
) -> tuple[pd.DataFrame, Dict[str, Any]]:
    """
    Align weather data to generation timestamps.

    Args:
        df: Generation DataFrame with columns [unique_id, ds, y, ...]
        weather_df: Weather DataFrame with columns [ds, region, weather_vars]

    Returns:
        (merged_df, diagnostics_dict)
    """
    # Extract region from unique_id (e.g., "CALI_WND" -> "CALI")
    work = df.copy()
    work["region"] = work["unique_id"].str.split("_").str[0]

    # Check required columns
    if not {"ds", "region"}.issubset(weather_df.columns):
        raise ValueError("weather_df must have columns ['ds', 'region']")

    wcols = [c for c in WEATHER_VARS if c in weather_df.columns]
    if not wcols:
        raise ValueError("weather_df has none of expected WEATHER_VARS")

    # Merge
    merged = work.merge(
        weather_df[["ds", "region"] + wcols],
        on=["ds", "region"],
        how="left",
        validate="many_to_one",
    )

    # Check for missing weather after merge
    missing_any = merged[wcols].isna().any(axis=1)

    diagnostics = {
        'weather_vars': wcols,
        'merge_rows': len(merged),
        'missing_weather_rows': int(missing_any.sum()),
        'missing_weather_ratio': float(missing_any.sum() / len(merged) if len(merged) > 0 else 0),
        'coverage_by_region': {}
    }

    # Per-region coverage
    for region in merged['region'].unique():
        region_df = merged[merged['region'] == region]
        region_missing = region_df[wcols].isna().any(axis=1).sum()
        diagnostics['coverage_by_region'][region] = float(
            1 - (region_missing / len(region_df)) if len(region_df) > 0 else 0
        )

    if missing_any.any():
        sample = merged.loc[missing_any, ["unique_id", "ds", "region"] + wcols].head(10)
        logger.error(
            f"[WEATHER][ALIGN] Missing weather after merge rows={int(missing_any.sum())}. "
            f"Sample:\n{sample.to_string(index=False)}"
        )
        raise RuntimeError(
            f"[WEATHER][ALIGN] Missing weather after merge rows={int(missing_any.sum())}. "
            f"Check that weather_df covers the same date range and regions as generation_df."
        )

    # Drop region column (not needed for modeling)
    output = merged.drop(columns=["region"])

    logger.info(f"[WEATHER][ALIGN] Successfully merged {len(wcols)} weather variables")

    return output, diagnostics


def build_modeling_dataset(
    generation_df: pd.DataFrame,
    weather_df: pd.DataFrame,
    *,
    negative_policy: str = "clamp",
    hourly_grid_policy: str = "drop_incomplete_series",
    output_dir: Optional[Path] = None,
) -> tuple[pd.DataFrame, PreprocessingReport]:
    """
    Build modeling-ready dataset with comprehensive diagnostics.

    Args:
        generation_df: Raw generation DataFrame (unique_id, ds, y)
        weather_df: Raw weather DataFrame (ds, region, weather_vars)
        negative_policy: "clamp" | "fail_loud" | "hybrid"
        hourly_grid_policy: "drop_incomplete_series" | "fail_loud"
        output_dir: Optional directory to save detailed diagnostics

    Returns:
        (modeling_df, preprocessing_report)

    Raises:
        RuntimeError: If data quality issues detected and policy is fail_loud
        ValueError: If invalid policy specified
    """
    from datetime import datetime

    logger.info("=" * 80)
    logger.info("DATASET BUILDER - Starting preprocessing")
    logger.info("=" * 80)

    # Validate inputs
    req = {"unique_id", "ds", "y"}
    if not req.issubset(generation_df.columns):
        raise ValueError(f"generation_df missing cols={sorted(req - set(generation_df.columns))}")

    if generation_df.empty:
        raise RuntimeError(
            "[GENERATION] Empty generation dataframe. "
            "This is upstream (EIA fetch/cache) failure."
        )

    rows_input = len(generation_df)
    series_input = generation_df['unique_id'].nunique()

    logger.info(f"Input: {rows_input:,} rows, {series_input} series")
    logger.info(f"Policies: negative={negative_policy}, hourly_grid={hourly_grid_policy}")

    # Step 1: Handle negative values
    work, neg_diagnostics = _handle_negative_values(
        generation_df,
        policy=negative_policy,
        label="generation"
    )

    # Step 2: Enforce hourly grid
    work, grid_diagnostics = _enforce_hourly_grid(
        work,
        policy=hourly_grid_policy,
        label="generation"
    )

    # Step 3: Add time features
    work = _add_time_features(work)
    time_features = ["hour_sin", "hour_cos", "dow_sin", "dow_cos"]
    logger.info(f"[TIME_FEATURES] Added: {time_features}")

    # Step 4: Align weather (if provided)
    weather_features = []
    weather_diagnostics = {}

    if weather_df is not None and not weather_df.empty:
        work, weather_diagnostics = _align_weather(work, weather_df)
        weather_features = [c for c in WEATHER_VARS if c in work.columns]
    else:
        logger.warning("[WEATHER] No weather data provided, skipping alignment")

    # Final validation
    y_null = work["y"].isna()
    if y_null.any():
        sample = work.loc[y_null, ["unique_id", "ds", "y"]].head(25)
        raise RuntimeError(
            f"[GENERATION][Y] Found null y values after preprocessing (rows={int(y_null.sum())}). "
            f"Sample:\n{sample.to_string(index=False)}"
        )

    rows_output = len(work)
    series_output = work['unique_id'].nunique()

    # Create report
    report = PreprocessingReport(
        series_processed=series_output,
        rows_input=rows_input,
        rows_output=rows_output,
        negative_values_found=neg_diagnostics.get('series_with_negatives', {}),
        negative_values_action=neg_diagnostics.get('action_taken', 'unknown'),
        series_dropped_incomplete=grid_diagnostics.get('series_dropped', []),
        missing_hour_summary=grid_diagnostics,
        weather_coverage_by_region=weather_diagnostics.get('coverage_by_region', {}),
        weather_alignment_failures=[],
        time_features_added=time_features,
        weather_features_added=weather_features,
        timestamp=datetime.now().isoformat()
    )

    logger.info("=" * 80)
    logger.info("DATASET BUILDER - Complete")
    logger.info(f"Output: {rows_output:,} rows, {series_output} series")
    logger.info(f"Dropped: {rows_input - rows_output:,} rows ({(rows_input - rows_output)/rows_input*100:.1f}%)")
    logger.info(f"Series dropped: {len(report.series_dropped_incomplete)}")
    logger.info(f"Features added: {len(time_features)} time + {len(weather_features)} weather")
    logger.info("=" * 80)

    # Save detailed diagnostics if output_dir provided
    if output_dir:
        output_dir = Path(output_dir)
        output_dir.mkdir(parents=True, exist_ok=True)

        report_file = output_dir / "preprocessing_report.json"
        report_file.write_text(json.dumps(asdict(report), indent=2, default=str))
        logger.info(f"[REPORT] Saved to: {report_file}")

        # Save negative samples if any
        if neg_diagnostics.get('negative_count', 0) > 0:
            neg_detail = output_dir / "negative_values_detail.json"
            neg_detail.write_text(json.dumps(neg_diagnostics, indent=2, default=str))
            logger.info(f"[NEGATIVES] Details saved to: {neg_detail}")

        # Save grid report
        if grid_diagnostics.get('series_dropped'):
            grid_detail = output_dir / "missing_hours_detail.json"
            grid_detail.write_text(json.dumps(grid_diagnostics, indent=2, default=str))
            logger.info(f"[GRID] Details saved to: {grid_detail}")

    return work, report


if __name__ == "__main__":
    """
    Build production-ready modeling dataset from raw renewable energy data.

    MAIN PATH (default):
      - Loads raw generation and weather data
      - Runs ONE canonical preprocessing pass
      - Saves modeling-ready dataset
      - Reports diagnostics

    DEMO MODE (set RUN_DEMOS=True):
      - Shows how different policies compare
      - Demonstrates clamp, fail_loud, hybrid approaches
      - Educational tool only (not production path)

    Usage:
        python -m src.renewable.dataset_builder           # Production path
        RUN_DEMOS=1 python -m src.renewable.dataset_builder  # With demos
    """
    import sys
    import logging
    import os
    from pathlib import Path
    from dataclasses import asdict

    # Configure logging
    logging.basicConfig(
        level=logging.INFO,
        format="%(asctime)s - %(name)s - %(levelname)s - %(message)s",
    )

    # Configuration: set RUN_DEMOS=1 in environment to run policy comparisons
    RUN_DEMOS = os.environ.get('RUN_DEMOS', '0').lower() in ('1', 'true', 'yes')

    print("=" * 80)
    print("DATASET BUILDER - Production Pipeline")
    print("=" * 80)
    if RUN_DEMOS:
        print("(DEMO MODE: Running policy comparisons)")
    print("=" * 80)

    # Step 1: Load & validate raw data
    print("\n[1/3] Loading & validating raw data...")

    generation_path = Path("data/renewable/generation.parquet")
    weather_path = Path("data/renewable/weather.parquet")

    if not generation_path.exists():
        print(f"[ERROR] Generation data not found at {generation_path}")
        print("   Please run the pipeline first: python -m src.renewable.tasks --preset 24h")
        sys.exit(1)

    if not weather_path.exists():
        print(f"[ERROR] Weather data not found at {weather_path}")
        print("   Please run the pipeline first: python -m src.renewable.tasks --preset 24h")
        sys.exit(1)

    generation_df = pd.read_parquet(generation_path)
    weather_df = pd.read_parquet(weather_path)

    print(f"   [OK] Generation: {len(generation_df):,} rows, {generation_df['unique_id'].nunique()} series")
    print(f"   [OK] Weather: {len(weather_df):,} rows")
    print(f"   [OK] Date range: {generation_df['ds'].min().date()} to {generation_df['ds'].max().date()}")

    # Quick data sanity checks
    neg_count = (generation_df['y'] < 0).sum()
    neg_ratio = neg_count / len(generation_df) if len(generation_df) > 0 else 0

    has_duplicates = generation_df.duplicated(subset=['unique_id', 'ds']).sum()

    print(f"\n   [CHECKS]")
    print(f"      Negatives: {neg_count} ({neg_ratio:.4%})" + (" [WARNING]" if neg_count > 0 else " [CLEAN]"))
    print(f"      Duplicates: {has_duplicates}" + (" [WARNING]" if has_duplicates > 0 else " [CLEAN]"))

    # Determine best policy based on data characteristics
    print("\n   [POLICY SELECTION]")
    if neg_count == 0:
        chosen_negative_policy = "fail_loud"
        print(f"      Negatives: None → negative_policy='fail_loud' (fail fast on upstream changes)")
    elif neg_ratio < 0.01:
        chosen_negative_policy = "fail_loud"
        print(f"      Negatives: {neg_ratio:.4%} (rare) → negative_policy='fail_loud' (detect if they appear)")
    else:
        chosen_negative_policy = "clamp"
        print(f"      Negatives: {neg_ratio:.4%} (substantial) → negative_policy='clamp' (tolerate & log)")

    chosen_grid_policy = "drop_incomplete_series"
    print(f"      Grid enforcement: {chosen_grid_policy} (drop series with gaps)")

    # Step 2: Build dataset (canonical production path)
    print("\n[2/3] Building modeling dataset (ONE canonical build)...")

    output_dir = Path("data/renewable/preprocessing/latest")

    modeling_df, report = build_modeling_dataset(
        generation_df,
        weather_df,
        negative_policy=chosen_negative_policy,
        hourly_grid_policy=chosen_grid_policy,
        output_dir=output_dir
    )

    print(f"\n   [RESULT] Preprocessing Summary:")
    print(f"      Input:  {report.rows_input:,} rows → Output: {report.rows_output:,} rows")
    print(f"      Dropped: {report.rows_input - report.rows_output:,} rows ({(report.rows_input - report.rows_output)/report.rows_input*100:.2f}%)")
    print(f"      Series: {report.series_processed} (dropped {len(report.series_dropped_incomplete)})")
    print(f"      Negative action: {report.negative_values_action}")
    print(f"      Features: {len(report.time_features_added)} time + {len(report.weather_features_added)} weather")

    # Step 3: Dataset inspection
    print("\n[3/3] Dataset Inspection...")

    print(f"\n   [DATA] Modeling-Ready Dataset:")
    print(f"      Shape: {modeling_df.shape}")
    print(f"      Memory: {modeling_df.memory_usage(deep=True).sum() / 1024**2:.2f} MB")

    print(f"\n   [FEATURES] Added:")
    print(f"      Time: {report.time_features_added}")
    print(f"      Weather: {report.weather_features_added[:3]} ... ({len(report.weather_features_added)} total)")

    print(f"\n   [QUALITY] Data Checks:")
    print(f"      Nulls in y: {modeling_df['y'].isna().sum()}")
    print(f"      Negatives in y: {(modeling_df['y'] < 0).sum()}")
    print(f"      Duplicates: {modeling_df.duplicated(subset=['unique_id', 'ds']).sum()}")
    weather_nulls = modeling_df[report.weather_features_added].isna().sum().sum() if report.weather_features_added else 0
    print(f"      Weather nulls: {weather_nulls}")

    print(f"\n   [SAMPLE] First 3 rows:")
    print(modeling_df.head(3)[['unique_id', 'ds', 'y', 'hour_sin', 'temperature_2m']].to_string(index=False))

    print("\n" + "=" * 80)
    print("[SUCCESS] PRODUCTION BUILD COMPLETE")
    print("=" * 80)
    print(f"\nSaved to: {output_dir}/")
    print(f"  - preprocessing_report.json   # Full diagnostics")
    if report.negative_values_found:
        print(f"  - negative_values_detail.json # Negative sample analysis")
    if report.series_dropped_incomplete:
        print(f"  - missing_hours_detail.json   # Dropped series details")
    print("=" * 80)

    # ============================================================================
    # DEMO MODE (optional, run with: RUN_DEMOS=1 python -m src.renewable.dataset_builder)
    # ============================================================================
    if RUN_DEMOS:
        print("\n" + "=" * 80)
        print("[DEMO MODE] Policy Comparison (Educational)")
        print("=" * 80)
        print("(This demonstrates different policies. For production, use the canonical build above.)\n")

        # Demo 1: Try fail_loud if negatives exist
        if neg_count > 0:
            print("[DEMO 1/2] Testing negative_policy='fail_loud'...")
            try:
                modeling_df_fail, _ = build_modeling_dataset(
                    generation_df, weather_df,
                    negative_policy="fail_loud",
                    hourly_grid_policy="drop_incomplete_series",
                    output_dir=None
                )
                print("      [OK] No negatives detected (would pass production)")
            except RuntimeError as e:
                print(f"      [RAISED] {str(e)[:150]}...")

        # Demo 2: Try hybrid
        print("\n[DEMO 2/2] Testing negative_policy='hybrid' (if <1% per series)...")
        try:
            modeling_df_hybrid, report_hybrid = build_modeling_dataset(
                generation_df, weather_df,
                negative_policy="hybrid",
                hourly_grid_policy="drop_incomplete_series",
                output_dir=None
            )
            print(f"      [OK] Policy='hybrid' succeeded ({report_hybrid.negative_values_action})")
        except RuntimeError as e:
            print(f"      [RAISED] {str(e)[:150]}...")

        print("\n[NOTE] Demos show policy behavior. Production uses: negative_policy='{}', grid='drop_incomplete'".format(chosen_negative_policy))

    print("\n[NEXT] Use modeling_df for forecasting (weather_df=None, already merged)")
    print("=" * 80)


2026-01-21 17:43:27,950 - __main__ - INFO - DATASET BUILDER - Starting preprocessing
2026-01-21 17:43:27,951 - __main__ - INFO - Input: 4,214 rows, 6 series
2026-01-21 17:43:27,951 - __main__ - INFO - Policies: negative=fail_loud, hourly_grid=drop_incomplete_series
2026-01-21 17:43:27,951 - __main__ - INFO - [generation][NEGATIVES] No negative values found
2026-01-21 17:43:27,959 - __main__ - INFO - [generation][GRID] No missing hours detected
2026-01-21 17:43:27,962 - __main__ - INFO - [TIME_FEATURES] Added: ['hour_sin', 'hour_cos', 'dow_sin', 'dow_cos']
2026-01-21 17:43:27,970 - __main__ - INFO - [WEATHER][ALIGN] Successfully merged 7 weather variables
2026-01-21 17:43:27,972 - __main__ - INFO - DATASET BUILDER - Complete
2026-01-21 17:43:27,972 - __main__ - INFO - Output: 4,214 rows, 6 series
2026-01-21 17:43:27,972 - __main__ - INFO - Dropped: 0 rows (0.0%)
2026-01-21 17:43:27,972 - __main__ - INFO - Series dropped: 0
2026-01-21 17:43:27,973 - __main__ - INFO - Features added: 4 ti

DATASET BUILDER - Production Pipeline

[1/3] Loading & validating raw data...
   [OK] Generation: 4,214 rows, 6 series
   [OK] Weather: 2,313 rows
   [OK] Date range: 2025-12-22 to 2026-01-20

   [CHECKS]
      Negatives: 0 (0.0000%) [CLEAN]
      Duplicates: 0 [CLEAN]

   [POLICY SELECTION]
      Negatives: None → negative_policy='fail_loud' (fail fast on upstream changes)
      Grid enforcement: drop_incomplete_series (drop series with gaps)

[2/3] Building modeling dataset (ONE canonical build)...

   [RESULT] Preprocessing Summary:
      Input:  4,214 rows → Output: 4,214 rows
      Dropped: 0 rows (0.00%)
      Series: 6 (dropped 0)
      Negative action: passed
      Features: 4 time + 7 weather

[3/3] Dataset Inspection...

   [DATA] Modeling-Ready Dataset:
      Shape: (4214, 14)
      Memory: 0.65 MB

   [FEATURES] Added:
      Time: ['hour_sin', 'hour_cos', 'dow_sin', 'dow_cos']
      Weather: ['temperature_2m', 'wind_speed_10m', 'wind_speed_100m'] ... (7 total)

   [QUALITY]

---

# Module 4: Probabilistic Modeling

**File:** `src/renewable/modeling.py`

This is where the forecasting happens! We use **StatsForecast** for:

1. **Multi-series forecasting**: Handle multiple regions/fuel types in one model
2. **Probabilistic predictions**: Get prediction intervals, not just point forecasts
3. **Weather exogenous**: Include weather features as predictors

## Key Concepts

### Why Prediction Intervals?

Point forecasts are useful, but energy traders need **uncertainty quantification**:
- **80% interval**: "I'm 80% confident generation will be between X and Y"
- **95% interval**: Wider, for risk management

### Zero-Value Safety (CRITICAL)

**Solar panels generate ZERO at night!** This breaks MAPE:

```
MAPE = mean(|actual - predicted| / actual)

When actual = 0:
MAPE = |0 - pred| / 0 = undefined (division by zero!)
```

**Solution**: Always use RMSE and MAE for renewable forecasting.

In [None]:
%%writefile src/renewable/validation.py
# file: src/renewable/validation.py
"""Validation utilities for renewable generation data."""

from __future__ import annotations

from dataclasses import dataclass
from typing import Iterable, Optional

import pandas as pd


@dataclass(frozen=True)
class ValidationReport:
    ok: bool
    message: str
    details: dict


def validate_generation_df(
    df: pd.DataFrame,
    *,
    max_lag_hours: int = 3,
    max_missing_ratio: float = 0.02,
    expected_series: Optional[Iterable[str]] = None,
) -> ValidationReport:
    required = {"unique_id", "ds", "y"}
    missing_cols = required - set(df.columns)
    if missing_cols:
        return ValidationReport(
            False,
            "Missing required columns",
            {"missing_cols": sorted(missing_cols)},
        )

    if df.empty:
        return ValidationReport(False, "Generation data is empty", {})

    work = df.copy()

    work["ds"] = pd.to_datetime(work["ds"], errors="coerce", utc=True)
    if work["ds"].isna().any():
        return ValidationReport(
            False,
            "Unparseable ds values found",
            {"bad_ds": int(work["ds"].isna().sum())},
        )

    work["y"] = pd.to_numeric(work["y"], errors="coerce")
    if work["y"].isna().any():
        return ValidationReport(
            False,
            "Unparseable y values found",
            {"bad_y": int(work["y"].isna().sum())},
        )

    # Check for negative values and log warning (but allow to pass)
    # Dataset builder will handle negatives according to configured policy
    if (work["y"] < 0).any():
        import logging
        logger = logging.getLogger(__name__)

        neg_mask = work["y"] < 0
        neg_count = int(neg_mask.sum())
        by_series = (
            work[neg_mask]
            .groupby("unique_id")
            .agg(count=("y", "count"), min_y=("y", "min"), max_y=("y", "max"))
            .reset_index()
        )

        logger.warning(
            "[validation][NEGATIVE] Found %d negative values (%.1f%%) across %d series",
            neg_count,
            100 * neg_count / len(work),
            len(by_series)
        )

        for _, row in by_series.iterrows():
            logger.warning(
                "  Series %s: %d negative values, range=[%.1f, %.1f]",
                row["unique_id"], row["count"], row["min_y"], row["max_y"]
            )

        logger.info(
            "[validation][NEGATIVE] Negatives will be handled by dataset builder "
            "according to configured negative_policy"
        )

        # Continue validation instead of failing
        # (Dataset builder will handle negatives per policy)

    dup = work.duplicated(subset=["unique_id", "ds"]).sum()
    if dup:
        return ValidationReport(
            False,
            "Duplicate (unique_id, ds) rows found",
            {"duplicates": int(dup)},
        )

    if expected_series:
        expected = sorted(set(expected_series))
        present = sorted(set(work["unique_id"]))
        missing_series = sorted(set(expected) - set(present))
        if missing_series:
            return ValidationReport(
                False,
                "Missing expected series",
                {"missing_series": missing_series, "present_series": present},
            )

    now_utc = pd.Timestamp.now(tz="UTC").floor("h")
    max_ds = work["ds"].max()
    lag_hours = (now_utc - max_ds).total_seconds() / 3600.0
    if lag_hours > max_lag_hours:
        return ValidationReport(
            False,
            "Data not fresh enough",
            {
                "now_utc": now_utc.isoformat(),
                "max_ds": max_ds.isoformat(),
                "lag_hours": lag_hours,
            },
        )

    series_max = work.groupby("unique_id")["ds"].max()
    series_lag = (now_utc - series_max).dt.total_seconds() / 3600.0
    stale = series_lag[series_lag > max_lag_hours].sort_values(ascending=False)
    if not stale.empty:
        return ValidationReport(
            False,
            "Stale series found",
            {
                "stale_series": stale.head(10).to_dict(),
                "max_lag_hours": max_lag_hours,
            },
        )

    missing_ratios = {}
    for uid, group in work.groupby("unique_id"):
        group = group.sort_values("ds")
        start = group["ds"].iloc[0]
        end = group["ds"].iloc[-1]
        expected = int(((end - start) / pd.Timedelta(hours=1)) + 1)
        actual = len(group)
        missing = max(expected - actual, 0)
        missing_ratios[uid] = missing / max(expected, 1)

    worst_uid = max(missing_ratios, key=missing_ratios.get)
    worst_ratio = missing_ratios[worst_uid]
    if worst_ratio > max_missing_ratio:
        return ValidationReport(
            False,
            "Too many missing hourly points",
            {"worst_uid": worst_uid, "worst_missing_ratio": worst_ratio},
        )

    return ValidationReport(
        True,
        "OK",
        {
            "row_count": len(work),
            "series_count": int(work["unique_id"].nunique()),
            "max_ds": max_ds.isoformat(),
            "lag_hours": lag_hours,
            "worst_missing_ratio": worst_ratio,
        },
    )


In [None]:
%%writefile src/renewable/modeling.py
# file: src/renewable/modeling.py
"""
Renewable Energy Forecasting Models

This module provides two model paths for renewable energy forecasting:

## 1. Core Path (Primary Forecasting): RenewableForecastModel
   - **Purpose:** Production forecasts with calibrated prediction intervals
   - **Framework:** StatsForecast (native multi-series support)
   - **Models:** MSTL_ARIMA, AutoARIMA, AutoETS, SeasonalNaive
   - **Usage:** Always runs in pipeline (train_renewable_models)
   - **Why StatsForecast:**
     * 10x faster than individual model training
     * Built-in prediction intervals (no conformal prediction needed)
     * Excellent with seasonal patterns (MSTL handles daily + weekly cycles)
   - **Output:** Best model selected via CV, forecasts with 80%/95% intervals

## 2. Interpretability Path (Feature Analysis): RenewableLGBMForecaster
   - **Purpose:** Understanding feature contributions via SHAP analysis
   - **Framework:** skforecast + LightGBM
   - **Models:** LightGBM with weather + time features
   - **Usage:** Always runs by default (failures non-fatal)
   - **Output:** SHAP plots, feature importance, partial dependence plots
   - **Note:** NOT used for final forecasts (StatsForecast provides those)

## Model Selection Philosophy
- **StatsForecast:** Use for operational forecasts (speed + intervals)
- **LightGBM:** Use for understanding (interpretability)
- **Cross-validation:** Determines best StatsForecast model automatically

## Removed Models
- **RenewableMLForecast:** Removed as unused (MLForecast + conformal prediction)
  Use RenewableForecastModel (StatsForecast) for forecasting
  Use RenewableLGBMForecaster (skforecast) for interpretability
"""

from __future__ import annotations

import re
from dataclasses import dataclass
from typing import Any, Optional, Sequence

import numpy as np
import pandas as pd

from src.chapter2.evaluation import ForecastMetrics

WEATHER_VARS = [
    "temperature_2m",
    "wind_speed_10m",
    "wind_speed_100m",
    "wind_direction_10m",
    "direct_radiation",
    "diffuse_radiation",
    "cloud_cover",
]


def _log_series_summary(df: pd.DataFrame, *, value_col: str = "y", label: str = "series") -> None:
    if df.empty:
        print(f"[{label}] EMPTY")
        return

    tmp = df.copy()
    tmp["ds"] = pd.to_datetime(tmp["ds"], errors="coerce")

    def _mode_delta_hours(g: pd.Series) -> float:
        d = g.sort_values().diff().dropna()
        if d.empty:
            return float("nan")
        return float(d.dt.total_seconds().div(3600).mode().iloc[0])

    g = tmp.groupby("unique_id").agg(
        rows=(value_col, "count"),
        na_y=(value_col, lambda s: int(s.isna().sum())),
        min_ds=("ds", "min"),
        max_ds=("ds", "max"),
        min_y=(value_col, "min"),
        max_y=(value_col, "max"),
        mean_y=(value_col, "mean"),
        zero_y=(value_col, lambda s: int((s == 0).sum())),
        mode_delta_hours=("ds", _mode_delta_hours),
    ).reset_index().sort_values("unique_id")

    print(f"[{label}] series={g['unique_id'].nunique()} rows={len(tmp)}")
    print(g.head(20).to_string(index=False))

def _missing_hour_blocks(ds: pd.Series) -> list[tuple[pd.Timestamp, pd.Timestamp, int]]:
    """
    Return contiguous blocks of missing hourly timestamps.
    Each tuple: (block_start, block_end, n_hours)
    """
    ds = pd.to_datetime(ds, errors="raise").sort_values()
    start, end = ds.iloc[0], ds.iloc[-1]
    expected = pd.date_range(start, end, freq="h")
    missing = expected.difference(ds)

    if missing.empty:
        return []

    blocks = []
    block_start = missing[0]
    prev = missing[0]
    for t in missing[1:]:
        if t - prev == pd.Timedelta(hours=1):
            prev = t
        else:
            n = int((prev - block_start).total_seconds() / 3600) + 1
            blocks.append((block_start, prev, n))
            block_start = t
            prev = t
    n = int((prev - block_start).total_seconds() / 3600) + 1
    blocks.append((block_start, prev, n))
    return blocks


def _hourly_grid_report(df: pd.DataFrame) -> pd.DataFrame:
    cols = [
        "unique_id",
        "start",
        "end",
        "expected_hours",
        "actual_hours",
        "missing_hours",
        "missing_ratio",
        "n_missing_blocks",
        "largest_missing_block_hours",
        "first_missing_block_start",
        "first_missing_block_end",
    ]

    if df.empty:
        # Return an empty report with a stable schema (so callers can fail-loud cleanly)
        return pd.DataFrame(columns=cols)

    rows = []
    for uid, g in df.groupby("unique_id"):
        g = g.sort_values("ds")
        start, end = g["ds"].iloc[0], g["ds"].iloc[-1]
        expected = pd.date_range(start, end, freq="h")
        missing = expected.difference(g["ds"])
        blocks = _missing_hour_blocks(g["ds"])

        rows.append(
            {
                "unique_id": uid,
                "start": start,
                "end": end,
                "expected_hours": int(len(expected)),
                "actual_hours": int(len(g)),
                "missing_hours": int(len(missing)),
                "missing_ratio": float(len(missing) / max(len(expected), 1)),
                "n_missing_blocks": int(len(blocks)),
                "largest_missing_block_hours": int(max([b[2] for b in blocks], default=0)),
                "first_missing_block_start": blocks[0][0] if blocks else pd.NaT,
                "first_missing_block_end": blocks[0][1] if blocks else pd.NaT,
            }
        )

    rep = pd.DataFrame(rows)
    return rep.sort_values(["missing_ratio", "missing_hours"], ascending=False)


def _enforce_hourly_grid(
    df: pd.DataFrame,
    *,
    label: str,
    policy: str = "raise",  # "raise" | "drop_incomplete_series"
) -> pd.DataFrame:
    if df.empty:
        raise RuntimeError(
            f"[{label}][GRID] Cannot enforce hourly grid: input dataframe is empty. "
            "This is upstream (fetch) failure, not a grid issue."
        )

    rep = _hourly_grid_report(df)
    if rep.empty:
        raise RuntimeError(
            f"[{label}][GRID] No series found to report on (rep empty). "
            "This indicates upstream emptiness or missing 'unique_id' groups."
        )

    worst = rep.iloc[0].to_dict()

    if worst["missing_hours"] == 0:
        return df

    print(f"[{label}][GRID] report (top):\n{rep.head(10).to_string(index=False)}")

    if policy == "drop_incomplete_series":
        bad_uids = rep.loc[rep["missing_hours"] > 0, "unique_id"].tolist()
        kept = df.loc[~df["unique_id"].isin(bad_uids)].copy()
        print(f"[{label}][GRID] policy=drop_incomplete_series dropped={bad_uids} kept_series={kept['unique_id'].nunique()}")
        if kept.empty:
            raise RuntimeError(f"[{label}][GRID] all series dropped due to missing hours")
        return kept

    worst_uid = worst["unique_id"]
    g = df[df["unique_id"] == worst_uid].sort_values("ds")
    blocks = _missing_hour_blocks(g["ds"])
    raise RuntimeError(
        f"[{label}][GRID] Missing hours detected (no imputation). "
        f"worst_unique_id={worst_uid} missing_hours={worst['missing_hours']} "
        f"missing_ratio={worst['missing_ratio']:.3f} blocks(sample)={blocks[:3]}"
    )


def _validate_hourly_grid_fail_loud(
    df: pd.DataFrame,
    *,
    max_missing_ratio: float = 0.0,
    label: str = "generation",
) -> None:
    # Keep your original basic checks:
    if df.empty:
        raise RuntimeError(f"[{label}] empty dataframe")

    bad = df["ds"].isna().sum()
    if bad:
        raise RuntimeError(f"[{label}] ds has NaT values bad={int(bad)}")

    dup = df.duplicated(subset=["unique_id", "ds"]).sum()
    if dup:
        raise RuntimeError(f"[{label}] duplicate (unique_id, ds) rows dup={int(dup)}")

    rep = _hourly_grid_report(df)
    worst = rep.iloc[0].to_dict()
    if worst["missing_ratio"] > max_missing_ratio:
        print(f"[{label}][GRID] report (top):\n{rep.head(10).to_string(index=False)}")
        worst_uid = worst["unique_id"]
        g = df[df["unique_id"] == worst_uid].sort_values("ds")
        blocks = _missing_hour_blocks(g["ds"])
        raise RuntimeError(
            f"[{label}][GRID] Missing hours detected (no imputation allowed). "
            f"unique_id={worst_uid} missing_hours={worst['missing_hours']} "
            f"missing_ratio={worst['missing_ratio']:.3f} blocks(sample)={blocks[:3]}"
        )



def _add_time_features(df: pd.DataFrame) -> pd.DataFrame:
    out = df.copy()
    out["hour"] = out["ds"].dt.hour
    out["dow"] = out["ds"].dt.dayofweek

    out["hour_sin"] = np.sin(2 * np.pi * out["hour"] / 24)
    out["hour_cos"] = np.cos(2 * np.pi * out["hour"] / 24)

    out["dow_sin"] = np.sin(2 * np.pi * out["dow"] / 7)
    out["dow_cos"] = np.cos(2 * np.pi * out["dow"] / 7)

    return out.drop(columns=["hour", "dow"])

def _infer_model_columns(cv_df: pd.DataFrame) -> list[str]:
    """
    Infer StatsForecast model prediction columns from a cross_validation dataframe.

    We treat as "model columns" those that:
      - are not core columns (unique_id, ds, cutoff, y)
      - are not metadata columns (index, level_0, etc.)
      - are not interval columns like '<model>-lo-80' or '<model>-hi-95'
    """
    # Core columns from StatsForecast + pandas residuals from reset_index()
    core = {"unique_id", "ds", "cutoff", "y", "index", "level_0", "level_1"}
    cols = [c for c in cv_df.columns if c not in core]

    model_cols: set[str] = set()
    interval_pat = re.compile(r"-(lo|hi)-\d+$")
    for c in cols:
        if interval_pat.search(c):
            continue
        model_cols.add(c)

    return sorted(model_cols)


def compute_leaderboard(
    cv_df: pd.DataFrame,
    *,
    confidence_levels: tuple[int, int] = (80, 95),
) -> pd.DataFrame:
    """
    Build an aggregated leaderboard from StatsForecast cross_validation output.

    Returns columns:
      - model, rmse, mae, mape, valid_rows
      - coverage_<level> if interval columns exist
    """
    required = {"y", "unique_id", "ds", "cutoff"}
    missing = required - set(cv_df.columns)
    if missing:
        raise ValueError(f"[leaderboard] cv_df missing required columns: {sorted(missing)}")

    model_cols = _infer_model_columns(cv_df)
    if not model_cols:
        raise RuntimeError(
            f"[leaderboard] Could not infer any model prediction columns. "
            f"cv_df columns={cv_df.columns.tolist()}"
        )

    rows: list[dict[str, Any]] = []
    y_true = cv_df["y"].to_numpy()

    for m in model_cols:
        if m not in cv_df.columns:
            continue

        y_pred = cv_df[m].to_numpy()
        valid_mask = np.isfinite(y_true) & np.isfinite(y_pred)
        valid_rows = int(valid_mask.sum())

        metrics = {
            "model": m,
            "rmse": float(ForecastMetrics.rmse(y_true, y_pred)),
            "mae": float(ForecastMetrics.mae(y_true, y_pred)),
            "mape": float(ForecastMetrics.mape(y_true, y_pred)),
            "valid_rows": valid_rows,
        }

        # Coverage if interval columns exist
        for lvl in confidence_levels:
            lo_col = f"{m}-lo-{lvl}"
            hi_col = f"{m}-hi-{lvl}"
            if lo_col in cv_df.columns and hi_col in cv_df.columns:
                cov = ForecastMetrics.coverage(
                    y_true,
                    cv_df[lo_col].to_numpy(),
                    cv_df[hi_col].to_numpy(),
                )
                metrics[f"coverage_{lvl}"] = float(cov)

        rows.append(metrics)

    lb = pd.DataFrame(rows)
    if lb.empty:
        raise RuntimeError("[leaderboard] computed empty leaderboard (no usable model columns).")

    # Fail-loud sorting: rmse NaNs should sort last
    lb = lb.sort_values(["rmse"], ascending=True, na_position="last").reset_index(drop=True)
    return lb


def compute_baseline_metrics(
    cv_df: pd.DataFrame,
    *,
    model_name: str,
    threshold_k: float = 2.0,
) -> dict:
    """
    Compute baseline metrics for drift detection from CV output.

    We compute RMSE/MAE per (unique_id, cutoff) window, then aggregate:
      rmse_mean, rmse_std, drift_threshold_rmse = mean + k*std

    No imputation/filling: metrics are computed only from finite values.
    """
    required = {"unique_id", "cutoff", "y", model_name}
    missing = required - set(cv_df.columns)
    if missing:
        raise ValueError(
            f"[baseline] cv_df missing required columns for model '{model_name}': {sorted(missing)}"
        )

    # Compute per-window metrics (unique_id, cutoff)
    def _window_metrics(g: pd.DataFrame) -> pd.Series:
        yt = g["y"].to_numpy()
        yp = g[model_name].to_numpy()
        valid = np.isfinite(yt) & np.isfinite(yp)
        if valid.sum() == 0:
            return pd.Series({"rmse": np.nan, "mae": np.nan, "valid_rows": 0})
        return pd.Series({
            "rmse": ForecastMetrics.rmse(yt, yp),
            "mae": ForecastMetrics.mae(yt, yp),
            "valid_rows": int(valid.sum()),
        })

    per_window = (
        cv_df.groupby(["unique_id", "cutoff"], sort=False, dropna=False)
        .apply(_window_metrics)
        .reset_index()
    )

    # Fail loud if baseline is entirely NaN
    if per_window["rmse"].notna().sum() == 0:
        sample_cols = ["unique_id", "cutoff", "y", model_name]
        raise RuntimeError(
            "[baseline] All per-window RMSE are NaN. "
            "This usually means predictions or y are non-finite everywhere. "
            f"Sample:\n{cv_df[sample_cols].head(20).to_string(index=False)}"
        )

    rmse_mean = float(per_window["rmse"].mean(skipna=True))
    rmse_std = float(per_window["rmse"].std(skipna=True, ddof=0))
    mae_mean = float(per_window["mae"].mean(skipna=True))
    mae_std = float(per_window["mae"].std(skipna=True, ddof=0))

    baseline = {
        "model": model_name,
        "rmse_mean": rmse_mean,
        "rmse_std": rmse_std,
        "mae_mean": mae_mean,
        "mae_std": mae_std,
        "drift_threshold_rmse": float(rmse_mean + threshold_k * rmse_std),
        "drift_threshold_mae": float(mae_mean + threshold_k * mae_std),
        "n_series": int(per_window["unique_id"].nunique()),
        "n_windows": int(per_window["cutoff"].nunique()),
        "per_window_rows": int(len(per_window)),
    }

    # Optional per-series baseline (useful later if you want drift per series)
    per_series = (
        per_window.groupby("unique_id")[["rmse", "mae"]]
        .agg(rmse_mean=("rmse", "mean"), rmse_std=("rmse", lambda s: s.std(ddof=0)),
             mae_mean=("mae", "mean"), mae_std=("mae", lambda s: s.std(ddof=0)))
        .reset_index()
    )
    per_series["drift_threshold_rmse"] = per_series["rmse_mean"] + threshold_k * per_series["rmse_std"]
    per_series["drift_threshold_mae"] = per_series["mae_mean"] + threshold_k * per_series["mae_std"]
    baseline["per_series"] = per_series.to_dict(orient="records")

    return baseline



@dataclass
class ForecastConfig:
    horizon: int = 24
    confidence_levels: tuple[int, int] = (80, 95)


class RenewableForecastModel:
    def __init__(self, horizon: int = 24, confidence_levels: tuple[int, int] = (80, 95)):
        self.horizon = horizon
        self.confidence_levels = confidence_levels
        self.sf = None
        self._train_df = None  # contains y + exog columns
        self._exog_cols: list[str] = []
        self.fitted = False

    def prepare_training_df(self, df: pd.DataFrame, weather_df: Optional[pd.DataFrame]) -> pd.DataFrame:
        """
        Prepare training dataframe for modeling.

        NOTE: This function now supports BOTH:
        1. Raw data (from old pipeline): performs all preprocessing
        2. Preprocessed data (from dataset_builder): skips preprocessing

        Detection: If df has time features (hour_sin, hour_cos, etc.), assumes already preprocessed.
        """
        req = {"unique_id", "ds", "y"}
        if not req.issubset(df.columns):
            raise ValueError(f"generation df missing cols={sorted(req - set(df.columns))}")

        if df.empty:
            raise RuntimeError(
                "[generation] Empty generation dataframe passed into modeling. "
                "This is upstream (EIA fetch/cache) failure — inspect fetch_diagnostics and fetch_generation logs."
            )

        work = df.copy()
        work["ds"] = pd.to_datetime(work["ds"], errors="raise")
        work = work.sort_values(["unique_id", "ds"]).reset_index(drop=True)

        # Check if data is already preprocessed (has time features)
        time_features = ["hour_sin", "hour_cos", "dow_sin", "dow_cos"]
        is_preprocessed = all(col in work.columns for col in time_features)

        if is_preprocessed:
            # Data already preprocessed by dataset_builder, skip preprocessing
            print("[prepare_training_df] Detected preprocessed data (has time features), skipping preprocessing")

            # Validate y has no nulls
            y_null = work["y"].isna()
            if y_null.any():
                sample = work.loc[y_null, ["unique_id", "ds", "y"]].head(25)
                raise RuntimeError(
                    f"[generation][Y] Found null y values. rows={int(y_null.sum())}. "
                    f"Sample:\n{sample.to_string(index=False)}"
                )

            # Identify exog columns (time features + any weather vars present)
            wcols = [c for c in WEATHER_VARS if c in work.columns]
            self._exog_cols = time_features + wcols

            print(f"[prepare_training_df] Using {len(self._exog_cols)} exog features: {self._exog_cols[:5]}...")
            return work

        # Legacy path: raw data, perform all preprocessing
        print("[prepare_training_df] Processing raw data (legacy path)")

        y_null = work["y"].isna()
        if y_null.any():
            sample = work.loc[y_null, ["unique_id", "ds", "y"]].head(25)
            raise RuntimeError(
                f"[generation][Y] Found null y values (no imputation). rows={int(y_null.sum())}. "
                f"Sample:\n{sample.to_string(index=False)}"
            )

        work = _enforce_hourly_grid(work, label="generation", policy="drop_incomplete_series")
        work = _add_time_features(work)

        if weather_df is not None and not weather_df.empty:
            if not {"ds", "region"}.issubset(weather_df.columns):
                raise ValueError("weather_df must have columns ['ds','region', ...]")

            work["region"] = work["unique_id"].str.split("_").str[0]

            wcols = [c for c in WEATHER_VARS if c in weather_df.columns]
            if not wcols:
                raise ValueError("weather_df has none of expected WEATHER_VARS")

            merged = work.merge(
                weather_df[["ds", "region"] + wcols],
                on=["ds", "region"],
                how="left",
                validate="many_to_one",
            )

            missing_any = merged[wcols].isna().any(axis=1)
            if missing_any.any():
                sample = merged.loc[missing_any, ["unique_id", "ds", "region"] + wcols].head(10)
                raise RuntimeError(
                    f"[weather][ALIGN] Missing weather after merge rows={int(missing_any.sum())}. "
                    f"Sample:\n{sample.to_string(index=False)}"
                )

            work = merged.drop(columns=["region"])
            self._exog_cols = ["hour_sin", "hour_cos", "dow_sin", "dow_cos"] + wcols
        else:
            self._exog_cols = ["hour_sin", "hour_cos", "dow_sin", "dow_cos"]

        return work



    def fit(self, df: pd.DataFrame, weather_df: Optional[pd.DataFrame] = None) -> None:
        from statsforecast import StatsForecast
        from statsforecast.models import (MSTL, AutoARIMA, AutoETS,
                                          SeasonalNaive)

        train_df = self.prepare_training_df(df, weather_df)

        models = [
            AutoARIMA(season_length=24),
            SeasonalNaive(season_length=24),
            AutoETS(season_length=24),
            MSTL(season_length=[24, 168], trend_forecaster=AutoARIMA(), alias="MSTL_ARIMA"),
        ]

        # Try to add AutoTheta and AutoCES if available (same as cross_validate)
        try:
            from statsforecast.models import AutoCES, AutoTheta
            models.append(AutoTheta(season_length=24))
            models.append(AutoCES(season_length=24))
            print("[fit] Using expanded model set: +AutoTheta, +AutoCES")
        except ImportError:
            print("[fit] AutoTheta/AutoCES not available, using core models only")

        self.sf = StatsForecast(models=models, freq="h", n_jobs=-1)
        self._train_df = train_df
        self.fitted = True

        print(f"[fit] rows={len(train_df)} series={train_df['unique_id'].nunique()} exog_cols={self._exog_cols}")

    def build_future_X_df(self, future_weather: pd.DataFrame) -> pd.DataFrame:
        """
        Build future X_df for forecast horizon using forecast weather.
        Must include: unique_id, ds, and exactly the exog columns used in training.
        """
        if not self.fitted:
            raise RuntimeError("fit() first")

        if future_weather is None or future_weather.empty:
            raise RuntimeError("future_weather required to forecast with regressors (no fabrication).")

        if not {"ds", "region"}.issubset(future_weather.columns):
            raise ValueError("future_weather must have columns ['ds','region', ...]")

        # Create the future ds grid per series
        last_ds = self._train_df.groupby("unique_id")["ds"].max()
        frames = []
        for uid, end in last_ds.items():
            future_ds = pd.date_range(end + pd.Timedelta(hours=1), periods=self.horizon, freq="h")
            frames.append(pd.DataFrame({"unique_id": uid, "ds": future_ds}))
        X = pd.concat(frames, ignore_index=True)

        X = _add_time_features(X)
        X["region"] = X["unique_id"].str.split("_").str[0]

        wcols = [c for c in WEATHER_VARS if c in future_weather.columns]
        X = X.merge(
            future_weather[["ds", "region"] + wcols],
            on=["ds", "region"],
            how="left",
            validate="many_to_one",
        )

        # Fail loud on missing future regressors
        needed = [c for c in self._exog_cols if c not in ["hour_sin", "hour_cos", "dow_sin", "dow_cos"]]  # weather cols
        if needed:
            missing_any = X[needed].isna().any(axis=1)
            if missing_any.any():
                sample = X.loc[missing_any, ["unique_id", "ds", "region"] + needed].head(10)
                raise RuntimeError(
                    f"[future_weather][ALIGN] Missing future weather rows={int(missing_any.sum())}. "
                    f"Sample:\n{sample.to_string(index=False)}"
                )

        X = X.drop(columns=["region"])
        keep = ["unique_id", "ds"] + self._exog_cols
        return X[keep].sort_values(["unique_id", "ds"]).reset_index(drop=True)

    def predict(self, future_weather: pd.DataFrame, best_model: Optional[str] = None) -> pd.DataFrame:
        """Generate forecasts using fitted models.

        Args:
            future_weather: Future weather data for forecast period
            best_model: Optional model name to use for predictions. If provided, only this model's
                       predictions will be included in output (as 'yhat' column). If None, all
                       fitted models' predictions are returned.

        Returns:
            DataFrame with forecast predictions. If best_model is specified, includes:
            - unique_id, ds: identifiers
            - yhat: point forecast from best_model
            - yhat-lo-{level}, yhat-hi-{level}: prediction intervals from best_model

            If best_model is None, includes predictions from all fitted models.
        """
        if not self.fitted:
            raise RuntimeError("fit() first")

        X_df = self.build_future_X_df(future_weather)

        # IMPORTANT: If you fit models using exogenous regressors, you must supply X_df at forecast time.
        fcst = self.sf.forecast(
            h=self.horizon,
            df=self._train_df,
            X_df=X_df,
            level=list(self.confidence_levels),
        ).reset_index()

        # Apply minimal physical constraints for solar series
        fcst = self._apply_minimal_solar_constraints(fcst, X_df)

        # If best_model is specified, filter to only that model's predictions
        if best_model is not None:
            if best_model not in fcst.columns:
                available_models = [c for c in fcst.columns if c not in ['unique_id', 'ds']]
                raise ValueError(
                    f"[predict] best_model '{best_model}' not found in forecast output. "
                    f"Available models: {available_models}"
                )

            # Extract best model's predictions and rename to standard 'yhat' format
            keep_cols = ['unique_id', 'ds', best_model]

            # Also keep prediction interval columns for the best model
            for level in self.confidence_levels:
                lo_col = f"{best_model}-lo-{level}"
                hi_col = f"{best_model}-hi-{level}"
                if lo_col in fcst.columns:
                    keep_cols.append(lo_col)
                if hi_col in fcst.columns:
                    keep_cols.append(hi_col)

            fcst = fcst[keep_cols].copy()

            # Rename model column to 'yhat' and interval columns to match
            # NOTE: Using underscores (not hyphens) to match dashboard expectations
            rename_map = {best_model: 'yhat'}
            for level in self.confidence_levels:
                old_lo = f"{best_model}-lo-{level}"
                old_hi = f"{best_model}-hi-{level}"
                if old_lo in fcst.columns:
                    rename_map[old_lo] = f"yhat_lo_{level}"  # Changed hyphen to underscore
                if old_hi in fcst.columns:
                    rename_map[old_hi] = f"yhat_hi_{level}"  # Changed hyphen to underscore

            fcst = fcst.rename(columns=rename_map)

        return fcst

    def _apply_minimal_solar_constraints(
        self,
        fcst: pd.DataFrame,
        X_df: pd.DataFrame
    ) -> pd.DataFrame:
        """
        Apply ONLY minimal physical constraints for impossible cases.

        Philosophy: Let the model learn natural patterns. Only intervene when
        physics is violated (e.g., generation when sun is below horizon).

        Constraint: Zero generation when BOTH radiation sources are zero
        (sun is definitely below horizon).
        """
        # Merge forecast with features to get radiation data
        fcst_with_features = fcst.merge(
            X_df[["unique_id", "ds", "direct_radiation", "diffuse_radiation"]],
            on=["unique_id", "ds"],
            how="left"
        )

        # Identify solar series
        solar_mask = fcst_with_features["unique_id"].str.endswith("_SUN")

        # ONLY constraint: Zero generation when BOTH radiation sources are zero
        # (sun is definitely below horizon)
        no_sun_mask = (
            (fcst_with_features["direct_radiation"] == 0) &
            (fcst_with_features["diffuse_radiation"] == 0)
        )

        # Apply constraint to solar series
        constrain_mask = solar_mask & no_sun_mask

        # Set all forecast columns to 0 (cannot generate without any sunlight)
        # Note: At this point, columns are model names (AutoARIMA, MSTL_ARIMA, etc.)
        # and their intervals (model-lo-80, model-hi-80, etc.), not "yhat"
        # Exclude unique_id, ds, and feature columns
        exclude_cols = {'unique_id', 'ds', 'direct_radiation', 'diffuse_radiation'}
        forecast_cols = [c for c in fcst_with_features.columns if c not in exclude_cols]

        for col in forecast_cols:
            fcst_with_features.loc[constrain_mask, col] = 0.0

        return fcst_with_features.drop(columns=["direct_radiation", "diffuse_radiation"], errors="ignore")

    def cross_validate(
        self,
        df: pd.DataFrame,
        weather_df: Optional[pd.DataFrame] = None,
        n_windows: int = 3,
        step_size: int = 168,
        expanded_models: bool = True,
    ) -> tuple[pd.DataFrame, pd.DataFrame]:
        from statsforecast import StatsForecast
        from statsforecast.models import (MSTL, AutoARIMA, AutoETS,
                                          SeasonalNaive)

        train_df = self.prepare_training_df(df, weather_df)

        # Core models
        models = [
            AutoARIMA(season_length=24),
            SeasonalNaive(season_length=24),
            AutoETS(season_length=24),
            MSTL(season_length=[24, 168], trend_forecaster=AutoARIMA(), alias="MSTL_ARIMA"),
        ]

        # Add expanded models if requested
        if expanded_models:
            try:
                from statsforecast.models import AutoCES, AutoTheta
                models.extend([
                    AutoTheta(season_length=24),
                    AutoCES(season_length=24),
                ])
                print("[cv] Using expanded model set: +AutoTheta, +AutoCES")
            except ImportError:
                print("[cv] AutoTheta/AutoCES not available, using core models only")

        sf = StatsForecast(models=models, freq="h", n_jobs=-1)

        print(
            f"[cv] windows={n_windows} step={step_size} h={self.horizon} "
            f"rows={len(train_df)} series={train_df['unique_id'].nunique()}"
        )

        cv = sf.cross_validation(
            df=train_df,
            h=self.horizon,
            step_size=step_size,
            n_windows=n_windows,
            level=list(self.confidence_levels),
        ).reset_index()

        leaderboard = compute_leaderboard(cv, confidence_levels=self.confidence_levels)
        return cv, leaderboard


class RenewableLGBMForecaster:
    """
    LightGBM-based forecaster with interpretability support.

    This forecaster is designed for model interpretability (SHAP, feature importance)
    rather than production forecasting. Use alongside RenewableForecastModel which
    provides better uncertainty quantification via statistical models.

    Uses skforecast's ForecasterRecursive with LightGBM as the base estimator,
    along with rolling window features for temporal patterns.
    """

    def __init__(
        self,
        horizon: int = 24,
        lags: int = 168,  # 7 days of lags
        rolling_window_sizes: list[int] | None = None,
    ):
        """
        Initialize the LightGBM forecaster.

        Args:
            horizon: Forecast horizon in hours
            lags: Number of lag features (default 168 = 7 days)
            rolling_window_sizes: Window sizes for rolling features (default [24, 168])
        """
        self.horizon = horizon
        self.lags = lags
        self.rolling_window_sizes = rolling_window_sizes or [24, 168]
        self.forecaster = None
        self._exog_features: list[str] = []
        self.fitted = False

    def fit(self, y: pd.Series, exog: Optional[pd.DataFrame] = None) -> None:
        """
        Fit LightGBM forecaster with rolling features.

        Args:
            y: Target time series (must have DatetimeIndex)
            exog: Optional exogenous features (must have same index as y)
        """
        # Import here to make dependencies optional
        try:
            from lightgbm import LGBMRegressor
            from skforecast.preprocessing import RollingFeatures
            from skforecast.recursive import ForecasterRecursive
        except ImportError as e:
            raise ImportError(
                "LightGBM forecaster requires: lightgbm, skforecast. "
                f"Install with: pip install lightgbm skforecast. Error: {e}"
            )

        # Create rolling features
        # Note: skforecast requires window_sizes length to match stats length
        # We use 'mean' for each window size to capture different temporal patterns
        stats_list = ['mean'] * len(self.rolling_window_sizes)
        window_features = RollingFeatures(
            stats=stats_list,
            window_sizes=self.rolling_window_sizes,
        )

        # Initialize forecaster
        self.forecaster = ForecasterRecursive(
            estimator=LGBMRegressor(
                random_state=42,
                verbose=-1,
                n_estimators=100,
                learning_rate=0.1,
                max_depth=6,
                num_leaves=31,
                min_child_samples=20,
            ),
            lags=self.lags,
            window_features=window_features,
        )

        # Store exog feature names
        if exog is not None:
            self._exog_features = exog.columns.tolist()
        else:
            self._exog_features = []

        # Fit the model
        self.forecaster.fit(y=y, exog=exog)
        self.fitted = True

        n_features = len(self.get_feature_importances())
        print(f"[LGBMForecaster.fit] fitted with {n_features} features, lags={self.lags}")

    def predict(self, steps: int, exog: Optional[pd.DataFrame] = None) -> pd.Series:
        """
        Generate predictions.

        Args:
            steps: Number of steps to forecast
            exog: Exogenous features for forecast period

        Returns:
            Series of predictions
        """
        if not self.fitted or self.forecaster is None:
            raise RuntimeError("Must call fit() before predict()")

        return self.forecaster.predict(steps=steps, exog=exog)

    def get_feature_importances(self) -> pd.DataFrame:
        """
        Extract feature importance from fitted model.

        Returns:
            DataFrame with 'feature' and 'importance' columns
        """
        if not self.fitted or self.forecaster is None:
            raise RuntimeError("Must call fit() before get_feature_importances()")

        return self.forecaster.get_feature_importances()

    def create_train_X_y(
        self,
        y: pd.Series,
        exog: Optional[pd.DataFrame] = None,
    ) -> tuple[pd.DataFrame, pd.Series]:
        """
        Create training matrices for SHAP analysis.

        Args:
            y: Target time series
            exog: Optional exogenous features

        Returns:
            Tuple of (X_train, y_train)
        """
        if not self.fitted or self.forecaster is None:
            raise RuntimeError("Must call fit() before create_train_X_y()")

        return self.forecaster.create_train_X_y(y=y, exog=exog)

    @property
    def regressor(self):
        """
        Access internal LightGBM estimator for SHAP.

        Returns:
            The fitted LGBMRegressor instance
        """
        if not self.fitted or self.forecaster is None:
            raise RuntimeError("Must call fit() before accessing regressor")

        # Use 'estimator' (new API) with fallback to 'regressor' (deprecated)
        if hasattr(self.forecaster, "estimator"):
            return self.forecaster.estimator
        return self.forecaster.regressor

    @property
    def exog_features(self) -> list[str]:
        """Return list of exogenous feature names used in training."""
        return self._exog_features.copy()


if __name__ == "__main__":
    # REAL EXAMPLE: multi-series WND with strict gates and CV

    from src.renewable.eia_renewable import EIARenewableFetcher
    from src.renewable.open_meteo import OpenMeteoRenewable

    regions = ["CALI", "ERCO", "MISO"]
    fuel = "WND"
    start_date = "2024-11-01"
    end_date = "2024-12-15"

    fetcher = EIARenewableFetcher(debug_env=True)
    gen = fetcher.fetch_all_regions(fuel, start_date, end_date, regions=regions)
    _log_series_summary(gen, label="generation_raw")

    weather_api = OpenMeteoRenewable(strict=True)
    wx_hist = weather_api.fetch_all_regions_historical(regions, start_date, end_date, debug=True)

    model = RenewableForecastModel(horizon=24, confidence_levels=(80, 95))

    # CV (historical): regressors live in df, no filling allowed
    cv = model.cross_validate(gen, weather_df=wx_hist, n_windows=3, step_size=168)
    print(cv.head().to_string(index=False))

    # Optional: fit + forecast next 24h using forecast weather (no leakage)
    # wx_future = weather_api.fetch_all_regions_forecast(regions, horizon_hours=48, debug=True)
    # model.fit(gen, weather_df=wx_hist)
    # fcst = model.predict(future_weather=wx_future)
    # print(fcst.head().to_string(index=False))



# Module: Pipeline Tasks

**File:** `src/renewable/tasks.py`

This module orchestrates the complete pipeline:

1. **Fetch generation data** from EIA
2. **Fetch weather data** from Open-Meteo
3. **Train models** with cross-validation
4. **Generate forecasts** with prediction intervals
5. **Compute drift metrics** vs baseline

## Key Feature: Adaptive CV

Cross-validation requires sufficient data:
```
Minimum rows = horizon + (n_windows × step_size)
```

For short series, we **adapt** the CV settings automatically.

In [None]:
%%writefile src/renewable/tasks.py
# file: src/renewable/tasks.py
"""Renewable energy forecasting pipeline tasks.

Idempotent tasks for:
- Fetching EIA renewable generation data
- Fetching weather data from Open-Meteo
- Training probabilistic models
- Generating forecasts with intervals
- Computing drift metrics
"""

import argparse
import logging
import os
from dataclasses import dataclass, field
from datetime import datetime, timedelta, timezone
from pathlib import Path
from typing import Optional

import pandas as pd

from src.renewable.eia_renewable import EIARenewableFetcher
from src.renewable.modeling import (
    RenewableForecastModel,
    RenewableLGBMForecaster,
    _log_series_summary,
    _add_time_features,
    compute_baseline_metrics,
    WEATHER_VARS,
)
from src.renewable.model_interpretability import (
    InterpretabilityReport,
    generate_full_interpretability_report,
)
from src.renewable.open_meteo import OpenMeteoRenewable
from src.renewable.regions import REGIONS, list_regions
from src.renewable.dataset_builder import build_modeling_dataset, PreprocessingReport

logger = logging.getLogger(__name__)


@dataclass
class RenewablePipelineConfig:
    """Configuration for renewable forecasting pipeline."""

    # Data parameters
    regions: list[str] = field(default_factory=lambda: ["CALI", "ERCO", "MISO", "PJM", "SWPP"])
    fuel_types: list[str] = field(default_factory=lambda: ["WND", "SUN"])
    start_date: str = ""  # Set dynamically
    end_date: str = ""  # Set dynamically
    lookback_days: int = 30

    # Forecast parameters
    horizon: int = 24
    confidence_levels: tuple[int, int] = (80, 95)
    horizon_preset: Optional[str] = None  # "24h" | "48h" | "72h"

    # CV parameters
    cv_windows: int = 5
    cv_step_size: int = 168  # 1 week

    # Model parameters
    enable_interpretability: bool = True  # LightGBM SHAP analysis (on by default)

    # Preprocessing parameters
    negative_policy: str = "clamp"  # "clamp" | "fail_loud" | "hybrid"
    hourly_grid_policy: str = "drop_incomplete_series"  # "drop_incomplete_series" | "fail_loud"

    # Output paths
    data_dir: str = "data/renewable"
    overwrite: bool = False

    # Horizon preset definitions (class-level constant)
    _PRESETS = {
        "24h": {"horizon": 24, "cv_windows": 2, "lookback_days": 15},
        "48h": {"horizon": 48, "cv_windows": 3, "lookback_days": 21},
        "72h": {"horizon": 72, "cv_windows": 3, "lookback_days": 28},
    }

    def __post_init__(self):
        # Apply horizon preset if specified
        if self.horizon_preset and self.horizon_preset in self._PRESETS:
            preset = self._PRESETS[self.horizon_preset]
            # Use object.__setattr__ since this is a dataclass
            object.__setattr__(self, "horizon", preset["horizon"])
            object.__setattr__(self, "cv_windows", preset["cv_windows"])
            object.__setattr__(self, "lookback_days", preset["lookback_days"])
            logger.info(f"[config] Applied preset '{self.horizon_preset}': horizon={preset['horizon']}h")

        # Set default dates if not provided
        if not self.end_date:
            self.end_date = datetime.now(timezone.utc).strftime("%Y-%m-%d")
        if not self.start_date:
            end = datetime.strptime(self.end_date, "%Y-%m-%d")
            start = end - timedelta(days=self.lookback_days)
            self.start_date = start.strftime("%Y-%m-%d")

        # Validate configuration
        warnings = self._validate()
        for warning in warnings:
            logger.warning(f"[config] {warning}")

    def _validate(self) -> list[str]:
        """Validate configuration and return warnings."""
        warnings = []

        # Check minimum data requirement
        available_hours = self.lookback_days * 24
        required_hours = self.horizon + (self.cv_windows * self.cv_step_size)
        if available_hours < required_hours:
            warnings.append(
                f"Insufficient data: need {required_hours}h, have {available_hours}h. "
                f"Increase lookback_days to {(required_hours // 24) + 1} or reduce cv_windows."
            )

        # Warn about accuracy degradation
        if self.horizon > 72:
            warnings.append(
                f"Horizon {self.horizon}h exceeds recommended max (72h). "
                f"Weather forecast accuracy degrades significantly beyond 3 days."
            )

        return warnings

    def generation_path(self) -> Path:
        return Path(self.data_dir) / "generation.parquet"

    def weather_path(self) -> Path:
        return Path(self.data_dir) / "weather.parquet"

    def forecasts_path(self) -> Path:
        return Path(self.data_dir) / "forecasts.parquet"

    def baseline_path(self) -> Path:
        return Path(self.data_dir) / "baseline.json"

    def interpretability_dir(self) -> Path:
        return Path(self.data_dir) / "interpretability"

    def preprocessing_dir(self) -> Path:
        return Path(self.data_dir) / "preprocessing"


def fetch_renewable_data(
    config: RenewablePipelineConfig,
    fetch_diagnostics: Optional[list[dict]] = None,
) -> pd.DataFrame:
    """Task 1: Fetch EIA generation data for all regions and fuel types.

    Args:
        config: Pipeline configuration
        fetch_diagnostics: Optional list to capture per-region fetch metadata

    Returns:
        DataFrame with columns [unique_id, ds, y]
    """
    output_path = config.generation_path()
    output_path.parent.mkdir(parents=True, exist_ok=True)

    def _log_generation_summary(df: pd.DataFrame, source: str) -> None:
        _log_series_summary(df, value_col="y", label=f"generation_data_{source}")

        expected_series = {
            f"{region}_{fuel}" for region in config.regions for fuel in config.fuel_types
        }
        present_series = set(df["unique_id"]) if "unique_id" in df.columns else set()
        missing_series = sorted(expected_series - present_series)
        if missing_series:
            logger.warning(
                "[fetch_generation] Missing expected series (%s): %s",
                source,
                missing_series,
            )

        if df.empty:
            logger.warning("[fetch_generation] No generation data rows (%s).", source)
            return

        coverage = (
            df.groupby("unique_id")["ds"]
            .agg(min_ds="min", max_ds="max", rows="count")
            .reset_index()
            .sort_values("unique_id")
        )
        max_series_log = 25
        if len(coverage) > max_series_log:
            logger.info(
                "[fetch_generation] Coverage (%s, first %s series):\n%s",
                source,
                max_series_log,
                coverage.head(max_series_log).to_string(index=False),
            )
        else:
            logger.info("[fetch_generation] Coverage (%s):\n%s", source, coverage.to_string(index=False))

    if output_path.exists() and not config.overwrite:
        logger.info(f"[fetch_generation] exists, loading: {output_path}")
        cached = pd.read_parquet(output_path)
        # Log cached coverage to surface missing series without refetching.
        _log_generation_summary(cached, source="cache")
        return cached

    logger.info(f"[fetch_generation] Fetching {config.fuel_types} for {config.regions}")

    # Use longer timeout (90s) to handle slow EIA API responses
    fetcher = EIARenewableFetcher(timeout=90)
    all_dfs = []

    for fuel_type in config.fuel_types:
        df = fetcher.fetch_all_regions(
            fuel_type=fuel_type,
            start_date=config.start_date,
            end_date=config.end_date,
            regions=config.regions,
            diagnostics=fetch_diagnostics,
        )
        all_dfs.append(df)

    combined = pd.concat(all_dfs, ignore_index=True)
    combined = combined.sort_values(["unique_id", "ds"]).reset_index(drop=True)

    # Log fresh coverage to highlight gaps or unexpected negatives.
    _log_generation_summary(combined, source="fresh")

    if fetch_diagnostics:
        empty_series = [
            entry
            for entry in fetch_diagnostics
            if entry.get("empty")
        ]
        for entry in empty_series:
            logger.warning(
                "[fetch_generation] Empty series detail: region=%s fuel=%s total=%s pages=%s",
                entry.get("region"),
                entry.get("fuel_type"),
                entry.get("total_records"),
                entry.get("pages"),
            )

    combined.to_parquet(output_path, index=False)
    logger.info(f"[fetch_generation] Saved: {output_path} ({len(combined)} rows)")

    return combined


def fetch_renewable_weather(
    config: RenewablePipelineConfig,
    include_forecast: bool = True,
) -> pd.DataFrame:
    """Task 2: Fetch weather data for all regions.

    Args:
        config: Pipeline configuration
        include_forecast: Include forecast weather for predictions

    Returns:
        DataFrame with columns [ds, region, weather_vars...]
    """
    output_path = config.weather_path()
    output_path.parent.mkdir(parents=True, exist_ok=True)

    def _log_weather_summary(df: pd.DataFrame, source: str) -> None:
        if df.empty:
            logger.warning("[fetch_weather] No weather data rows (%s).", source)
            return

        coverage = (
            df.groupby("region")["ds"]
            .agg(min_ds="min", max_ds="max", rows="count")
            .reset_index()
            .sort_values("region")
        )
        max_region_log = 25
        if len(coverage) > max_region_log:
            logger.info(
                "[fetch_weather] Coverage (%s, first %s regions):\n%s",
                source,
                max_region_log,
                coverage.head(max_region_log).to_string(index=False),
            )
        else:
            logger.info("[fetch_weather] Coverage (%s):\n%s", source, coverage.to_string(index=False))

        missing_cols = [
            col for col in OpenMeteoRenewable.WEATHER_VARS if col not in df.columns
        ]
        if missing_cols:
            logger.warning(
                "[fetch_weather] Missing expected weather columns (%s): %s",
                source,
                missing_cols,
            )

        missing_values = {
            col: int(df[col].isna().sum())
            for col in OpenMeteoRenewable.WEATHER_VARS
            if col in df.columns and df[col].isna().any()
        }
        if missing_values:
            logger.warning(
                "[fetch_weather] Missing weather values (%s): %s",
                source,
                missing_values,
            )

    if output_path.exists() and not config.overwrite:
        logger.info(f"[fetch_weather] exists, loading: {output_path}")
        cached = pd.read_parquet(output_path)
        # Log cached weather coverage to surface missing regions/columns.
        _log_weather_summary(cached, source="cache")
        return cached

    logger.info(f"[fetch_weather] Fetching weather for {config.regions}")

    weather = OpenMeteoRenewable()

    # Historical weather
    hist_df = weather.fetch_all_regions_historical(
        regions=config.regions,
        start_date=config.start_date,
        end_date=config.end_date,
    )

    # Validate historical weather result
    if hist_df.empty:
        raise RuntimeError(
            "[fetch_weather] Historical weather returned empty DataFrame. "
            "fetch_all_regions_historical should raise an error on failure, "
            "but received empty result. Check fetch logic."
        )

    if not {"ds", "region"}.issubset(hist_df.columns):
        missing_cols = {"ds", "region"} - set(hist_df.columns)
        raise ValueError(
            f"[fetch_weather] Weather DataFrame missing required columns: {missing_cols}"
        )

    hist_regions = hist_df['region'].nunique()
    hist_rows = len(hist_df)
    logger.info(
        f"[fetch_weather] Historical: {hist_regions} regions, {hist_rows} rows"
    )

    # Forecast weather (for prediction, prevents leakage)
    if include_forecast:
        fcst_df = weather.fetch_all_regions_forecast(
            regions=config.regions,
            horizon_hours=config.horizon + 24,  # Buffer
        )

        # Validate forecast weather result
        if fcst_df.empty:
            logger.warning(
                "[fetch_weather] Forecast weather returned empty DataFrame. "
                "Using historical data only for model training and predictions."
            )
            combined = hist_df
        else:
            fcst_rows = len(fcst_df)
            logger.info(f"[fetch_weather] Forecast: {fcst_rows} rows")

            # Combine, preferring forecast for overlapping times
            combined = pd.concat([hist_df, fcst_df], ignore_index=True)
            combined = combined.drop_duplicates(subset=["ds", "region"], keep="last")
    else:
        combined = hist_df

    combined = combined.sort_values(["region", "ds"]).reset_index(drop=True)

    # Log fresh weather coverage and missing values before saving.
    _log_weather_summary(combined, source="fresh")

    combined.to_parquet(output_path, index=False)
    logger.info(f"[fetch_weather] Saved: {output_path} ({len(combined)} rows)")

    return combined


def train_renewable_models(
    config: RenewablePipelineConfig,
    generation_df: Optional[pd.DataFrame] = None,
    weather_df: Optional[pd.DataFrame] = None,
) -> tuple[pd.DataFrame, pd.DataFrame, dict]:
    """Task 3: Train models and compute baseline metrics via cross-validation.

    Args:
        config: Pipeline configuration
        generation_df: Generation data (loads from file if None)
        weather_df: Weather data (loads from file if None)

    Returns:
        Tuple of (cv_results, leaderboard, baseline_metrics)
    """
    # Load data if not provided
    if generation_df is None:
        generation_df = pd.read_parquet(config.generation_path())
    if weather_df is None:
        weather_df = pd.read_parquet(config.weather_path())

    logger.info(f"[train_models] Training on {len(generation_df)} rows")

    model = RenewableForecastModel(
        horizon=config.horizon,
        confidence_levels=config.confidence_levels,
    )

    # Compute adaptive CV settings based on shortest series
    min_series_len = generation_df.groupby("unique_id").size().min()

    # CV needs: horizon + (n_windows * step_size) rows minimum
    # Solve for n_windows: n_windows = (min_series_len - horizon) / step_size
    available_for_cv = min_series_len - config.horizon

    # Adjust step_size and n_windows to fit data
    step_size = min(config.cv_step_size, max(24, available_for_cv // 3))
    n_windows = min(config.cv_windows, max(2, available_for_cv // step_size))

    logger.info(
        f"[train_models] Adaptive CV: {n_windows} windows, "
        f"step={step_size}h (min_series={min_series_len} rows)"
    )

    # Cross-validation
    cv_results, leaderboard = model.cross_validate(
        df=generation_df,
        weather_df=weather_df,
        n_windows=n_windows,
        step_size=step_size,
    )

    best_model = leaderboard.iloc[0]["model"]
    baseline = compute_baseline_metrics(cv_results, model_name=best_model)


    logger.info(f"[train_models] Best model: {best_model}, RMSE: {baseline['rmse_mean']:.1f}")

    return cv_results, leaderboard, baseline


def train_interpretability_models(
    config: RenewablePipelineConfig,
    generation_df: Optional[pd.DataFrame] = None,
    weather_df: Optional[pd.DataFrame] = None,
) -> dict[str, InterpretabilityReport]:
    """Train LightGBM models and generate interpretability reports per series.

    This trains a separate LightGBM model for each series (region × fuel type)
    and generates SHAP, partial dependence, and feature importance artifacts.

    Note: LightGBM is used for interpretability only. The primary forecasts
    come from statistical models (MSTL/ARIMA) which provide better uncertainty
    quantification.

    Args:
        config: Pipeline configuration
        generation_df: Generation data (loads from file if None)
        weather_df: Weather data (loads from file if None)

    Returns:
        Dict mapping series_id -> InterpretabilityReport
    """
    # Load data if not provided
    if generation_df is None:
        generation_df = pd.read_parquet(config.generation_path())
    if weather_df is None:
        weather_df = pd.read_parquet(config.weather_path())

    logger.info(f"[train_interpretability] Training LightGBM for {generation_df['unique_id'].nunique()} series")

    # Ensure datetime types
    generation_df = generation_df.copy()
    generation_df["ds"] = pd.to_datetime(generation_df["ds"], errors="raise")
    weather_df = weather_df.copy()
    weather_df["ds"] = pd.to_datetime(weather_df["ds"], errors="raise")

    reports: dict[str, InterpretabilityReport] = {}
    output_dir = config.interpretability_dir()
    output_dir.mkdir(parents=True, exist_ok=True)

    for uid in sorted(generation_df["unique_id"].unique()):
        logger.info(f"[train_interpretability] Processing {uid}...")

        # Extract series data
        series_data = generation_df[generation_df["unique_id"] == uid].copy()
        series_data = series_data.sort_values("ds")

        # Prepare target series with proper frequency
        y = series_data.set_index("ds")["y"]
        y.index = pd.DatetimeIndex(y.index, freq="h")  # Set hourly frequency

        # Prepare exogenous features
        region = uid.split("_")[0]
        series_weather = weather_df[weather_df["region"] == region].copy()

        if series_weather.empty:
            logger.warning(f"[train_interpretability] No weather data for region {region}, skipping {uid}")
            continue

        # Merge weather to series timestamps
        series_data = series_data.merge(
            series_weather[["ds"] + [c for c in WEATHER_VARS if c in series_weather.columns]],
            on="ds",
            how="left",
        )

        # Add time features
        series_data = _add_time_features(series_data)

        # Build exog DataFrame aligned with y
        exog_cols = ["hour_sin", "hour_cos", "dow_sin", "dow_cos"]
        exog_cols += [c for c in WEATHER_VARS if c in series_data.columns]
        exog = series_data.set_index("ds")[exog_cols]

        # Check for missing weather
        missing_weather = exog.isna().any(axis=1).sum()
        if missing_weather > 0:
            logger.warning(f"[train_interpretability] {uid}: {missing_weather} rows with missing weather, filling with ffill/bfill")
            exog = exog.ffill().bfill()

        # Fit LightGBM forecaster
        try:
            lgbm = RenewableLGBMForecaster(
                horizon=config.horizon,
                lags=168,  # 7 days of lags
                rolling_window_sizes=[24, 168],  # 1 day, 1 week
            )
            lgbm.fit(y=y, exog=exog)

            # Create training matrices for SHAP analysis
            X_train, y_train = lgbm.create_train_X_y(y=y, exog=exog)

            # Generate interpretability report
            series_output_dir = output_dir / uid
            report = generate_full_interpretability_report(
                forecaster=lgbm.forecaster,
                X_train=X_train,
                series_id=uid,
                output_dir=series_output_dir,
                top_n_features=5,
                shap_sample_frac=0.5,
                shap_max_samples=1000,
            )
            reports[uid] = report

            logger.info(
                f"[train_interpretability] {uid}: top_features={report.top_features[:3]}"
            )

        except Exception as e:
            logger.error(f"[train_interpretability] {uid}: Failed to train - {e}")
            continue

    logger.info(f"[train_interpretability] Generated {len(reports)} interpretability reports")
    return reports


def generate_renewable_forecasts(
    config: RenewablePipelineConfig,
    generation_df: Optional[pd.DataFrame] = None,
    weather_df: Optional[pd.DataFrame] = None,
    best_model: str = "MSTL_ARIMA",
) -> pd.DataFrame:
    """Task 4: Generate forecasts with prediction intervals."""
    output_path = config.forecasts_path()
    output_path.parent.mkdir(parents=True, exist_ok=True)

    if generation_df is None:
        generation_df = pd.read_parquet(config.generation_path())
    if weather_df is None:
        weather_df = pd.read_parquet(config.weather_path())

    logger.info(f"[generate_forecasts] Generating {config.horizon}h forecasts using model={best_model}")

    # Ensure datetime types
    generation_df = generation_df.copy()
    generation_df["ds"] = pd.to_datetime(generation_df["ds"], errors="raise")
    weather_df = weather_df.copy()
    weather_df["ds"] = pd.to_datetime(weather_df["ds"], errors="raise")

    model = RenewableForecastModel(
        horizon=config.horizon,
        confidence_levels=config.confidence_levels,
    )

    # Fit uses only historical generation timestamps, weather merge will fail-loud if missing.
    model.fit(generation_df, weather_df)

    # Future weather must cover the horizon after the EARLIEST series' last timestamp
    # (different regions may have different publishing lags)
    per_series_max = generation_df.groupby("unique_id")["ds"].max()
    logger.info(f"[generate_forecasts] Per-series max timestamps:\n{per_series_max.to_dict()}")

    min_of_max = per_series_max.min()
    global_max = generation_df["ds"].max()

    logger.info(
        f"[generate_forecasts] Min of series maxes: {min_of_max}, "
        f"Global max: {global_max}, "
        f"Delta: {(global_max - min_of_max).total_seconds() / 3600:.1f}h"
    )

    # Use min of max timestamps to ensure all series have weather for their forecasts
    future_weather = weather_df[weather_df["ds"] > min_of_max].copy()

    if future_weather.empty:
        raise RuntimeError(
            "[generate_forecasts] No future weather rows found after earliest series max. "
            f"min_of_max={min_of_max}"
        )

    # Generate forecasts using best model from CV
    # The predict() method now supports best_model parameter to filter output
    logger.info(f"[generate_forecasts] Generating predictions using model: {best_model}")
    forecasts = model.predict(future_weather=future_weather, best_model=best_model)

    logger.info(
        f"[generate_forecasts] Generated {len(forecasts)} forecast rows "
        f"for {forecasts['unique_id'].nunique()} series"
    )

    forecasts.to_parquet(output_path, index=False)
    logger.info(f"[generate_forecasts] Saved: {output_path} ({len(forecasts)} rows)")

    return forecasts



def compute_renewable_drift(
    predictions: pd.DataFrame,
    actuals: pd.DataFrame,
    baseline_metrics: dict,
) -> dict:
    """Task 5: Detect drift by comparing current metrics to baseline.

    Drift is flagged when current RMSE > baseline_mean + 2*baseline_std

    Args:
        predictions: Forecast DataFrame with [unique_id, ds, yhat]
        actuals: Actual values DataFrame with [unique_id, ds, y]
        baseline_metrics: Baseline from cross-validation

    Returns:
        Dictionary with drift status and details
    """
    from src.chapter2.evaluation import ForecastMetrics

    # Merge predictions with actuals
    merged = predictions.merge(
        actuals[["unique_id", "ds", "y"]],
        on=["unique_id", "ds"],
        how="inner",
    )

    if len(merged) == 0:
        return {
            "status": "no_data",
            "message": "No overlapping data between predictions and actuals",
        }

    # Compute current metrics
    y_true = merged["y"].values
    y_pred = merged["yhat"].values

    current_rmse = ForecastMetrics.rmse(y_true, y_pred)
    current_mae = ForecastMetrics.mae(y_true, y_pred)

    # Check against threshold
    threshold = baseline_metrics.get("drift_threshold_rmse", float("inf"))
    is_drifting = current_rmse > threshold

    result = {
        "status": "drift_detected" if is_drifting else "stable",
        "current_rmse": float(current_rmse),
        "current_mae": float(current_mae),
        "baseline_rmse": float(baseline_metrics.get("rmse_mean", 0)),
        "drift_threshold": float(threshold),
        "threshold_exceeded_by": float(max(0, current_rmse - threshold)),
        "n_predictions": len(merged),
        "timestamp": datetime.utcnow().isoformat(),
    }

    if is_drifting:
        logger.warning(
            f"[drift] DRIFT DETECTED: RMSE={current_rmse:.1f} > threshold={threshold:.1f}"
        )
    else:
        logger.info(f"[drift] Stable: RMSE={current_rmse:.1f} <= threshold={threshold:.1f}")

    return result


def run_full_pipeline(
    config: RenewablePipelineConfig,
    fetch_diagnostics: Optional[list[dict]] = None,
) -> dict:
    """Run the complete renewable forecasting pipeline.

    Steps:
    1. Fetch generation data
    2. Fetch weather data
    3. Train models (CV)
    4. Generate forecasts

    Args:
        config: Pipeline configuration
        fetch_diagnostics: Optional list to capture per-region fetch metadata

    Returns:
        Dictionary with pipeline results
    """
    logger.info(f"[pipeline] Starting: {config.start_date} to {config.end_date}")
    logger.info(f"[pipeline] Regions: {config.regions}")
    logger.info(f"[pipeline] Fuel types: {config.fuel_types}")

    results = {}

    # Step 1: Fetch generation
    generation_df = fetch_renewable_data(config, fetch_diagnostics=fetch_diagnostics)
    results["generation_rows"] = len(generation_df)
    results["series_count"] = generation_df["unique_id"].nunique()

    from src.renewable.validation import validate_generation_df

    expected_series = [f"{r}_{f}" for r in config.regions for f in config.fuel_types]
    rep = validate_generation_df(
        generation_df,
        expected_series=expected_series,
        max_missing_ratio=0.02,
        max_lag_hours=48,  # choose a value consistent with EIA publishing lag
    )
    if not rep.ok:
        raise RuntimeError(f"[pipeline][generation_validation] {rep.message} details={rep.details}")

    # Step 2: Fetch weather
    weather_df = fetch_renewable_weather(config)
    results["weather_rows"] = len(weather_df)

    # Step 2.5: Build modeling-ready dataset with preprocessing
    logger.info("[pipeline] Building modeling dataset (dataset_builder.py)")
    modeling_df, prep_report = build_modeling_dataset(
        generation_df,
        weather_df,
        negative_policy=config.negative_policy,
        hourly_grid_policy=config.hourly_grid_policy,
        output_dir=config.preprocessing_dir()
    )

    results["preprocessing"] = {
        "rows_input": prep_report.rows_input,
        "rows_output": prep_report.rows_output,
        "series_dropped": len(prep_report.series_dropped_incomplete),
        "negative_action": prep_report.negative_values_action,
        "time_features": prep_report.time_features_added,
        "weather_features": prep_report.weather_features_added,
    }
    logger.info(f"[pipeline] Preprocessing: {prep_report.rows_input:,} → {prep_report.rows_output:,} rows")

    # Step 3: Train and validate (on preprocessed data)
    cv_results, leaderboard, baseline = train_renewable_models(
        config, modeling_df, weather_df=None  # Weather already merged by dataset_builder
    )
    best_model = leaderboard.iloc[0]["model"]
    results["best_model"] = best_model
    results["best_rmse"] = float(leaderboard.iloc[0]["rmse"])
    results["baseline"] = baseline
    # Save full leaderboard for dashboard display
    results["leaderboard"] = leaderboard.to_dict(orient="records")

    # Step 4: Generate forecasts (use the best model from CV)
    forecasts = generate_renewable_forecasts(
        config, modeling_df, weather_df=None, best_model=best_model
    )
    results["forecast_rows"] = len(forecasts)

    # Step 5: Train LightGBM models and generate interpretability reports (optional)
    # (LightGBM is for interpretability only - MSTL/ARIMA provide primary forecasts)
    if config.enable_interpretability:
        logger.info("[pipeline] Training interpretability models (LightGBM + SHAP)")
        try:
            interpretability_reports = train_interpretability_models(
                config, generation_df, weather_df
            )
            results["interpretability"] = {
                "series_count": len(interpretability_reports),
                "series": list(interpretability_reports.keys()),
                "output_dir": str(config.interpretability_dir()),
            }

            # Add top features summary per series
            for uid, report in interpretability_reports.items():
                results["interpretability"][f"{uid}_top_features"] = report.top_features[:3]

        except Exception as e:
            logger.warning(f"[pipeline] Interpretability training failed (non-fatal): {e}")
            results["interpretability"] = {"error": str(e)}
    else:
        logger.info("[pipeline] Interpretability disabled (enable_interpretability=False)")
        results["interpretability"] = {"enabled": False}

    if fetch_diagnostics is not None:
        results["fetch_diagnostics"] = fetch_diagnostics

    logger.info(f"[pipeline] Complete. Best model: {results['best_model']}")

    return results


def main():
    """CLI entry point for renewable pipeline."""
    parser = argparse.ArgumentParser(
        description="Renewable Energy Forecasting Pipeline",
        formatter_class=argparse.RawDescriptionHelpFormatter,
        epilog="""
Preset Examples:
  # Fast development (24h forecast, 2 CV windows, 15 days lookback)
  python -m src.renewable.tasks --preset 24h

  # Standard forecasting (48h forecast, 3 CV windows, 21 days lookback)
  python -m src.renewable.tasks --preset 48h

  # Extended planning (72h forecast, 3 CV windows, 28 days lookback)
  python -m src.renewable.tasks --preset 72h

Custom Examples:
  # 24h preset but only CALI region, skip interpretability
  python -m src.renewable.tasks --preset 24h --regions CALI --no-interpretability

  # Custom: 36h forecast with 4 CV windows
  python -m src.renewable.tasks --horizon 36 --cv-windows 4 --lookback-days 30
        """
    )

    # Preset system (NEW)
    parser.add_argument(
        "--preset",
        type=str,
        choices=["24h", "48h", "72h"],
        help="Quick preset: 24h (fast dev), 48h (standard), 72h (extended planning)",
    )

    # Flags (NEW)
    parser.add_argument(
        "--no-interpretability",
        action="store_true",
        help="Disable LightGBM interpretability analysis (speeds up pipeline)",
    )

    # Data parameters (existing)
    parser.add_argument(
        "--regions",
        type=str,
        help="Override regions (comma-separated, e.g., CALI,ERCO,MISO)",
    )
    parser.add_argument(
        "--fuel",
        type=str,
        help="Override fuel types (comma-separated, e.g., WND,SUN)",
    )

    # Forecast parameters (existing + new)
    parser.add_argument(
        "--horizon",
        type=int,
        help="Override forecast horizon in hours",
    )
    parser.add_argument(
        "--lookback-days",
        type=int,
        help="Override lookback days",
    )
    parser.add_argument(
        "--cv-windows",
        type=int,
        help="Override CV windows count",
    )

    # Output parameters
    parser.add_argument(
        "--overwrite",
        action="store_true",
        help="Overwrite existing data files",
    )
    parser.add_argument(
        "--data-dir",
        type=str,
        default="data/renewable",
        help="Output directory (default: data/renewable)",
    )

    args = parser.parse_args()

    # Configure logging
    logging.basicConfig(
        level=logging.INFO,
        format="%(asctime)s - %(name)s - %(levelname)s - %(message)s",
    )

    # Build config with preset support
    if args.preset:
        # Apply preset defaults
        logger.info(f"[CLI] Applying preset: {args.preset}")
        config = RenewablePipelineConfig(
            horizon_preset=args.preset,  # This triggers __post_init__ to apply preset
            regions=args.regions.split(",") if args.regions else ["CALI", "ERCO", "MISO"],
            fuel_types=args.fuel.split(",") if args.fuel else ["WND", "SUN"],
            enable_interpretability=not args.no_interpretability,
            overwrite=args.overwrite,
            data_dir=args.data_dir,
        )

        # Allow CLI overrides of preset values
        if args.horizon is not None:
            object.__setattr__(config, "horizon", args.horizon)
            logger.info(f"[CLI] Override: horizon={args.horizon}h")
        if args.lookback_days is not None:
            object.__setattr__(config, "lookback_days", args.lookback_days)
            logger.info(f"[CLI] Override: lookback_days={args.lookback_days}")
        if args.cv_windows is not None:
            object.__setattr__(config, "cv_windows", args.cv_windows)
            logger.info(f"[CLI] Override: cv_windows={args.cv_windows}")

    else:
        # No preset: use explicit values or defaults
        config = RenewablePipelineConfig(
            regions=args.regions.split(",") if args.regions else ["CALI", "ERCO", "MISO"],
            fuel_types=args.fuel.split(",") if args.fuel else ["WND", "SUN"],
            lookback_days=args.lookback_days if args.lookback_days else 30,
            horizon=args.horizon if args.horizon else 24,
            cv_windows=args.cv_windows if args.cv_windows else 5,
            enable_interpretability=not args.no_interpretability,
            overwrite=args.overwrite,
            data_dir=args.data_dir,
        )

    # Run pipeline
    results = run_full_pipeline(config)

    print("\n" + "=" * 60)
    print("PIPELINE RESULTS")
    print("=" * 60)
    print(f"  Series count: {results['series_count']}")
    print(f"  Generation rows: {results['generation_rows']}")
    print(f"  Weather rows: {results['weather_rows']}")
    print(f"  Forecast rows: {results['forecast_rows']}")
    print(f"  Best model: {results['best_model']}")
    print(f"  Best RMSE: {results['best_rmse']:.1f}")
    print("=" * 60)


if __name__ == "__main__":
    main()


# SQLite persistence layer

Extends the monitoring database with:
- Prediction intervals (80%, 95%)
- Weather features table
- Renewable-specific columns (fuel_type, region)

In [8]:
%%writefile src/renewable/db.py
# file: src/renewable/db.py
"""Database schema and operations for renewable forecasting.

Extends the Chapter 4 monitoring database with:
- Prediction intervals (80%, 95%)
- Weather features table
- Renewable-specific columns (fuel_type, region)
"""

import json
import sqlite3
from datetime import datetime
from pathlib import Path
from typing import Optional

import pandas as pd


def connect(db_path: str) -> sqlite3.Connection:
    """Connect to SQLite database with optimized settings."""
    Path(db_path).parent.mkdir(parents=True, exist_ok=True)
    con = sqlite3.connect(db_path)
    con.execute("PRAGMA journal_mode=WAL;")
    con.execute("PRAGMA synchronous=NORMAL;")
    return con


def init_renewable_db(db_path: str) -> None:
    """Initialize renewable forecasting database schema.

    Creates tables:
    - renewable_forecasts: Forecasts with dual intervals
    - renewable_scores: Evaluation metrics with coverage
    - weather_features: Weather data by region
    - drift_alerts: Drift detection history
    - baseline_metrics: Backtest baselines for drift thresholds
    """
    con = connect(db_path)
    cur = con.cursor()

    # Forecasts with dual prediction intervals
    cur.execute("""
    CREATE TABLE IF NOT EXISTS renewable_forecasts (
        id INTEGER PRIMARY KEY AUTOINCREMENT,
        run_id TEXT NOT NULL,
        created_at TEXT NOT NULL,
        unique_id TEXT NOT NULL,
        region TEXT NOT NULL,
        fuel_type TEXT NOT NULL,
        ds TEXT NOT NULL,
        model TEXT NOT NULL,
        yhat REAL,
        yhat_lo_80 REAL,
        yhat_hi_80 REAL,
        yhat_lo_95 REAL,
        yhat_hi_95 REAL,
        UNIQUE (run_id, model, unique_id, ds)
    );
    """)

    # Index for efficient queries
    cur.execute("""
    CREATE INDEX IF NOT EXISTS idx_forecasts_region_ds
    ON renewable_forecasts (region, ds);
    """)

    cur.execute("""
    CREATE INDEX IF NOT EXISTS idx_forecasts_fuel_ds
    ON renewable_forecasts (fuel_type, ds);
    """)

    # Evaluation scores with dual coverage
    cur.execute("""
    CREATE TABLE IF NOT EXISTS renewable_scores (
        id INTEGER PRIMARY KEY AUTOINCREMENT,
        scored_at TEXT NOT NULL,
        run_id TEXT NOT NULL,
        unique_id TEXT NOT NULL,
        region TEXT NOT NULL,
        fuel_type TEXT NOT NULL,
        model TEXT NOT NULL,
        horizon_hours INTEGER NOT NULL,
        rmse REAL,
        mae REAL,
        coverage_80 REAL,
        coverage_95 REAL,
        valid_rows INTEGER,
        UNIQUE (run_id, model, unique_id, horizon_hours)
    );
    """)

    # Weather features by region
    cur.execute("""
    CREATE TABLE IF NOT EXISTS weather_features (
        id INTEGER PRIMARY KEY AUTOINCREMENT,
        region TEXT NOT NULL,
        ds TEXT NOT NULL,
        temperature_2m REAL,
        wind_speed_10m REAL,
        wind_speed_100m REAL,
        wind_direction_10m REAL,
        direct_radiation REAL,
        diffuse_radiation REAL,
        cloud_cover REAL,
        is_forecast INTEGER DEFAULT 0,
        created_at TEXT DEFAULT CURRENT_TIMESTAMP,
        UNIQUE (region, ds, is_forecast)
    );
    """)

    cur.execute("""
    CREATE INDEX IF NOT EXISTS idx_weather_region_ds
    ON weather_features (region, ds);
    """)

    # Drift detection alerts
    cur.execute("""
    CREATE TABLE IF NOT EXISTS drift_alerts (
        id INTEGER PRIMARY KEY AUTOINCREMENT,
        alert_at TEXT NOT NULL,
        run_id TEXT,
        unique_id TEXT,
        region TEXT,
        fuel_type TEXT,
        alert_type TEXT NOT NULL,
        severity TEXT NOT NULL,
        current_rmse REAL,
        threshold_rmse REAL,
        message TEXT,
        metadata_json TEXT
    );
    """)

    cur.execute("""
    CREATE INDEX IF NOT EXISTS idx_drift_alerts_time
    ON drift_alerts (alert_at);
    """)

    # Baseline metrics for drift detection
    cur.execute("""
    CREATE TABLE IF NOT EXISTS baseline_metrics (
        id INTEGER PRIMARY KEY AUTOINCREMENT,
        created_at TEXT NOT NULL,
        unique_id TEXT NOT NULL,
        model TEXT NOT NULL,
        rmse_mean REAL NOT NULL,
        rmse_std REAL NOT NULL,
        mae_mean REAL,
        mae_std REAL,
        drift_threshold_rmse REAL NOT NULL,
        drift_threshold_mae REAL,
        n_windows INTEGER,
        metadata_json TEXT,
        UNIQUE (unique_id, model)
    );
    """)

    con.commit()
    con.close()


def save_forecasts(
    db_path: str,
    forecasts_df: pd.DataFrame,
    run_id: str,
    model: str = "MSTL_ARIMA",
) -> int:
    """Save forecasts to database.

    Args:
        db_path: Path to SQLite database
        forecasts_df: DataFrame with [unique_id, ds, yhat, yhat_lo_80, ...]
        run_id: Pipeline run identifier
        model: Model name

    Returns:
        Number of rows inserted
    """
    con = connect(db_path)
    created_at = datetime.utcnow().isoformat()

    rows = []
    for _, row in forecasts_df.iterrows():
        unique_id = row["unique_id"]
        parts = unique_id.split("_")
        region = parts[0] if len(parts) > 0 else ""
        fuel_type = parts[1] if len(parts) > 1 else ""

        rows.append((
            run_id,
            created_at,
            unique_id,
            region,
            fuel_type,
            str(row["ds"]),
            model,
            row.get("yhat"),
            row.get("yhat_lo_80"),
            row.get("yhat_hi_80"),
            row.get("yhat_lo_95"),
            row.get("yhat_hi_95"),
        ))

    cur = con.cursor()
    cur.executemany("""
        INSERT OR REPLACE INTO renewable_forecasts
        (run_id, created_at, unique_id, region, fuel_type, ds, model,
         yhat, yhat_lo_80, yhat_hi_80, yhat_lo_95, yhat_hi_95)
        VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)
    """, rows)

    con.commit()
    con.close()

    return len(rows)


def save_weather(
    db_path: str,
    weather_df: pd.DataFrame,
    is_forecast: bool = False,
) -> int:
    """Save weather features to database.

    Args:
        db_path: Path to SQLite database
        weather_df: DataFrame with [ds, region, weather_vars...]
        is_forecast: True if this is forecast weather data

    Returns:
        Number of rows inserted
    """
    con = connect(db_path)

    weather_cols = [
        "temperature_2m", "wind_speed_10m", "wind_speed_100m",
        "wind_direction_10m", "direct_radiation", "diffuse_radiation", "cloud_cover"
    ]

    rows = []
    for _, row in weather_df.iterrows():
        values = [row.get(col) for col in weather_cols]
        rows.append((
            row["region"],
            str(row["ds"]),
            *values,
            1 if is_forecast else 0,
        ))

    cur = con.cursor()
    cur.executemany(f"""
        INSERT OR REPLACE INTO weather_features
        (region, ds, {', '.join(weather_cols)}, is_forecast)
        VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?)
    """, rows)

    con.commit()
    con.close()

    return len(rows)


def save_drift_alert(
    db_path: str,
    run_id: str,
    unique_id: str,
    current_rmse: float,
    threshold_rmse: float,
    severity: str = "warning",
    metadata: Optional[dict] = None,
) -> None:
    """Save drift detection alert.

    Args:
        db_path: Path to SQLite database
        run_id: Pipeline run identifier
        unique_id: Series identifier
        current_rmse: Current RMSE value
        threshold_rmse: Drift threshold
        severity: Alert severity (info, warning, critical)
        metadata: Additional metadata
    """
    con = connect(db_path)

    parts = unique_id.split("_")
    region = parts[0] if len(parts) > 0 else ""
    fuel_type = parts[1] if len(parts) > 1 else ""

    alert_type = "drift_detected" if current_rmse > threshold_rmse else "drift_check"
    message = (
        f"RMSE {current_rmse:.1f} {'>' if current_rmse > threshold_rmse else '<='} "
        f"threshold {threshold_rmse:.1f}"
    )

    cur = con.cursor()
    cur.execute("""
        INSERT INTO drift_alerts
        (alert_at, run_id, unique_id, region, fuel_type, alert_type, severity,
         current_rmse, threshold_rmse, message, metadata_json)
        VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)
    """, (
        datetime.utcnow().isoformat(),
        run_id,
        unique_id,
        region,
        fuel_type,
        alert_type,
        severity,
        current_rmse,
        threshold_rmse,
        message,
        json.dumps(metadata) if metadata else None,
    ))

    con.commit()
    con.close()


def save_baseline(
    db_path: str,
    unique_id: str,
    model: str,
    baseline: dict,
) -> None:
    """Save baseline metrics for drift detection.

    Args:
        db_path: Path to SQLite database
        unique_id: Series identifier
        model: Model name
        baseline: Baseline metrics dictionary
    """
    con = connect(db_path)
    cur = con.cursor()

    cur.execute("""
        INSERT OR REPLACE INTO baseline_metrics
        (created_at, unique_id, model, rmse_mean, rmse_std, mae_mean, mae_std,
         drift_threshold_rmse, drift_threshold_mae, n_windows, metadata_json)
        VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)
    """, (
        datetime.utcnow().isoformat(),
        unique_id,
        model,
        baseline.get("rmse_mean"),
        baseline.get("rmse_std"),
        baseline.get("mae_mean"),
        baseline.get("mae_std"),
        baseline.get("drift_threshold_rmse"),
        baseline.get("drift_threshold_mae"),
        baseline.get("n_windows"),
        json.dumps(baseline),
    ))

    con.commit()
    con.close()


def get_recent_forecasts(
    db_path: str,
    region: Optional[str] = None,
    fuel_type: Optional[str] = None,
    hours: int = 48,
) -> pd.DataFrame:
    """Get recent forecasts from database.

    Args:
        db_path: Path to SQLite database
        region: Filter by region (optional)
        fuel_type: Filter by fuel type (optional)
        hours: Hours of history to retrieve

    Returns:
        DataFrame with forecasts
    """
    con = connect(db_path)

    query = """
        SELECT *
        FROM renewable_forecasts
        WHERE datetime(created_at) > datetime('now', ?)
    """
    params = [f"-{hours} hours"]

    if region:
        query += " AND region = ?"
        params.append(region)

    if fuel_type:
        query += " AND fuel_type = ?"
        params.append(fuel_type)

    query += " ORDER BY ds DESC"

    df = pd.read_sql_query(query, con, params=params)
    con.close()

    return df


def get_drift_alerts(
    db_path: str,
    hours: int = 24,
    severity: Optional[str] = None,
) -> pd.DataFrame:
    """Get recent drift alerts.

    Args:
        db_path: Path to SQLite database
        hours: Hours of history
        severity: Filter by severity (optional)

    Returns:
        DataFrame with alerts
    """
    con = connect(db_path)

    query = """
        SELECT *
        FROM drift_alerts
        WHERE datetime(alert_at) > datetime('now', ?)
    """
    params = [f"-{hours} hours"]

    if severity:
        query += " AND severity = ?"
        params.append(severity)

    query += " ORDER BY alert_at DESC"

    df = pd.read_sql_query(query, con, params=params)
    con.close()

    return df


if __name__ == "__main__":
    # Test database initialization
    import tempfile

    with tempfile.TemporaryDirectory() as tmpdir:
        db_path = f"{tmpdir}/test_renewable.db"

        print("Initializing database...")
        init_renewable_db(db_path)

        print("Database initialized successfully!")

        # Test connection
        con = connect(db_path)
        cur = con.cursor()
        cur.execute("SELECT name FROM sqlite_master WHERE type='table'")
        tables = cur.fetchall()
        print(f"Tables created: {[t[0] for t in tables]}")
        con.close()


Overwriting src/renewable/db.py


---

# Module 8: Dashboard

**File:** `src/renewable/dashboard.py`

The Streamlit dashboard provides:
- **Forecast visualization** with prediction intervals
- **Drift monitoring** and alerts
- **Coverage analysis** (nominal vs empirical)
- **Weather features** by region

## Running the Dashboard

```bash
streamlit run src/renewable/dashboard.py
```

The dashboard will:
1. Load forecasts from `data/renewable/forecasts.parquet`
2. Display interactive charts with Plotly
3. Show drift alerts from the database

In [9]:
%%writefile src/renewable/dashboard.py
# file: src/renewable/dashboard.py
"""Streamlit dashboard for renewable energy forecasting.

Provides:
- Forecast visualization with prediction intervals
- Drift monitoring and alerts
- Coverage analysis (nominal vs empirical)
- Weather features by region

Run with:
    streamlit run src/renewable/dashboard.py
"""

import os
import sys
from datetime import datetime, timedelta, timezone
from pathlib import Path

import numpy as np
import pandas as pd
import plotly.express as px
import plotly.graph_objects as go
import streamlit as st

# Add project root to path
sys.path.insert(0, str(Path(__file__).parent.parent.parent))

from src.renewable.db import (
    connect,
    get_drift_alerts,
    get_recent_forecasts,
    init_renewable_db,
)
from src.renewable.regions import FUEL_TYPES, REGIONS

# Page config
st.set_page_config(
    page_title="Renewable Forecast Dashboard",
    page_icon="⚡",
    layout="wide",
)


def main():
    """Main dashboard application."""
    st.title("⚡ Renewable Energy Forecast Dashboard")
    st.markdown("Next-24h wind/solar generation forecasts with drift monitoring")

    # Sidebar configuration
    with st.sidebar:
        st.header("Configuration")

        db_path = st.text_input(
            "Database Path",
            value="data/renewable/renewable.db",
        )

        # Initialize database if it doesn't exist
        if not Path(db_path).exists():
            init_renewable_db(db_path)
            st.info("Database initialized")

        st.divider()

        # Region filter
        all_regions = list(REGIONS.keys())
        selected_regions = st.multiselect(
            "Regions",
            options=all_regions,
            default=["CALI", "ERCO", "MISO"],
        )

        # Fuel type filter
        fuel_type = st.selectbox(
            "Fuel Type",
            options=["WND", "SUN", "Both"],
            index=0,
        )

        st.divider()

        # Actions
        show_debug = st.checkbox("Show Debug", value=False)
        if st.button("🔄 Refresh Data", width="stretch"):
            st.rerun()

        if st.button("📊 Run Pipeline", width="stretch"):
            run_pipeline_from_dashboard(db_path, selected_regions, fuel_type)

    # Main content tabs
    tab1, tab2, tab3, tab4, tab5 = st.tabs([
        "📈 Forecasts",
        "⚠️ Drift Monitor",
        "📊 Coverage",
        "🌤️ Weather",
        "🔍 Interpretability",
    ])

    with tab1:
        render_forecasts_tab(db_path, selected_regions, fuel_type, show_debug=show_debug)

    with tab2:
        render_drift_tab(db_path)

    with tab3:
        render_coverage_tab(db_path)

    with tab4:
        render_weather_tab(db_path, selected_regions)

    with tab5:
        render_interpretability_tab(selected_regions, fuel_type)


def render_forecasts_tab(db_path: str, regions: list, fuel_type: str, *, show_debug: bool = False):
    """Render forecast visualization with prediction intervals."""
    st.subheader("Generation Forecasts")

    forecasts_df = pd.DataFrame()
    data_source = "none"
    derived_columns: list[str] = []

    # Try to load from parquet file first (pipeline output)
    parquet_path = Path("data/renewable/forecasts.parquet")
    if parquet_path.exists():
        try:
            forecasts_df = pd.read_parquet(parquet_path)
            data_source = f"parquet:{parquet_path}"
            # Add region/fuel_type columns if missing
            if "unique_id" in forecasts_df.columns:
                parts = forecasts_df["unique_id"].astype(str).str.split("_", n=1, expand=True)
                if "region" not in forecasts_df.columns:
                    forecasts_df["region"] = parts[0]
                    derived_columns.append("region")
                if "fuel_type" not in forecasts_df.columns:
                    forecasts_df["fuel_type"] = parts[1] if parts.shape[1] > 1 else pd.NA
                    derived_columns.append("fuel_type")
            st.success(f"Loaded {len(forecasts_df)} forecasts from pipeline")

            # Calculate and display data freshness
            if not forecasts_df.empty and "ds" in forecasts_df.columns:
                earliest_forecast_ts = forecasts_df["ds"].min()
                now_utc = pd.Timestamp.now(tz="UTC").floor("h")

                # Forecasts start from last_data + 1h, so last_data = earliest_forecast - 1h
                last_data_ts = earliest_forecast_ts - pd.Timedelta(hours=1)

                # Ensure both timestamps are timezone-aware for comparison
                if not hasattr(last_data_ts, 'tz') or last_data_ts.tz is None:
                    last_data_ts = pd.Timestamp(last_data_ts, tz="UTC")

                data_age_hours = (now_utc - last_data_ts).total_seconds() / 3600

                # Show warning if data is > 6 hours old
                if data_age_hours > 6:
                    st.warning(
                        f"⚠️ Forecasts are based on **{data_age_hours:.1f} hour old** data "
                        f"(last EIA data: {last_data_ts.strftime('%b %d %H:%M')} UTC). "
                        f"Click 'Refresh Forecasts' button in sidebar to update."
                    )
                else:
                    st.info(
                        f"✅ Forecasts from {last_data_ts.strftime('%b %d %H:%M')} UTC data "
                        f"({data_age_hours:.1f}h old)"
                    )

        except Exception as e:
            st.warning(f"Could not load parquet: {e}")

    # Fall back to database
    if forecasts_df.empty:
        try:
            forecasts_df = get_recent_forecasts(db_path, hours=72)
            data_source = f"db:{db_path}"
        except Exception as e:
            st.warning(f"Could not load from database: {e}")

    if forecasts_df.empty:
        # Show demo data
        st.info("No forecasts found. Showing demo data.")
        forecasts_df = generate_demo_forecasts(regions, fuel_type)
        data_source = "demo"

    if show_debug:
        with st.expander("Debug: Forecast Data", expanded=False):
            st.markdown("**Source**")
            st.code(data_source)
            st.markdown("**Columns**")
            st.code(", ".join(forecasts_df.columns.tolist()))

            st.markdown("**Counts (pre-filter)**")
            st.write({"rows": int(len(forecasts_df))})

            if derived_columns:
                st.markdown("**Derived Columns**")
                st.write(derived_columns)

            if "unique_id" in forecasts_df.columns:
                st.markdown("**unique_id sample**")
                st.write(forecasts_df["unique_id"].dropna().astype(str).head(10).tolist())

            if "fuel_type" in forecasts_df.columns:
                st.markdown("**fuel_type counts**")
                st.dataframe(forecasts_df["fuel_type"].value_counts(dropna=False).to_frame())

                unknown = sorted(
                    {str(v) for v in forecasts_df["fuel_type"].dropna().unique()}
                    - set(FUEL_TYPES.keys())
                )
                if unknown:
                    st.warning(f"Unknown fuel_type values: {unknown}")

            if "region" in forecasts_df.columns:
                st.markdown("**region counts**")
                st.dataframe(forecasts_df["region"].value_counts(dropna=False).to_frame())

    # Filter by selections
    if fuel_type != "Both":
        forecasts_df = forecasts_df[forecasts_df["fuel_type"] == fuel_type]

    if regions:
        forecasts_df = forecasts_df[forecasts_df["region"].isin(regions)]

    if show_debug:
        with st.expander("Debug: Filter Result", expanded=False):
            st.markdown("**Applied Filters**")
            st.write({"fuel_type": fuel_type, "regions": regions})
            st.markdown("**Counts (post-filter)**")
            st.write({"rows": int(len(forecasts_df))})
            if "unique_id" in forecasts_df.columns:
                st.markdown("**unique_id after filter**")
                st.write(sorted(forecasts_df["unique_id"].dropna().astype(str).unique().tolist()))

    if forecasts_df.empty:
        st.warning("No data matching filters")
        return

    # Series selector
    series_options = forecasts_df["unique_id"].unique().tolist()
    selected_series = st.selectbox(
        "Select Series",
        options=series_options,
        index=0 if series_options else None,
        key="forecast_series_select",
    )

    if selected_series:
        series_data = forecasts_df[forecasts_df["unique_id"] == selected_series].copy()
        series_data = series_data.sort_values("ds")

        # Convert to local timezone for display
        region_code = series_data["unique_id"].iloc[0].split("_")[0]
        region_info = REGIONS.get(region_code)
        timezone_name = region_info.timezone if region_info else "UTC"

        # Create forecast plot with intervals
        fig = create_forecast_plot(series_data, selected_series, timezone_name)
        st.plotly_chart(fig, width="stretch")

        # Show data table
        with st.expander("View Data"):
            st.dataframe(
                series_data[["ds", "yhat", "yhat_lo_80", "yhat_hi_80", "yhat_lo_95", "yhat_hi_95"]],
                width="stretch",
            )


def create_forecast_plot(df: pd.DataFrame, title: str, timezone_name: str = "UTC") -> go.Figure:
    """Create Plotly figure with forecast and prediction intervals.

    Args:
        df: Forecast dataframe with ds (timestamp), yhat, and interval columns
        title: Series name for chart title
        timezone_name: IANA timezone name for display (e.g., "America/Chicago")
    """
    fig = go.Figure()

    # Convert timestamps to local timezone for display
    df = df.copy()
    df["ds"] = pd.to_datetime(df["ds"])

    # Convert UTC to local timezone
    if timezone_name != "UTC":
        df["ds"] = df["ds"].dt.tz_localize("UTC").dt.tz_convert(timezone_name)

    # Get timezone abbreviation for display (e.g., "CST", "PST")
    if timezone_name != "UTC" and len(df) > 0:
        tz_abbr = df["ds"].iloc[0].strftime("%Z")
    else:
        tz_abbr = "UTC"

    # 95% interval (outer, lighter)
    if "yhat_lo_95" in df.columns and "yhat_hi_95" in df.columns:
        fig.add_trace(go.Scatter(
            x=pd.concat([df["ds"], df["ds"][::-1]]),
            y=pd.concat([df["yhat_hi_95"], df["yhat_lo_95"][::-1]]),
            fill="toself",
            fillcolor="rgba(68, 138, 255, 0.2)",
            line=dict(color="rgba(255,255,255,0)"),
            name="95% Interval",
            hoverinfo="skip",
        ))

    # 80% interval (inner, darker)
    if "yhat_lo_80" in df.columns and "yhat_hi_80" in df.columns:
        fig.add_trace(go.Scatter(
            x=pd.concat([df["ds"], df["ds"][::-1]]),
            y=pd.concat([df["yhat_hi_80"], df["yhat_lo_80"][::-1]]),
            fill="toself",
            fillcolor="rgba(68, 138, 255, 0.4)",
            line=dict(color="rgba(255,255,255,0)"),
            name="80% Interval",
            hoverinfo="skip",
        ))

    # Point forecast
    fig.add_trace(go.Scatter(
        x=df["ds"],
        y=df["yhat"],
        mode="lines",
        name="Forecast",
        line=dict(color="#1f77b4", width=2),
    ))

    # Actuals if available
    if "y" in df.columns:
        fig.add_trace(go.Scatter(
            x=df["ds"],
            y=df["y"],
            mode="markers",
            name="Actual",
            marker=dict(color="#2ca02c", size=6),
        ))

    fig.update_layout(
        title=f"Forecast: {title}",
        xaxis_title=f"Time ({tz_abbr})",
        yaxis_title="Generation (MWh)",
        hovermode="x unified",
        legend=dict(orientation="h", yanchor="bottom", y=1.02),
        height=450,
    )

    return fig


def render_drift_tab(db_path: str):
    """Render drift monitoring and alerts."""
    st.subheader("Drift Detection")

    col1, col2, col3 = st.columns(3)

    # Try to load alerts
    try:
        alerts_df = get_drift_alerts(db_path, hours=48)
    except Exception:
        alerts_df = pd.DataFrame()

    # Summary metrics
    with col1:
        critical = len(alerts_df[alerts_df["severity"] == "critical"]) if not alerts_df.empty else 0
        st.metric(
            "Critical Alerts",
            critical,
            delta=None,
            delta_color="inverse" if critical > 0 else "off",
        )

    with col2:
        warning = len(alerts_df[alerts_df["severity"] == "warning"]) if not alerts_df.empty else 0
        st.metric("Warnings", warning)

    with col3:
        stable = len(alerts_df[alerts_df["alert_type"] == "drift_check"]) if not alerts_df.empty else 0
        st.metric("Stable Checks", stable)

    st.divider()

    if alerts_df.empty:
        st.info("No drift alerts in the last 48 hours. System is stable.")

        # Show demo drift status
        st.markdown("### Demo Drift Status")
        demo_drift = pd.DataFrame({
            "Series": ["CALI_WND", "ERCO_WND", "MISO_WND", "CALI_SUN", "ERCO_SUN"],
            "Current RMSE": [125.3, 98.7, 156.2, 45.1, 67.8],
            "Threshold": [150.0, 120.0, 180.0, 60.0, 80.0],
            "Status": ["✅ Stable", "✅ Stable", "✅ Stable", "✅ Stable", "✅ Stable"],
        })
        st.dataframe(demo_drift, width="stretch")
    else:
        # Show alerts table
        st.dataframe(
            alerts_df[["alert_at", "unique_id", "severity", "current_rmse", "threshold_rmse", "message"]],
            width="stretch",
        )

        # Drift timeline
        if len(alerts_df) > 1:
            alerts_df["alert_at"] = pd.to_datetime(alerts_df["alert_at"])
            fig = px.scatter(
                alerts_df,
                x="alert_at",
                y="current_rmse",
                color="severity",
                size="current_rmse",
                hover_data=["unique_id", "message"],
                title="Drift Timeline",
            )
            fig.add_hline(
                y=alerts_df["threshold_rmse"].mean(),
                line_dash="dash",
                annotation_text="Avg Threshold",
            )
            st.plotly_chart(fig, width="stretch")


def render_coverage_tab(db_path: str):
    """Render coverage analysis comparing nominal vs empirical."""
    st.subheader("Prediction Interval Coverage")

    st.markdown("""
    **Coverage** measures how often actual values fall within prediction intervals.
    - **Nominal**: The expected coverage (80% or 95%)
    - **Empirical**: The actual observed coverage
    - **Gap**: Difference indicates calibration quality
    """)

    # Demo coverage data
    coverage_data = pd.DataFrame({
        "Series": ["CALI_WND", "ERCO_WND", "MISO_WND", "SWPP_WND", "CALI_SUN", "ERCO_SUN"],
        "Nominal 80%": [80, 80, 80, 80, 80, 80],
        "Empirical 80%": [78.5, 82.1, 76.3, 79.8, 81.2, 77.9],
        "Nominal 95%": [95, 95, 95, 95, 95, 95],
        "Empirical 95%": [93.2, 96.1, 91.5, 94.8, 95.7, 92.3],
    })

    coverage_data["Gap 80%"] = coverage_data["Empirical 80%"] - coverage_data["Nominal 80%"]
    coverage_data["Gap 95%"] = coverage_data["Empirical 95%"] - coverage_data["Nominal 95%"]

    # Summary
    col1, col2 = st.columns(2)

    with col1:
        avg_80 = coverage_data["Empirical 80%"].mean()
        st.metric("Avg 80% Coverage", f"{avg_80:.1f}%", f"{avg_80 - 80:.1f}%")

    with col2:
        avg_95 = coverage_data["Empirical 95%"].mean()
        st.metric("Avg 95% Coverage", f"{avg_95:.1f}%", f"{avg_95 - 95:.1f}%")

    st.divider()

    # Coverage comparison chart
    fig = go.Figure()

    fig.add_trace(go.Bar(
        name="80% Empirical",
        x=coverage_data["Series"],
        y=coverage_data["Empirical 80%"],
        marker_color="rgba(68, 138, 255, 0.7)",
    ))

    fig.add_trace(go.Bar(
        name="95% Empirical",
        x=coverage_data["Series"],
        y=coverage_data["Empirical 95%"],
        marker_color="rgba(68, 138, 255, 0.4)",
    ))

    # Nominal lines
    fig.add_hline(y=80, line_dash="dash", line_color="red", annotation_text="80% Nominal")
    fig.add_hline(y=95, line_dash="dash", line_color="orange", annotation_text="95% Nominal")

    fig.update_layout(
        title="Coverage by Series",
        xaxis_title="Series",
        yaxis_title="Coverage (%)",
        barmode="group",
        height=400,
    )

    st.plotly_chart(fig, width="stretch")

    # Detailed table
    with st.expander("View Coverage Data"):
        st.dataframe(coverage_data, width="stretch")


def render_weather_tab(db_path: str, regions: list):
    """Render weather features visualization."""
    st.subheader("Weather Features")

    weather_df = pd.DataFrame()

    # Prefer real pipeline output; no demo fallback.
    parquet_path = Path("data/renewable/weather.parquet")
    if parquet_path.exists():
        try:
            weather_df = pd.read_parquet(parquet_path)
            st.success(f"Loaded {len(weather_df)} weather rows from pipeline")
        except Exception as exc:
            st.warning(f"Could not load weather parquet: {exc}")

    if weather_df.empty and Path(db_path).exists():
        try:
            with connect(db_path) as con:
                weather_df = pd.read_sql_query(
                    "SELECT * FROM weather_features ORDER BY ds ASC",
                    con,
                )
            if not weather_df.empty:
                st.success(f"Loaded {len(weather_df)} weather rows from database")
        except Exception as exc:
            st.warning(f"Could not load weather data from database: {exc}")

    if weather_df.empty:
        st.warning("No weather data available. Run the pipeline to populate weather features.")
        return

    weather_df["ds"] = pd.to_datetime(weather_df["ds"], errors="coerce")
    if regions:
        weather_df = weather_df[weather_df["region"].isin(regions)]
    if weather_df.empty:
        st.warning("No weather data matching selected regions.")
        return

    # Variable selector
    weather_vars = [
        col for col in ["wind_speed_10m", "wind_speed_100m", "direct_radiation", "cloud_cover"]
        if col in weather_df.columns
    ]
    if not weather_vars:
        st.warning("Weather data missing expected variables.")
        return
    selected_var = st.selectbox("Weather Variable", options=weather_vars)

    # Plot
    fig = px.line(
        weather_df,
        x="ds",
        y=selected_var,
        color="region",
        title=f"{selected_var} by Region",
    )
    fig.update_layout(height=400)
    st.plotly_chart(fig, width="stretch")

    # Summary stats
    st.markdown("### Current Conditions")

    cols = st.columns(len(regions[:4]))
    for i, region in enumerate(regions[:4]):
        if i < len(cols):
            with cols[i]:
                region_data = weather_df[weather_df["region"] == region].iloc[-1] if len(weather_df[weather_df["region"] == region]) > 0 else {}
                st.metric(
                    region,
                    f"{region_data.get('wind_speed_10m', 0):.1f} m/s",
                    help="Wind speed at 10m",
                )


def render_interpretability_tab(regions: list, fuel_type: str):
    """Render model interpretability visualizations (SHAP, feature importance, PDP)."""
    st.subheader("Model Interpretability")

    # Model Leaderboard Section
    st.markdown("### 🏆 Model Leaderboard (Cross-Validation)")

    # Model descriptions for education
    MODEL_INFO = {
        "AutoARIMA": {
            "type": "Statistical",
            "description": "Auto-tuned ARIMA with automatic p,d,q selection. Good for univariate series with trend/seasonality.",
            "strengths": "Robust, well-understood, good prediction intervals",
        },
        "MSTL_ARIMA": {
            "type": "Statistical",
            "description": "Multiple Seasonal-Trend decomposition + ARIMA. Handles daily (24h) and weekly (168h) seasonality.",
            "strengths": "Best for multi-seasonal patterns like energy data",
        },
        "AutoETS": {
            "type": "Statistical",
            "description": "Exponential smoothing with automatic error/trend/season selection.",
            "strengths": "Simple, fast, works well for smooth series",
        },
        "AutoTheta": {
            "type": "Statistical",
            "description": "Theta method with automatic decomposition. Robust to outliers.",
            "strengths": "Competition winner (M3), handles level shifts",
        },
        "CES": {
            "type": "Statistical",
            "description": "Complex Exponential Smoothing. Captures complex seasonal patterns.",
            "strengths": "Good for complex seasonality",
        },
        "SeasonalNaive": {
            "type": "Baseline",
            "description": "Uses value from same hour last week. Baseline benchmark.",
            "strengths": "Simple benchmark - if beaten, models add value",
        },
    }

    run_log_path = Path("data/renewable/run_log.json")
    if run_log_path.exists():
        try:
            import json
            run_log = json.loads(run_log_path.read_text())
            pipeline_results = run_log.get("pipeline_results", {})
            leaderboard_data = pipeline_results.get("leaderboard", [])

            if leaderboard_data:
                leaderboard_df = pd.DataFrame(leaderboard_data)
                best_model = pipeline_results.get("best_model", "")
                best_rmse = pipeline_results.get("best_rmse", 0)

                # Key metrics row
                col1, col2, col3, col4 = st.columns(4)
                with col1:
                    st.metric("Best Model", best_model)
                with col2:
                    st.metric("Best RMSE", f"{best_rmse:.3f}")
                with col3:
                    st.metric("Models Evaluated", len(leaderboard_data))
                with col4:
                    # Calculate improvement over baseline
                    baseline_rmse = leaderboard_df[leaderboard_df["model"] == "SeasonalNaive"]["rmse"].values
                    if len(baseline_rmse) > 0 and best_rmse > 0:
                        improvement = ((baseline_rmse[0] - best_rmse) / baseline_rmse[0]) * 100
                        st.metric("vs Baseline", f"{improvement:+.1f}%", help="Improvement over SeasonalNaive")
                    else:
                        st.metric("vs Baseline", "N/A")

                # Selection rationale
                st.markdown("#### Why This Model?")
                st.info(f"""
                **{best_model}** was selected because it has the **lowest RMSE** on cross-validation.

                - **RMSE (Root Mean Square Error)**: Penalizes large errors more heavily. Best for energy forecasting where big misses are costly.
                - **Selection method**: Time-series CV with {run_log.get('config', {}).get('cv_windows', 2)} windows, step size {run_log.get('config', {}).get('cv_step_size', 168)}h
                - **Horizon**: {run_log.get('config', {}).get('horizon', 24)}h ahead forecasts
                """)

                # Model description for winner
                if best_model in MODEL_INFO:
                    info = MODEL_INFO[best_model]
                    st.success(f"**{info['type']} Model**: {info['description']}")

                # Full leaderboard with visualization
                st.markdown("#### All Models Ranked by RMSE")

                display_cols = [c for c in ["model", "rmse", "mae", "mape", "coverage_80", "coverage_95"]
                               if c in leaderboard_df.columns]

                # Create visualization
                if "rmse" in leaderboard_df.columns:
                    fig = px.bar(
                        leaderboard_df.sort_values("rmse"),
                        x="model",
                        y="rmse",
                        title="Model Comparison (Lower RMSE = Better)",
                        color="rmse",
                        color_continuous_scale="RdYlGn_r",
                    )
                    fig.add_hline(y=best_rmse, line_dash="dash", line_color="green",
                                  annotation_text=f"Best: {best_rmse:.3f}")
                    fig.update_layout(height=350)
                    st.plotly_chart(fig, width="stretch")

                # Format numeric columns for table
                styled_df = leaderboard_df[display_cols].copy()
                for col in ["rmse", "mae", "mape"]:
                    if col in styled_df.columns:
                        styled_df[col] = styled_df[col].apply(lambda x: f"{x:.3f}" if pd.notna(x) else "N/A")
                for col in ["coverage_80", "coverage_95"]:
                    if col in styled_df.columns:
                        styled_df[col] = styled_df[col].apply(lambda x: f"{x:.1f}%" if pd.notna(x) else "N/A")

                st.dataframe(styled_df, width="stretch", hide_index=True)

                # Coverage analysis
                if "coverage_80" in leaderboard_df.columns:
                    st.markdown("#### Prediction Interval Coverage")
                    st.markdown("""
                    **Coverage** measures if prediction intervals are well-calibrated:
                    - **80% interval** should contain ~80% of actual values
                    - **95% interval** should contain ~95% of actual values
                    - **Under-coverage** (<target) = intervals too narrow, overconfident
                    - **Over-coverage** (>target) = intervals too wide, conservative
                    """)

                    coverage_df = leaderboard_df[["model", "coverage_80", "coverage_95"]].copy()
                    coverage_df["coverage_80_status"] = coverage_df["coverage_80"].apply(
                        lambda x: "Under" if x < 75 else ("Over" if x > 85 else "Good") if pd.notna(x) else "N/A"
                    )
                    coverage_df["coverage_95_status"] = coverage_df["coverage_95"].apply(
                        lambda x: "Under" if x < 90 else ("Over" if x > 99 else "Good") if pd.notna(x) else "N/A"
                    )

                # Model descriptions expander
                with st.expander("Model Descriptions"):
                    for model_name, info in MODEL_INFO.items():
                        st.markdown(f"**{model_name}** ({info['type']})")
                        st.markdown(f"- {info['description']}")
                        st.markdown(f"- *Strengths*: {info['strengths']}")
                        st.markdown("---")

                # CV configuration expander
                config = run_log.get("config", {})
                with st.expander("CV Configuration"):
                    st.write({
                        "cv_windows": config.get("cv_windows"),
                        "cv_step_size": config.get("cv_step_size"),
                        "horizon": config.get("horizon"),
                        "regions": config.get("regions"),
                        "fuel_types": config.get("fuel_types"),
                        "run_at": run_log.get("run_at_utc", "N/A"),
                    })
            else:
                st.info("Leaderboard not available. Run the pipeline with the latest code to generate.")
        except Exception as e:
            st.warning(f"Could not load leaderboard: {e}")
    else:
        st.info("No run log found. Run the pipeline to generate model comparison.")

    st.divider()

    st.markdown("### 🔍 Per-Series Interpretability")
    st.markdown("""
    **LightGBM** models are trained alongside statistical models (MSTL/ARIMA) to provide
    interpretability insights. The statistical models generate the primary forecasts,
    while LightGBM helps understand feature importance and relationships.
    """)

    interp_dir = Path("data/renewable/interpretability")

    if not interp_dir.exists():
        st.info("No interpretability data available. Run the pipeline to generate SHAP and PDP plots.")
        return

    # Get available series
    series_dirs = sorted([d.name for d in interp_dir.iterdir() if d.is_dir()])

    if not series_dirs:
        st.warning("Interpretability directory exists but contains no series data.")
        return

    # Filter by selected regions and fuel type
    filtered_series = []
    for series_id in series_dirs:
        parts = series_id.split("_")
        if len(parts) == 2:
            region, ft = parts
            if regions and region not in regions:
                continue
            if fuel_type != "Both" and ft != fuel_type:
                continue
            filtered_series.append(series_id)

    if not filtered_series:
        st.warning("No interpretability data for selected filters.")
        return

    # Series selector
    selected_series = st.selectbox(
        "Select Series",
        options=filtered_series,
        index=0,
        key="interpretability_series_select",
    )

    if not selected_series:
        return

    series_dir = interp_dir / selected_series

    # Layout: Feature Importance + SHAP Summary side by side
    col1, col2 = st.columns(2)

    with col1:
        st.markdown("### Feature Importance")
        importance_path = series_dir / "feature_importance.csv"
        if importance_path.exists():
            try:
                importance_df = pd.read_csv(importance_path)
                # Show top 15 features
                top_features = importance_df.head(15)

                # Create bar chart
                fig = px.bar(
                    top_features,
                    x="importance",
                    y="feature",
                    orientation="h",
                    title=f"Top Features: {selected_series}",
                    labels={"importance": "Importance", "feature": "Feature"},
                )
                fig.update_layout(yaxis=dict(autorange="reversed"), height=400)
                st.plotly_chart(fig, width="stretch")

                with st.expander("Full Feature List"):
                    st.dataframe(importance_df, width="stretch")
            except Exception as e:
                st.error(f"Error loading feature importance: {e}")
        else:
            st.info("Feature importance not available.")

    with col2:
        st.markdown("### SHAP Summary")
        shap_summary_path = series_dir / "shap_summary.png"
        if shap_summary_path.exists():
            st.image(str(shap_summary_path), width="stretch")
        else:
            # Try bar plot as fallback
            shap_bar_path = series_dir / "shap_bar.png"
            if shap_bar_path.exists():
                st.image(str(shap_bar_path), width="stretch")
            else:
                st.info("SHAP summary not available.")

    st.divider()

    # SHAP Dependence Plots
    st.markdown("### SHAP Dependence Plots")
    st.markdown("Shows how individual feature values affect predictions.")

    shap_dep_files = list(series_dir.glob("shap_dependence_*.png"))
    if shap_dep_files:
        # Create columns for dependence plots
        n_cols = min(3, len(shap_dep_files))
        cols = st.columns(n_cols)

        for i, dep_file in enumerate(shap_dep_files[:6]):  # Limit to 6 plots
            feature_name = dep_file.stem.replace("shap_dependence_", "")
            with cols[i % n_cols]:
                st.markdown(f"**{feature_name}**")
                st.image(str(dep_file), width="stretch")
    else:
        st.info("SHAP dependence plots not available.")

    st.divider()

    # Partial Dependence Plot
    st.markdown("### Partial Dependence Plot")
    st.markdown("Shows the average effect of features on predictions (marginal effect).")

    pdp_path = series_dir / "partial_dependence.png"
    if pdp_path.exists():
        st.image(str(pdp_path), width="stretch")
    else:
        st.info("Partial dependence plot not available.")

    # Waterfall plot for sample prediction
    waterfall_path = series_dir / "shap_waterfall_sample.png"
    if waterfall_path.exists():
        st.markdown("### Sample Prediction Explanation")
        st.markdown("SHAP waterfall showing how features contributed to a single prediction.")
        st.image(str(waterfall_path), width="stretch")


def generate_demo_forecasts(regions: list, fuel_type: str) -> pd.DataFrame:
    """Generate demo forecast data for display."""
    data = []
    base_time = datetime.now(timezone.utc).replace(minute=0, second=0, microsecond=0)

    fuel_types = [fuel_type] if fuel_type != "Both" else ["WND", "SUN"]

    for region in regions[:3]:
        for ft in fuel_types:
            unique_id = f"{region}_{ft}"
            base_value = 500 if ft == "WND" else 300

            for h in range(24):
                ds = base_time + timedelta(hours=h)

                # Add daily pattern
                if ft == "SUN":
                    hour_factor = max(0, np.sin((ds.hour - 6) * np.pi / 12)) if 6 < ds.hour < 18 else 0
                    yhat = base_value * hour_factor + np.random.normal(0, 20)
                else:
                    yhat = base_value + np.sin(ds.hour * np.pi / 12) * 100 + np.random.normal(0, 30)

                yhat = max(0, yhat)

                data.append({
                    "unique_id": unique_id,
                    "region": region,
                    "fuel_type": ft,
                    "ds": ds,
                    "yhat": yhat,
                    "yhat_lo_80": yhat * 0.85,
                    "yhat_hi_80": yhat * 1.15,
                    "yhat_lo_95": yhat * 0.75,
                    "yhat_hi_95": yhat * 1.25,
                })

    return pd.DataFrame(data)


def run_pipeline_from_dashboard(db_path: str, regions: list, fuel_type: str):
    """Run the forecasting pipeline from the dashboard."""
    with st.spinner("Refreshing forecasts... (may take 2-3 minutes)"):
        try:
            from src.renewable.jobs import run_hourly

            # Run the hourly pipeline job
            run_hourly.main()

            st.success("Pipeline completed! Forecasts have been updated with latest EIA data.")
            st.info("Reloading page to show new forecasts...")

            # Wait a moment then reload
            import time
            time.sleep(2)
            st.rerun()

        except Exception as e:
            st.error(f"Pipeline failed: {e}")
            import traceback
            with st.expander("Error details"):
                st.code(traceback.format_exc())


if __name__ == "__main__":
    main()


Overwriting src/renewable/dashboard.py


# Airflow integration 



In [10]:
%%writefile src/renewable/data_freshness.py
# src/renewable/data_freshness.py
"""
Lightweight EIA data freshness checking.

This module provides functions to check if new data is available from the EIA API
before running the full pipeline. It compares the current max timestamps with
the previous run's max timestamps to determine if a full pipeline run is needed.
"""

from __future__ import annotations

import json
import logging
from dataclasses import dataclass, field
from datetime import datetime, timezone
from pathlib import Path
from typing import Optional

import pandas as pd
import requests

from src.renewable.regions import get_eia_respondent

logger = logging.getLogger(__name__)


@dataclass(frozen=True)
class FreshnessCheckResult:
    """Result of a data freshness check."""

    has_new_data: bool
    checked_at_utc: str
    series_status: dict[str, dict] = field(default_factory=dict)
    summary: str = ""


def load_previous_max_ds(run_log_path: Path) -> dict[str, str]:
    """
    Load per-series max_ds from previous run_log.json.

    Args:
        run_log_path: Path to run_log.json

    Returns:
        Dict mapping unique_id -> max_ds ISO string.
        Empty dict if file doesn't exist or is malformed.
    """
    if not run_log_path.exists():
        logger.info("[freshness] No previous run_log.json found - first run")
        return {}

    try:
        data = json.loads(run_log_path.read_text(encoding="utf-8"))

        # Navigate to diagnostics.generation_coverage.coverage
        coverage = (
            data.get("diagnostics", {})
            .get("generation_coverage", {})
            .get("coverage", [])
        )

        if not coverage:
            logger.warning("[freshness] run_log.json has no coverage data")
            return {}

        result = {}
        for item in coverage:
            uid = item.get("unique_id")
            max_ds = item.get("max_ds")
            if uid and max_ds:
                result[uid] = max_ds

        logger.info(f"[freshness] Loaded {len(result)} series from previous run_log")
        return result

    except (json.JSONDecodeError, KeyError, TypeError) as e:
        logger.warning(f"[freshness] Failed to parse run_log.json: {e}")
        return {}


def probe_eia_latest(
    api_key: str,
    region: str,
    fuel_type: str,
    *,
    timeout: int = 15,
) -> Optional[str]:
    """
    Fetch only the single most recent record from EIA API.

    This is a lightweight probe that uses:
    - length=1 (only fetch 1 record)
    - sort by period DESC (most recent first)

    Args:
        api_key: EIA API key
        region: Region code (CALI, ERCO, MISO, etc.)
        fuel_type: Fuel type (WND, SUN)
        timeout: Request timeout in seconds

    Returns:
        ISO timestamp string of latest record, or None on error.
    """
    try:
        respondent = get_eia_respondent(region)

        params = {
            "api_key": api_key,
            "data[]": "value",
            "facets[respondent][]": respondent,
            "facets[fueltype][]": fuel_type,
            "frequency": "hourly",
            "length": 1,
            "sort[0][column]": "period",
            "sort[0][direction]": "desc",
        }

        base_url = "https://api.eia.gov/v2/electricity/rto/fuel-type-data/data/"
        resp = requests.get(base_url, params=params, timeout=timeout)
        resp.raise_for_status()

        payload = resp.json()
        response = payload.get("response", {})
        records = response.get("data", [])

        if not records:
            logger.warning(f"[probe] {region}_{fuel_type}: No records returned")
            return None

        period = records[0].get("period")
        if not period:
            logger.warning(f"[probe] {region}_{fuel_type}: Record missing 'period'")
            return None

        # Parse to consistent ISO format
        ts = pd.to_datetime(period, utc=True)
        return ts.isoformat()

    except requests.RequestException as e:
        logger.warning(f"[probe] {region}_{fuel_type}: API error: {e}")
        return None
    except Exception as e:
        logger.warning(f"[probe] {region}_{fuel_type}: Unexpected error: {e}")
        return None


def _compare_timestamps(prev: Optional[str], current: Optional[str]) -> bool:
    """
    Return True if current is strictly newer than prev.

    Handles None values conservatively (assume new data if unknown).
    """
    if not prev or not current:
        return True  # Unknown = assume new data (conservative)

    try:
        prev_dt = pd.to_datetime(prev, utc=True)
        curr_dt = pd.to_datetime(current, utc=True)
        return curr_dt > prev_dt
    except Exception:
        return True  # Parse error = assume new data


def check_all_series_freshness(
    regions: list[str],
    fuel_types: list[str],
    run_log_path: Path,
    api_key: str,
) -> FreshnessCheckResult:
    """
    Check all series for new data availability.

    Args:
        regions: List of region codes (e.g., ["CALI", "ERCO", "MISO"])
        fuel_types: List of fuel types (e.g., ["WND", "SUN"])
        run_log_path: Path to previous run_log.json
        api_key: EIA API key

    Returns:
        FreshnessCheckResult with has_new_data flag and detailed status per series.
    """
    checked_at = datetime.now(timezone.utc).isoformat()

    # 1. Load previous max_ds values
    prev_max_ds = load_previous_max_ds(run_log_path)

    # 2. If no previous run_log, always run full pipeline (first run)
    if not prev_max_ds:
        return FreshnessCheckResult(
            has_new_data=True,
            checked_at_utc=checked_at,
            series_status={},
            summary="No previous run_log.json found - running full pipeline (first run)",
        )

    # 3. Probe each series
    series_status: dict[str, dict] = {}
    has_any_new = False
    new_series: list[str] = []
    error_series: list[str] = []

    for region in regions:
        for fuel_type in fuel_types:
            series_id = f"{region}_{fuel_type}"
            prev = prev_max_ds.get(series_id)
            current = probe_eia_latest(api_key, region, fuel_type)

            # Determine if this series has new data
            if current is None:
                # API error - be conservative, assume new data
                is_new = True
                error_series.append(series_id)
                logger.warning(
                    f"[freshness] {series_id}: probe failed, assuming new data"
                )
            else:
                is_new = _compare_timestamps(prev, current)

            series_status[series_id] = {
                "prev_max_ds": prev,
                "current_max_ds": current,
                "is_new": is_new,
            }

            if is_new:
                has_any_new = True
                if current is not None:
                    new_series.append(series_id)

            # Log each series check
            status_str = "NEW" if is_new else "unchanged"
            logger.info(
                f"[freshness] {series_id}: prev={prev} current={current} ({status_str})"
            )

    # 4. Build summary
    if error_series:
        summary = f"Probe errors for {error_series}, assuming new data available"
    elif new_series:
        summary = f"New data found for: {', '.join(new_series)}"
    else:
        summary = "No new data found for any series"

    return FreshnessCheckResult(
        has_new_data=has_any_new,
        checked_at_utc=checked_at,
        series_status=series_status,
        summary=summary,
    )


if __name__ == "__main__":
    # Quick test
    import os
    from dotenv import load_dotenv

    load_dotenv()
    logging.basicConfig(level=logging.INFO)

    api_key = os.getenv("EIA_API_KEY")
    if not api_key:
        print("EIA_API_KEY not set")
        exit(1)

    run_log_path = Path("data/renewable/run_log.json")

    result = check_all_series_freshness(
        regions=["CALI", "ERCO", "MISO"],
        fuel_types=["WND", "SUN"],
        run_log_path=run_log_path,
        api_key=api_key,
    )

    print(f"\nFreshness Check Result:")
    print(f"  has_new_data: {result.has_new_data}")
    print(f"  checked_at: {result.checked_at_utc}")
    print(f"  summary: {result.summary}")
    print(f"\nPer-series status:")
    for series_id, status in result.series_status.items():
        print(f"  {series_id}: {status}")


Overwriting src/renewable/data_freshness.py


In [11]:
%%writefile src/renewable/jobs/run_hourly.py
# file: src/renewable/jobs/run_hourly.py
"""Hourly renewable pipeline entry point with validation."""

from __future__ import annotations

import json
import os
from datetime import datetime, timezone
from pathlib import Path
import pandas as pd
from dotenv import load_dotenv

from src.renewable.tasks import RenewablePipelineConfig, run_full_pipeline
from src.renewable.validation import validate_generation_df
from src.renewable.data_freshness import check_all_series_freshness, FreshnessCheckResult

load_dotenv()


def _env_list(name: str, default_csv: str) -> list[str]:
    raw = os.getenv(name, default_csv)
    return [item.strip() for item in raw.split(",") if item.strip()]


def _env_int(name: str, default: int) -> int:
    raw = os.getenv(name, str(default))
    try:
        return int(raw)
    except ValueError:
        return default


def _env_float(name: str, default: float) -> float:
    raw = os.getenv(name, str(default))
    try:
        return float(raw)
    except ValueError:
        return default


def _expected_series(regions: list[str], fuel_types: list[str]) -> list[str]:
    return [f"{region}_{fuel}" for region in regions for fuel in fuel_types]


def _json_default(value: object) -> str:
    if isinstance(value, pd.Timestamp):
        return value.isoformat()
    if isinstance(value, datetime):
        return value.isoformat()
    if hasattr(value, "item"):
        try:
            return value.item()
        except Exception:
            return str(value)
    return str(value)


def _summarize_generation_coverage(df: pd.DataFrame) -> dict:
    if df.empty:
        return {"row_count": 0, "series_count": 0, "coverage": []}

    coverage = (
        df.groupby("unique_id")["ds"]
        .agg(min_ds="min", max_ds="max", rows="count")
        .reset_index()
        .sort_values("unique_id")
    )
    return {
        "row_count": int(len(df)),
        "series_count": int(df["unique_id"].nunique()),
        "coverage": coverage.to_dict(orient="records"),
    }


def _read_previous_run_summary(data_dir: str) -> dict | None:
    """Read previous run_log.json for rowcount comparison."""
    path = Path(data_dir) / "run_log.json"
    if not path.exists():
        return None
    try:
        return json.loads(path.read_text())
    except Exception:
        return None


def _summarize_negative_forecasts(
    df: pd.DataFrame,
    sample_rows: int = 5,
) -> dict:
    if df.empty or "yhat" not in df.columns:
        return {
            "row_count": int(len(df)),
            "negative_rows": 0,
            "series": [],
            "sample": [],
        }

    neg = df[df["yhat"] < 0]
    if neg.empty:
        return {
            "row_count": int(len(df)),
            "negative_rows": 0,
            "series": [],
            "sample": [],
        }

    series_summary = (
        neg.groupby("unique_id")["yhat"]
        .agg(count="count", min_value="min", max_value="max", mean_value="mean")
        .reset_index()
        .sort_values("unique_id")
    )
    sample = neg[["unique_id", "ds", "yhat"]].head(sample_rows)
    return {
        "row_count": int(len(df)),
        "negative_rows": int(len(neg)),
        "series": series_summary.to_dict(orient="records"),
        "sample": sample.to_dict(orient="records"),
    }


def run_hourly_pipeline() -> dict:
    data_dir = os.getenv("RENEWABLE_DATA_DIR", "data/renewable")
    regions = _env_list("RENEWABLE_REGIONS", "CALI,ERCO,MISO")
    fuel_types = _env_list("RENEWABLE_FUELS", "WND,SUN")
    lookback_days = _env_int("LOOKBACK_DAYS", 30)

    # Horizon configuration: support both preset and direct override
    horizon_preset = os.getenv("RENEWABLE_HORIZON_PRESET", None)  # "24h" | "48h" | "72h"
    horizon_override = _env_int("RENEWABLE_HORIZON", 0)  # Legacy direct override

    # If direct override is set, use it; otherwise use preset (or None for default)
    if horizon_override > 0:
        horizon = horizon_override
        horizon_preset = None  # Ignore preset if direct override is set
    else:
        horizon = 24  # Default, may be overridden by preset

    cv_windows = _env_int("RENEWABLE_CV_WINDOWS", 2)
    cv_step_size = _env_int("RENEWABLE_CV_STEP_SIZE", 168)

    start_date = os.getenv("RENEWABLE_START_DATE", "")
    end_date = os.getenv("RENEWABLE_END_DATE", "")

    # Check if we should force run (e.g., manual dispatch)
    force_run = os.getenv("FORCE_RUN", "false").lower() == "true"

    # Data freshness check - skip full pipeline if no new data
    if not force_run:
        api_key = os.getenv("EIA_API_KEY", "")
        if not api_key:
            print("WARNING: EIA_API_KEY not set, skipping freshness check")
        else:
            run_log_path = Path(data_dir) / "run_log.json"
            freshness = check_all_series_freshness(
                regions=regions,
                fuel_types=fuel_types,
                run_log_path=run_log_path,
                api_key=api_key,
            )

            if not freshness.has_new_data:
                # No new data - return early with skip status
                skip_log = {
                    "run_at_utc": datetime.now(timezone.utc).isoformat(),
                    "status": "skipped",
                    "reason": "no_new_data",
                    "freshness_check": {
                        "checked_at_utc": freshness.checked_at_utc,
                        "summary": freshness.summary,
                        "series_status": freshness.series_status,
                    },
                    "config": {
                        "regions": regions,
                        "fuel_types": fuel_types,
                        "data_dir": data_dir,
                    },
                }

                # Write skip log (append to run_log.json)
                Path(data_dir).mkdir(parents=True, exist_ok=True)
                skip_log_path = Path(data_dir) / "skip_log.json"
                skip_log_path.write_text(
                    json.dumps(skip_log, indent=2, default=_json_default)
                )

                print(f"SKIPPED: {freshness.summary}")
                print(f"Skip log written to: {skip_log_path}")

                # Set output for GitHub Actions
                github_output = os.getenv("GITHUB_OUTPUT")
                if github_output:
                    with open(github_output, "a") as f:
                        f.write("status=skipped\n")

                return skip_log

            print(f"Freshness check: {freshness.summary}")
    else:
        print("FORCE_RUN=true - skipping freshness check")

    cfg = RenewablePipelineConfig(
        regions=regions,
        fuel_types=fuel_types,
        lookback_days=lookback_days,
        horizon=horizon,
        horizon_preset=horizon_preset,  # Apply preset if specified
        data_dir=data_dir,
        overwrite=True,
        start_date=start_date,
        end_date=end_date,
    )
    cfg.cv_windows = cv_windows
    cfg.cv_step_size = cv_step_size

    fetch_diagnostics: list[dict] = []
    results = run_full_pipeline(cfg, fetch_diagnostics=fetch_diagnostics)

    gen_path = cfg.generation_path()
    gen_df = pd.read_parquet(gen_path)
    generation_coverage = _summarize_generation_coverage(gen_df)

    max_lag_hours = _env_int("MAX_LAG_HOURS", 48)  # EIA publishes with 12-24h delay
    max_missing_ratio = _env_float("MAX_MISSING_RATIO", 0.02)
    report = validate_generation_df(
        gen_df,
        max_lag_hours=max_lag_hours,
        max_missing_ratio=max_missing_ratio,
        expected_series=_expected_series(regions, fuel_types),
    )

    forecasts_df = pd.read_parquet(cfg.forecasts_path())
    negative_forecasts = _summarize_negative_forecasts(forecasts_df)

    # Quality gates
    max_rowdrop_pct = _env_float("MAX_ROWDROP_PCT", 0.30)
    max_neg_forecast_ratio = _env_float("MAX_NEG_FORECAST_RATIO", 0.10)

    prev_run = _read_previous_run_summary(data_dir)
    prev_gen_rows = 0
    if prev_run:
        prev_gen_rows = prev_run.get("pipeline_results", {}).get("generation_rows", 0)

    curr_gen_rows = results.get("generation_rows", 0)
    rowdrop_ok = True
    if prev_gen_rows > 0:
        floor_ok = int(prev_gen_rows * (1.0 - max_rowdrop_pct))
        rowdrop_ok = curr_gen_rows >= floor_ok

    neg_forecast_ratio = 0.0
    if negative_forecasts["row_count"] > 0:
        neg_forecast_ratio = (
            negative_forecasts["negative_rows"] / negative_forecasts["row_count"]
        )
    neg_forecast_ok = neg_forecast_ratio <= max_neg_forecast_ratio

    quality_gates = {
        "rowdrop": {
            "ok": rowdrop_ok,
            "prev_rows": prev_gen_rows,
            "curr_rows": curr_gen_rows,
            "max_rowdrop_pct": max_rowdrop_pct,
        },
        "neg_forecast": {
            "ok": neg_forecast_ok,
            "ratio": neg_forecast_ratio,
            "max_ratio": max_neg_forecast_ratio,
        },
    }

    run_log = {
        "run_at_utc": datetime.now(timezone.utc).isoformat(),
        "config": {
            "regions": regions,
            "fuel_types": fuel_types,
            "lookback_days": lookback_days,
            "horizon": horizon,
            "cv_windows": cv_windows,
            "cv_step_size": cv_step_size,
            "data_dir": data_dir,
            "start_date": cfg.start_date,
            "end_date": cfg.end_date,
        },
        "pipeline_results": results,
        "validation": {
            "ok": report.ok,
            "message": report.message,
            "details": report.details,
        },
        "diagnostics": {
            "fetch": fetch_diagnostics,
            "generation_coverage": generation_coverage,
            "negative_forecasts": negative_forecasts,
        },
        "quality_gates": quality_gates,
    }

    Path(data_dir).mkdir(parents=True, exist_ok=True)
    (Path(data_dir) / "run_log.json").write_text(
        json.dumps(run_log, indent=2, default=_json_default)
    )

    # Check validation
    if not report.ok:
        raise SystemExit(f"VALIDATION_FAILED: {report.message} | {report.details}")

    # Check quality gates
    if not rowdrop_ok:
        raise SystemExit(
            f"QUALITY_GATE_FAILED: rowdrop | "
            f"curr={curr_gen_rows} prev={prev_gen_rows} max_drop={max_rowdrop_pct:.0%}"
        )
    if not neg_forecast_ok:
        raise SystemExit(
            f"QUALITY_GATE_FAILED: neg_forecast | "
            f"ratio={neg_forecast_ratio:.1%} max={max_neg_forecast_ratio:.0%}"
        )

    # Set output for GitHub Actions (successful run)
    github_output = os.getenv("GITHUB_OUTPUT")
    if github_output:
        with open(github_output, "a") as f:
            f.write("status=success\n")

    return run_log


def main() -> None:
    run_hourly_pipeline()


if __name__ == "__main__":
    main()


Overwriting src/renewable/jobs/run_hourly.py


In [12]:
%%writefile src/renewable/dag_builder.py
# file: src/renewable/dag_builder.py
"""Renewable pipeline DAG builder for Airflow."""
from __future__ import annotations

from datetime import datetime, timedelta
from typing import Any, Dict, Optional

from airflow import DAG
from airflow.operators.python import PythonOperator
AIRFLOW_AVAILABLE = True



DEFAULT_ARGS = {
    "owner": "data-team",
    "depends_on_past": False,
    "email_on_failure": False,
    "email_on_retry": False,
    "retries": 2,
    "retry_delay": timedelta(minutes=5),
}


def build_hourly_dag(
    dag_id: str = "renewable_hourly_pipeline",
    schedule: str = "17 * * * *",
    start_date: Optional[datetime] = None,
    default_args: Optional[Dict[str, Any]] = None,
) -> "DAG":
    if not AIRFLOW_AVAILABLE:
        raise ImportError("Airflow is not installed. Install apache-airflow to use build_hourly_dag().")

    from src.renewable.jobs.run_hourly import run_hourly_pipeline

    if start_date is None:
        start_date = datetime.utcnow() - timedelta(days=1)
    if default_args is None:
        default_args = DEFAULT_ARGS.copy()

    with DAG(
        dag_id=dag_id,
        default_args=default_args,
        description="Renewable hourly pipeline",
        schedule_interval=schedule,
        start_date=start_date,
        catchup=False,
        max_active_runs=1,
        tags=["renewable", "eia", "forecasting"],
    ) as dag:
        PythonOperator(
            task_id="run_hourly",
            python_callable=run_hourly_pipeline,
        )

    return dag


def build_dag_dot() -> str:
    return """digraph RENEWABLE_PIPELINE {
  rankdir=LR;
  node [shape=box, style="rounded,filled", fillcolor="#e8f5e9"];

  run_hourly;
}"""


Overwriting src/renewable/dag_builder.py


# git actions

In [13]:
%%writefile .github/workflows/pre-commit.yml
# file: .github/workflows/pre-commit.yml
name: pre-commit

on:
  pull_request:
  push:
    branches:
      - main

jobs:
  run:
    runs-on: ubuntu-latest
    steps:
      - uses: actions/checkout@v4

      - uses: actions/setup-python@v5
        with:
          python-version: "3.11"

      - uses: pre-commit/action@v3.0.1

      - name: Install test dependencies
        run: |
          python -m pip install --upgrade pip
          pip install pytest pandas numpy requests python-dotenv

      - name: Run smoke tests
        env:
          PYTHONPATH: ${{ github.workspace }}
        run: pytest tests/ -v -k "not slow" --tb=short || true


Overwriting .github/workflows/pre-commit.yml


In [14]:
%%writefile .github/workflows/renewable-hourly.yml
# file: .github/workflows/renewable-hourly.yml
name: renewable-hourly

on:
  workflow_dispatch:
    inputs:
      force_run:
        description: 'Force full pipeline run (skip freshness check)'
        type: boolean
        default: false
  schedule:
    - cron: "17 * * * *"

permissions:
  contents: write

concurrency:
  group: renewable-hourly
  cancel-in-progress: true

jobs:
  update:
    runs-on: ubuntu-latest
    timeout-minutes: 25
    env:
      EIA_API_KEY: ${{ secrets.EIA_API_KEY }}
      FORCE_RUN: ${{ github.event_name == 'workflow_dispatch' && inputs.force_run && 'true' || 'false' }}
      RENEWABLE_REGIONS: "CALI,ERCO,MISO"
      RENEWABLE_FUELS: "WND,SUN"
      LOOKBACK_DAYS: "30"
      RENEWABLE_HORIZON: "24"
      RENEWABLE_CV_WINDOWS: "2"
      RENEWABLE_CV_STEP_SIZE: "168"
      MAX_LAG_HOURS: "48"  # EIA publishes hourly data with 12-24h delay
      MAX_MISSING_RATIO: "0.02"
      RENEWABLE_DATA_DIR: "data/renewable"
      RENEWABLE_N_JOBS: "1"
      OMP_NUM_THREADS: "1"
      MKL_NUM_THREADS: "1"
      OPENBLAS_NUM_THREADS: "1"
      NUMBA_NUM_THREADS: "1"
      VECLIB_MAXIMUM_THREADS: "1"
    steps:
      - uses: actions/checkout@v4

      - uses: actions/setup-python@v5
        with:
          python-version: "3.11"

      - name: Check EIA API key
        run: |
          if [ -z "$EIA_API_KEY" ]; then
            echo "EIA_API_KEY is not set. Add it to repo secrets." >&2
            exit 1
          fi

      - name: Install deps
        run: |
          python -m pip install --upgrade pip
          # Install from pyproject.toml for single source of truth
          # Use -e for editable install (allows imports to work correctly)
          pip install -e .

      - name: Run hourly pipeline
        id: pipeline
        run: |
          python -m src.renewable.jobs.run_hourly

      - name: Quality gate check
        if: steps.pipeline.outputs.status != 'skipped'
        run: |
          python -c "
          import json, sys
          from pathlib import Path
          log_path = Path('data/renewable/run_log.json')
          if not log_path.exists():
              print('No run_log.json found')
              sys.exit(1)
          log = json.loads(log_path.read_text())
          val = log.get('validation', {})
          if not val.get('ok'):
              print(f'VALIDATION FAILED: {val.get(\"message\")}')
              print(f'Details: {val.get(\"details\")}')
              sys.exit(1)
          gates = log.get('quality_gates', {})
          if not gates.get('rowdrop', {}).get('ok', True):
              print(f'ROWDROP GATE FAILED: {gates.get(\"rowdrop\")}')
              sys.exit(1)
          if not gates.get('neg_forecast', {}).get('ok', True):
              print(f'NEG_FORECAST GATE FAILED: {gates.get(\"neg_forecast\")}')
              sys.exit(1)
          print('QUALITY GATES PASSED')
          "

      - name: Skip notification
        if: steps.pipeline.outputs.status == 'skipped'
        run: |
          echo "### Pipeline skipped - no new EIA data" >> "$GITHUB_STEP_SUMMARY"
          if [ -f data/renewable/skip_log.json ]; then
            python -c "
          import json
          from pathlib import Path
          data = json.loads(Path('data/renewable/skip_log.json').read_text())
          freshness = data.get('freshness_check', {})
          print(f'- Checked at: {freshness.get(\"checked_at_utc\")}')
          print(f'- Summary: {freshness.get(\"summary\")}')
          " >> "$GITHUB_STEP_SUMMARY"
          fi

      - name: Summarize run
        if: always() && steps.pipeline.outputs.status != 'skipped'
        run: |
          if [ -f data/renewable/run_log.json ]; then
          python - <<'PY' | tee -a "$GITHUB_STEP_SUMMARY"
          import json
          from pathlib import Path

          data = json.loads(Path("data/renewable/run_log.json").read_text())
          validation = data.get("validation", {})
          details = validation.get("details", {})
          pipeline = data.get("pipeline_results", {})
          interp = pipeline.get("interpretability", {})

          lines = [
              "### Renewable hourly run",
              f"- run_at_utc: {data.get('run_at_utc')}",
              f"- validation_ok: {validation.get('ok')}",
              f"- message: {validation.get('message')}",
              f"- max_ds: {details.get('max_ds')}",
              f"- lag_hours: {details.get('lag_hours')}",
              f"- best_model: {pipeline.get('best_model')}",
              f"- best_rmse: {pipeline.get('best_rmse', 0):.1f}",
              "",
              "#### Interpretability",
              f"- series_count: {interp.get('series_count', 0)}",
              f"- output_dir: {interp.get('output_dir', 'N/A')}",
          ]
          print("\n".join(lines))
          PY
          else
          echo "No run_log.json found." >> "$GITHUB_STEP_SUMMARY"
          fi

      - name: Commit updated artifacts
        if: steps.pipeline.outputs.status != 'skipped'
        run: |
          git config user.name "github-actions[bot]"
          git config user.email "41898282+github-actions[bot]@users.noreply.github.com"
          git add data/renewable/generation.parquet \
            data/renewable/weather.parquet \
            data/renewable/forecasts.parquet \
            data/renewable/run_log.json
          # Add interpretability artifacts if they exist
          if [ -d data/renewable/interpretability ]; then
            git add data/renewable/interpretability/
          fi
          git commit -m "renewable: hourly data update (UTC)" || echo "No changes to commit"
          git push


Overwriting .github/workflows/renewable-hourly.yml
