In [None]:
"""
ModelRunLogger: Generic TensorBoard logging for regression models + convenience wrapper.

Features:
- Run directory management
- Config & hyperparameter logging (as scalars + JSON)
- Regression metrics (MSE, RMSE, MAE, R²) for train/test/etc.
- Datetime-indexed series logging as scalars using walltime (Option 1)
- Multi-series logging (automatic overlay via consistent tags)
- Matplotlib figure logging
- Optional CSV export of each logged series
- Convenience wrapper function `log_basic_run` to avoid boilerplate

Usage (simple):
    summary = log_basic_run(
        model=model,
        model_name="zoneA_lgbm",
        X_train=X_train, y_train=y_train,
        X_test=X_test,   y_test=y_test,
        model_params=lgbm_params,
        pnl_series=pnl_series,
        sharpe_series=sharpe_series,
        extra_series={"hit_rate": hit_rate_series},
        zone="A",
        target="da_imb_spread",
    )
    print("Test RMSE:", summary["metrics"]["test_final"]["rmse"])
"""

from __future__ import annotations

import json
import math
import time
import shutil
import datetime as dt
from pathlib import Path
from typing import Any, Dict, Mapping

import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
from torch.utils.tensorboard import SummaryWriter


# ---------------------------------------------------------------------------
# Utility functions
# ---------------------------------------------------------------------------

def _utc_timestamp_string() -> str:
    return dt.datetime.utcnow().strftime("%Y%m%d_%H%M%%S".replace("%%", "%"))  # safe literal %


def flatten_dict(d: Mapping[str, Any], parent_key: str = "", sep: str = ".") -> Dict[str, Any]:
    """Flatten nested dicts/lists/tuples into a flat mapping for logging."""
    flat: Dict[str, Any] = {}
    for k, v in d.items():
        nk = f"{parent_key}{sep}{k}" if parent_key else str(k)
        if isinstance(v, Mapping):
            flat.update(flatten_dict(v, parent_key=nk, sep=sep))
        elif isinstance(v, (list, tuple)):
            if (len(v) <= 8) and all(isinstance(x, (int, float, str, bool, type(None))) for x in v):
                flat[nk] = list(v)
            else:
                flat[nk] = f"{type(v).__name__}(len={len(v)})"
        else:
            if isinstance(v, (np.integer, np.floating)):
                v = v.item()
            flat[nk] = v
    return flat


def compute_regression_metrics(y_true: pd.Series, y_pred: pd.Series) -> Dict[str, float]:
    """
    Return dict of regression metrics (MSE, RMSE, MAE, R²).
    Assumes alignment by index or reproducible alignment via concatenation.
    """
    df = pd.concat([y_true.rename("y_true"), y_pred.rename("y_pred")], axis=1).dropna()
    if df.empty:
        return {"mse": np.nan, "rmse": np.nan, "mae": np.nan, "r2": np.nan}
    err = df["y_pred"] - df["y_true"]
    mse = float(np.mean(err ** 2))
    rmse = float(math.sqrt(mse))
    mae = float(np.mean(np.abs(err)))
    var = float(np.var(df["y_true"], ddof=0))
    r2 = float(1 - mse / var) if var > 0 else np.nan
    return {"mse": mse, "rmse": rmse, "mae": mae, "r2": r2}


# ---------------------------------------------------------------------------
# Core Logger
# ---------------------------------------------------------------------------

class ModelRunLogger:
    """
    TensorBoard-based, model-agnostic logger for regression workflows.

    - Per-run directory management with timestamped naming
    - Hyperparameter & config capture (JSON + scalars)
    - Regression metrics logging
    - Datetime-indexed time series logging (walltime axis)
    - Multi-series logging
    - Matplotlib figure logging
    - JSON snapshots (config / metrics)
    - Optional CSV export for each series

    Typical usage:
        with ModelRunLogger("my_model") as logger:
            logger.log_hparams(model_params=..., data_params=...)
            logger.log_regression_metrics(y_train, y_train_pred, "train")
            logger.log_regression_metrics(y_test, y_test_pred, "test")
            logger.log_series_datetime("pnl", pnl_series)
    """

    def __init__(
        self,
        model_name: str,
        base_run_dir: str | Path = "runs",
        run_name: str | None = None,
        group_by_date: bool = True,
        overwrite: bool = False,
        save_series_csv: bool = True,
        extra_config: Dict[str, Any] | None = None,
    ):
        self.model_name = model_name
        date_part = dt.date.today().isoformat() if group_by_date else ""
        base_dir = Path(base_run_dir)
        if date_part:
            base_dir = base_dir / date_part
        run_name = run_name or f"{model_name}_{dt.datetime.utcnow().strftime('%Y%m%d_%H%M%S')}"
        self.run_dir = base_dir / run_name

        if overwrite and self.run_dir.exists():
            shutil.rmtree(self.run_dir)
        self.run_dir.mkdir(parents=True, exist_ok=True)

        self.writer = SummaryWriter(log_dir=str(self.run_dir))
        self.config: Dict[str, Any] = {"model_name": model_name}
        if extra_config:
            self.config.update(extra_config)

        self.metrics_history: Dict[str, Dict[str, float]] = {}
        self.save_series_csv = save_series_csv
        self._closed = False
        self._start_time = time.time()

        self._save_json("config.json", self.config)

    # ----- Config / Hyperparameters -----

    def update_config(self, **kwargs):
        self.config.update(kwargs)
        self._save_json("config.json", self.config)

    def log_hparams(
        self,
        model_params: Dict[str, Any],
        data_params: Dict[str, Any] | None = None,
        tag_prefix: str = "hparams"
    ):
        combined = {"model_params": model_params}
        if data_params:
            combined["data_params"] = data_params

        flat = flatten_dict(combined)
        for k, v in flat.items():
            if isinstance(v, (int, float)) and not isinstance(v, bool):
                self.writer.add_scalar(f"{tag_prefix}/{k}", v, global_step=0)
            else:
                self.writer.add_text(f"{tag_prefix}_text/{k}", str(v), global_step=0)

        self.config.update(combined)
        self._save_json("config.json", self.config)

    # ----- Metrics -----

    def log_regression_metrics(
        self,
        y_true: pd.Series,
        y_pred: pd.Series | np.ndarray | list,
        split: str,
        step: int | None = None,
        extra_metrics: Dict[str, float] | None = None,
        prefix: str | None = None,
    ) -> Dict[str, float]:
        if not isinstance(y_true, pd.Series):
            raise TypeError("y_true must be a pandas Series.")
        if not isinstance(y_pred, pd.Series):
            y_pred = pd.Series(y_pred, index=y_true.index)

        metrics = compute_regression_metrics(y_true, y_pred)
        if extra_metrics:
            metrics.update(extra_metrics)

        tag_root = f"{prefix}/{split}" if prefix else split
        for k, v in metrics.items():
            self.writer.add_scalar(f"{tag_root}/{k}", v, global_step=step)

        hist_key = f"{split}_{step if step is not None else 'final'}"
        self.metrics_history[hist_key] = metrics
        self._save_json("metrics.json", self.metrics_history)
        return metrics

    def log_manual_metric(self, name: str, value: float, split: str = "custom", step: int | None = None):
        self.writer.add_scalar(f"{split}/{name}", float(value), global_step=step)

    # ----- Time Series (walltime logging) -----

    def log_series_datetime(
        self,
        name: str,
        series: pd.Series,
        group: str = "timeseries",
        sort: bool = True,
        step_offset: int = 0,
        downsample: str | None = None,
    ):
        if not isinstance(series, pd.Series):
            raise TypeError("series must be a pandas Series.")
        if not isinstance(series.index, pd.DatetimeIndex):
            raise ValueError("Series must have a DatetimeIndex.")

        s = series.dropna()
        if s.empty:
            return
        if s.index.tz is None:
            s.index = s.index.tz_localize("UTC")
        else:
            s.index = s.index.tz_convert("UTC")
        if sort:
            s = s.sort_index()
        if downsample:
            s = s.resample(downsample).last().dropna()

        tag = f"{group}/{name}"
        for i, (ts, val) in enumerate(s.items()):
            if pd.isna(val):
                continue
            step = step_offset + i
            self.writer.add_scalar(tag, float(val), global_step=step, walltime=ts.timestamp())

        if self.save_series_csv:
            s.to_frame(name=name).to_csv(self.run_dir / f"{name}.csv", index_label="timestamp")

    def log_series_datetime_multi(
        self,
        name: str,
        series_map: Dict[str, pd.Series],
        group: str = "timeseries_multi",
        align_inner: bool = True,
        downsample: str | None = None,
    ):
        if align_inner:
            df = pd.concat(series_map, axis=1).dropna(how="all")
            for label in df.columns.levels[0]:
                s = df[label].dropna()
                self.log_series_datetime(f"{name}/{label}", s, group=group, downsample=downsample)
        else:
            for label, s in series_map.items():
                self.log_series_datetime(f"{name}/{label}", s, group=group, downsample=downsample)

    # ----- Figures -----

    def log_figure(self, fig, tag: str, step: int | None = None, close: bool = True):
        self.writer.add_figure(tag, fig, global_step=step)
        if close:
            plt.close(fig)

    def log_series_overlay_figure(
        self,
        series_map: Dict[str, pd.Series],
        tag: str,
        title: str = "",
        ylabel: str = "",
        normalize: bool = False,
        step: int | None = None,
        figsize=(10, 3),
    ):
        fig, ax = plt.subplots(figsize=figsize)
        for label, s in series_map.items():
            if not isinstance(s, pd.Series):
                continue
            if not isinstance(s.index, pd.DatetimeIndex):
                raise ValueError(f"Series '{label}' must have a DatetimeIndex.")
            t = s.dropna()
            if t.empty:
                continue
            if t.index.tz is None:
                t.index = t.index.tz_localize("UTC")
            else:
                t.index = t.index.tz_convert("UTC")
            t = t.sort_index()
            if normalize and t.iloc[0] != 0:
                t = t / abs(t.iloc[0]) - 1
            ax.plot(t.index, t.values, label=label)
        ax.set_title(title or "Series Overlay")
        if ylabel:
            ax.set_ylabel(ylabel)
        ax.grid(alpha=0.3)
        ax.legend(loc="best")
        self.log_figure(fig, tag=tag, step=step, close=True)

    # ----- Persistence helpers -----

    def _save_json(self, filename: str, obj: Any):
        with open(self.run_dir / filename, "w") as f:
            json.dump(obj, f, indent=2, default=str)

    # ----- Lifecycle -----

    def close(self):
        if self._closed:
            return
        runtime = time.time() - self._start_time
        self.writer.add_scalar("meta/runtime_seconds", runtime, global_step=0)
        self.writer.flush()
        self.writer.close()
        self._closed = True

    def __enter__(self):
        return self

    def __exit__(self, exc_type, *_):
        self.close()


# ---------------------------------------------------------------------------
# Convenience wrapper: Option A (log_basic_run)
# ---------------------------------------------------------------------------

def log_basic_run(
    *,
    model,
    model_name: str,
    X_train: pd.DataFrame,
    y_train: pd.Series,
    X_test: pd.DataFrame,
    y_test: pd.Series,
    model_params: dict | None = None,
    pnl_series: pd.Series | None = None,
    sharpe_series: pd.Series | None = None,
    extra_series: dict[str, pd.Series] | None = None,
    zone: str | None = None,
    target: str | None = None,
    base_run_dir: str | Path = "runs",
    run_name: str | None = None,
    performance_group: str = "performance",
    make_overlay_figure: bool = True,
    overlay_tag: str = "figures/performance_overlay",
    overlay_ylabel: str = "Value",
) -> dict:
    """
    Convenience wrapper for common regression logging pattern.

    Returns a summary dict with run_dir, metrics, and config.
    """
    if model_params is None and hasattr(model, "get_params"):
        try:
            model_params = model.get_params()
        except Exception:
            model_params = {}

    extra_config = {}
    if zone is not None:
        extra_config["zone"] = zone
    if target is not None:
        extra_config["target"] = target

    with ModelRunLogger(
        model_name,
        extra_config=extra_config,
        base_run_dir=base_run_dir,
        run_name=run_name,
    ) as logger:
        logger.log_hparams(
            model_params=model_params,
            data_params={"n_train": len(X_train), "n_test": len(X_test)},
        )

        # Predictions
        y_train_pred = model.predict(X_train)
        y_test_pred = model.predict(X_test)

        # Metrics
        logger.log_regression_metrics(
            y_train, pd.Series(y_train_pred, index=y_train.index), split="train"
        )
        logger.log_regression_metrics(
            y_test, pd.Series(y_test_pred, index=y_test.index), split="test"
        )

        # Series logging
        series_logged: Dict[str, pd.Series] = {}
        if pnl_series is not None:
            logger.log_series_datetime("pnl", pnl_series, group=performance_group)
            series_logged["pnl"] = pnl_series
        if sharpe_series is not None:
            logger.log_series_datetime(
                "rolling_sharpe", sharpe_series, group=performance_group
            )
            series_logged["rolling_sharpe"] = sharpe_series
        if extra_series:
            for name, s in extra_series.items():
                logger.log_series_datetime(name, s, group=performance_group)
                series_logged[name] = s

        if make_overlay_figure and series_logged:
            logger.log_series_overlay_figure(
                series_logged,
                tag=overlay_tag,
                title="Performance Overlay",
                ylabel=overlay_ylabel,
            )

        summary = {
            "run_dir": str(logger.run_dir),
            "model_name": model_name,
            "metrics": logger.metrics_history,
            "config": logger.config,
        }
    return summary


In [None]:
mport pandas as pd
from lightgbm import LGBMRegressor
from tracking.model_run_logger import log_basic_run

# Assume you already have time-indexed train/test sets
# X_train, y_train, X_test, y_test (y_* are pd.Series with DatetimeIndex)

lgbm_params = {
    "n_estimators": 400,
    "learning_rate": 0.05,
    "num_leaves": 64,
    "subsample": 0.9,
    "colsample_bytree": 0.8,
    "random_state": 42,
}

model = LGBMRegressor(**lgbm_params)
model.fit(X_train, y_train)

# Optional performance series you computed elsewhere
pnl_series = pnl_series  # pd.Series, DatetimeIndex
sharpe_series = sharpe_series  # pd.Series, DatetimeIndex

summary = log_basic_run(
    model=model,
    model_name="zoneA_lgbm",
    X_train=X_train,
    y_train=y_train,
    X_test=X_test,
    y_test=y_test,
    model_params=lgbm_params,
    pnl_series=pnl_series,
    sharpe_series=sharpe_series,
    zone="A",
    target="da_imb_spread",
)

print("Run stored at:", summary["run_dir"])
print("Test RMSE:", summary["metrics"]["test_final"]["rmse"])