# Additional Model

In [None]:
%%writefile api/src/data_engineering.py
# api/src/data_engineering.py

import pandas as pd

def load_my_dataset(path: str) -> pd.DataFrame:
    # 1) read raw CSV/DB
    df = pd.read_csv(path)
    # 2) clean / filter / feature‑engineer
    df = df.dropna(subset=["important_feature"])
    df["new_feature"] = df["colA"] / df["colB"]
    return df

def split_features_target(df: pd.DataFrame, target_col: str, test_size=0.2, random_state=0):
    from sklearn.model_selection import train_test_split
    X = df.drop(columns=[target_col])
    y = df[target_col]
    return train_test_split(X, y, test_size=test_size, random_state=random_state)


In [None]:
%%writefile api/src/trainers/my_trainer.py
# api/src/trainers/my_trainer.py

from .base import BaseTrainer, TrainResult
import mlflow
from ..data_engineering import load_my_dataset, split_features_target
from sklearn.ensemble import GradientBoostingClassifier
from sklearn.metrics import accuracy_score

class MyDatasetTrainer(BaseTrainer):
    name = "my_dataset_model"
    model_type = "classification"

    def default_hyperparams(self):
        return {
            "n_estimators": 100,
            "learning_rate": 0.1,
            "random_state": 42,
        }

    def train(self, data_path: str = "data/my.csv", **overrides) -> TrainResult:
        # Merge overrides
        hp = self.merge_hyperparams(overrides)
        # Load & split
        df = load_my_dataset(data_path)
        X_tr, X_te, y_tr, y_te = split_features_target(df, target_col="label", test_size=0.2, random_state=hp["random_state"])
        # Fit model
        model = GradientBoostingClassifier(
            n_estimators=hp["n_estimators"],
            learning_rate=hp["learning_rate"],
            random_state=hp["random_state"],
        ).fit(X_tr, y_tr)
        # Evaluate
        preds = model.predict(X_te)
        metrics = {"accuracy": accuracy_score(y_te, preds)}
        # Log to MLflow
        with mlflow.start_run(run_name=self.name) as run:
            mlflow.log_params(hp)
            mlflow.log_metrics(metrics)
            # log model artifact (sklearn flavor)
            mlflow.sklearn.log_model(model, artifact_path="model", registered_model_name=self.name)
        return TrainResult(run_id=run.info.run_id, metrics=metrics)


In [None]:
%%writefile api/src/registry/registry.py
from src.registry.registry import register
from src.trainers.my_trainer import MyDatasetTrainer
from src.registry.types import TrainerSpec

spec = TrainerSpec(
  name=MyDatasetTrainer.name,
  cls=MyDatasetTrainer,
  default_params=MyDatasetTrainer().default_hyperparams()
)
register(spec)


In [None]:
%%writefile api/app/schemas/my_dataset.py
# api/app/schemas/my_dataset.py

from pydantic import BaseModel, Field
from typing import List

class MyDataFeatures(BaseModel):
    feature1: float = Field(..., description="…")
    feature2: float = Field(..., description="…")
    # …etc

class MyDataTrainRequest(BaseModel):
    data_path: str = Field("data/my.csv", description="Path to CSV")
    hyperparams: dict[str, float] | None = None
    async_training: bool = False

class MyDataPredictRequest(BaseModel):
    samples: List[MyDataFeatures]

class MyDataPredictResponse(BaseModel):
    predictions: List[str]
    probabilities: List[float]


In [None]:
%%writefile api/app/new_main.py
from .schemas.my_dataset import (
  MyDataTrainRequest, MyDataPredictRequest, MyDataPredictResponse
)
from .services.ml.model_service import model_service

@app.post("/api/v1/my/train", status_code=202)
async def train_mydata(req: MyDataTrainRequest, bt: BackgroundTasks):
    if req.async_training:
        bt.add_task(model_service.train_via_registry, "my_dataset_model", req.hyperparams)
        return {"status": "queued"}
    else:
        run_id = await model_service.train_via_registry("my_dataset_model", req.hyperparams)
        return {"status": "completed", "run_id": run_id}

@app.post("/api/v1/my/predict", response_model=MyDataPredictResponse)
async def predict_mydata(req: MyDataPredictRequest):
    # assuming model_service.predict_generic exists or add a custom wrapper
    preds, probs = await model_service.predict_generic("my_dataset_model", [f.dict() for f in req.samples])
    return MyDataPredictResponse(predictions=preds, probabilities=probs)


In [None]:
%%writefile api/src/trainers/__init__.py
# api/src/trainers/__init__.py
from .base import BaseTrainer, TrainResult

__all__ = ["BaseTrainer", "TrainResult"] 

# original working Models

In [None]:
%%writefile api/src/trainers/base.py
# api/src/trainers/base.py
from __future__ import annotations
from dataclasses import dataclass, field
from typing import Any, Dict, Optional, Protocol
import mlflow

@dataclass
class TrainResult:
    run_id: str
    metrics: Dict[str, float]
    artifacts: Dict[str, str] = field(default_factory=dict)

class SupportsPyFunc(Protocol):
    # Minimal protocol if custom loader is needed later
    def predict(self, X): ...

class BaseTrainer:
    """
    Minimal trainer abstraction:
      * implement `train(**hyperparams)` returning TrainResult
      * optionally override default_hyperparams()
    """
    name: str
    model_type: str = "generic"

    def default_hyperparams(self) -> Dict[str, Any]:
        return {}

    def merge_hyperparams(self, overrides: Dict[str, Any] | None) -> Dict[str, Any]:
        params = self.default_hyperparams().copy()
        if overrides:
            params.update({k: v for k, v in overrides.items() if v is not None})
        return params

    def train(self, **hyperparams) -> TrainResult:  # pragma: no cover - interface
        raise NotImplementedError

    # Optional hook – if a trainer needs a special load path
    def load_pyfunc(self, run_uri: str):
        return mlflow.pyfunc.load_model(run_uri) 

In [None]:
%%writefile api/src/trainers/iris_rf_trainer.py
# api/src/trainers/iris_rf_trainer.py
from __future__ import annotations
from .base import BaseTrainer, TrainResult
from sklearn.datasets import load_iris
from sklearn.model_selection import train_test_split
from sklearn.ensemble import RandomForestClassifier
from sklearn.metrics import accuracy_score, f1_score, precision_score, recall_score
import mlflow

class IrisRandomForestTrainer(BaseTrainer):
    name = "iris_random_forest"
    model_type = "classification"

    def default_hyperparams(self):
        return {
            "n_estimators": 300,
            "max_depth": None,
            "random_state": 42,
        }

    def train(self, **overrides) -> TrainResult:
        hp = self.merge_hyperparams(overrides)
        iris = load_iris(as_frame=True)
        X, y = iris.data, iris.target
        Xtr, Xte, ytr, yte = train_test_split(
            X, y, test_size=0.25, stratify=y, random_state=hp["random_state"]
        )
        rf = RandomForestClassifier(
            n_estimators=hp["n_estimators"],
            max_depth=hp["max_depth"],
            random_state=hp["random_state"],
            n_jobs=-1,
            class_weight="balanced",
        ).fit(Xtr, ytr)

        preds = rf.predict(Xte)
        metrics = {
            "accuracy": accuracy_score(yte, preds),
            "f1_macro": f1_score(yte, preds, average="macro"),
            "precision_macro": precision_score(yte, preds, average="macro"),
            "recall_macro": recall_score(yte, preds, average="macro"),
        }

        class _Wrapper(mlflow.pyfunc.PythonModel):
            def __init__(self, model, cols):
                self.model = model
                self.cols = cols
            def predict(self, context, model_input, params=None):
                import pandas as pd, numpy as np
                df = model_input if isinstance(model_input, pd.DataFrame) else pd.DataFrame(model_input, columns=self.cols)
                return self.model.predict_proba(df)

        with mlflow.start_run(run_name=self.name) as run:
            mlflow.log_params({k: v for k, v in hp.items()})
            mlflow.log_metrics(metrics)
            sig = mlflow.models.signature.infer_signature(X, rf.predict_proba(X))
            mlflow.pyfunc.log_model(
                artifact_path="model",
                python_model=_Wrapper(rf, list(X.columns)),
                registered_model_name=self.name,
                input_example=X.head(),
                signature=sig,
            )
            return TrainResult(run_id=run.info.run_id, metrics=metrics) 

In [None]:
%%writefile api/src/registry/__init__.py
# api/src/registry/__init__.py
from .registry import register, all_names, get, load_from_entry_point
from .types import TrainerSpec

__all__ = ["register", "all_names", "get", "load_from_entry_point", "TrainerSpec"] 

In [None]:
%%writefile api/src/registry/types.py
# api/src/registry/types.py
from __future__ import annotations
from dataclasses import dataclass
from typing import Type, Dict, Any
from ..trainers.base import BaseTrainer

@dataclass
class TrainerSpec:
    name: str
    cls: Type[BaseTrainer]
    default_params: Dict[str, Any] 

In [None]:
%%writefile api/src/registry/registry.py
# api/src/registry/registry.py
from __future__ import annotations
from importlib import import_module
from typing import Dict, Iterable, Type
from .types import TrainerSpec
from ..trainers.base import BaseTrainer

_REGISTRY: Dict[str, TrainerSpec] = {}

def register(spec: TrainerSpec) -> None:
    _REGISTRY[spec.name] = spec

def all_names() -> Iterable[str]:
    return _REGISTRY.keys()

def get(name: str) -> TrainerSpec:
    return _REGISTRY[name]

def load_from_entry_point(dotted: str, name: str | None = None):
    """
    Load 'pkg.module:ClassName' into registry.
    """
    mod_path, cls_name = dotted.split(":")
    mod = import_module(mod_path)
    cls: Type[BaseTrainer] = getattr(mod, cls_name)
    inst_name = name or getattr(cls, "name", cls_name.lower())
    spec = TrainerSpec(name=inst_name, cls=cls, default_params=cls().default_hyperparams())
    register(spec)
    return spec 

In [None]:
%%writefile api/app/schemas/train.py
from typing import Optional, Dict, Any
from pydantic import BaseModel, Field
from .bayes import BayesCancerParams

class IrisTrainRequest(BaseModel):
    """
    Kick off Iris model training.

    • `model_type` – 'rf' (Random‑Forest) | 'logreg'  
    • `hyperparams` – optional scikit‑learn overrides, e.g. {"n_estimators": 500}  
    • `async_training` – true ⇒ returns job_id immediately
    """
    model_type: str = Field(
        default="rf",
        description="Which Iris trainer to run: 'rf' or 'logreg'"
    )
    hyperparams: Optional[Dict[str, Any]] = Field(
        default=None,
        description="Optional hyper‑parameter overrides"
    )
    async_training: bool = Field(
        default=False,
        description="Run in background and return job ID"
    )

class CancerTrainRequest(BaseModel):
    """
    Train Breast‑Cancer classifiers.

    • `model_type` – 'bayes' (hier‑Bayes) | 'stub' (quick LogisticRegression)  
    • `params` – validated Bayesian hyper‑parameters (only used when model_type='bayes')  
    • `async_training` – background flag
    """
    model_type: str = Field(
        default="bayes",
        description="Which cancer model to train: 'bayes' or 'stub'"
    )
    params: Optional[BayesCancerParams] = Field(
        default=None,
        description="Bayesian hyper‑parameters; ignored for stub model"
    )
    async_training: bool = Field(
        default=False,
        description="Run in background and return job ID"
    )

class BayesTrainRequest(BaseModel):
    """Request model for Bayesian cancer model training"""
    params: Optional[BayesCancerParams] = Field(
        default=None, 
        description="Bayesian hyperparameters. If None, uses defaults."
    )
    async_training: bool = Field(
        default=False,
        description="If True, returns job_id immediately. If False, waits for completion."
    )

class BayesTrainResponse(BaseModel):
    """Response model for Bayesian training"""
    run_id: str = Field(description="MLflow run ID")
    job_id: Optional[str] = Field(default=None, description="Background job ID if async")
    status: str = Field(description="Training status: 'completed', 'queued', 'failed'")
    message: Optional[str] = Field(default=None, description="Status message or error")

class BayesConfigResponse(BaseModel):
    """Response model for Bayesian configuration endpoint"""
    defaults: BayesCancerParams = Field(description="Default hyperparameters")
    bounds: dict = Field(description="Parameter bounds for UI controls")
    descriptions: dict = Field(description="Parameter descriptions for tooltips")
    runtime_estimate: dict = Field(description="Runtime estimation factors")

class BayesRunMetrics(BaseModel):
    """Response model for Bayesian run metrics"""
    run_id: str
    accuracy: float
    rhat_max: Optional[float] = None
    ess_bulk_min: Optional[float] = None
    ess_tail_min: Optional[float] = None
    waic: Optional[float] = None
    loo: Optional[float] = None
    status: str
    warnings: list[str] = Field(default_factory=list) 


In [None]:
%%writefile api/app/schemas/cancer.py
from pydantic import BaseModel, Field
from typing import List, Optional

class CancerFeatures(BaseModel):
    """Breast cancer diagnostic features."""
    mean_radius: float = Field(..., description="Mean of distances from center to points on perimeter")
    mean_texture: float = Field(..., description="Standard deviation of gray-scale values")
    mean_perimeter: float = Field(..., description="Mean size of the core tumor")
    mean_area: float = Field(..., description="Mean area of the core tumor")
    mean_smoothness: float = Field(..., description="Mean of local variation in radius lengths")
    mean_compactness: float = Field(..., description="Mean of perimeter^2 / area - 1.0")
    mean_concavity: float = Field(..., description="Mean of severity of concave portions of the contour")
    mean_concave_points: float = Field(..., description="Mean for number of concave portions of the contour")
    mean_symmetry: float = Field(..., description="Mean symmetry")
    mean_fractal_dimension: float = Field(..., description="Mean for 'coastline approximation' - 1")
    
    # SE features (standard error)
    se_radius: float = Field(..., description="Standard error of radius")
    se_texture: float = Field(..., description="Standard error of texture")
    se_perimeter: float = Field(..., description="Standard error of perimeter")
    se_area: float = Field(..., description="Standard error of area")
    se_smoothness: float = Field(..., description="Standard error of smoothness")
    se_compactness: float = Field(..., description="Standard error of compactness")
    se_concavity: float = Field(..., description="Standard error of concavity")
    se_concave_points: float = Field(..., description="Standard error of concave points")
    se_symmetry: float = Field(..., description="Standard error of symmetry")
    se_fractal_dimension: float = Field(..., description="Standard error of fractal dimension")
    
    # Worst features
    worst_radius: float = Field(..., description="Worst radius")
    worst_texture: float = Field(..., description="Worst texture")
    worst_perimeter: float = Field(..., description="Worst perimeter")
    worst_area: float = Field(..., description="Worst area")
    worst_smoothness: float = Field(..., description="Worst smoothness")
    worst_compactness: float = Field(..., description="Worst compactness")
    worst_concavity: float = Field(..., description="Worst concavity")
    worst_concave_points: float = Field(..., description="Worst concave points")
    worst_symmetry: float = Field(..., description="Worst symmetry")
    worst_fractal_dimension: float = Field(..., description="Worst fractal dimension")

class CancerPredictRequest(BaseModel):
    """Cancer prediction request (allows 'rows' alias)."""
    model_type: str = Field("bayes", description="Model type: 'bayes', 'logreg', or 'rf'")
    samples: List[CancerFeatures] = Field(
        ...,
        description="Breast-cancer feature vectors",
        alias="rows",
    )
    posterior_samples: Optional[int] = Field(
        None, ge=10, le=10_000, description="Posterior draws for uncertainty"
    )

    class Config:
        populate_by_name = True
        extra = "forbid"

class CancerPredictResponse(BaseModel):
    """Cancer prediction response."""
    predictions: List[str] = Field(..., description="Predicted diagnosis (M=malignant, B=benign)")
    probabilities: List[float] = Field(..., description="Probability of malignancy")
    uncertainties: Optional[List[float]] = Field(None, description="Uncertainty estimates (if requested)")
    input_received: List[CancerFeatures] = Field(..., description="Echo of input features") 

In [None]:
%%writefile api/app/schemas/iris.py
from pydantic import BaseModel, Field
from typing import List, Optional

class IrisFeatures(BaseModel):
    """Iris measurement features."""
    sepal_length: float = Field(..., description="Sepal length in cm", ge=0, le=10)
    sepal_width: float = Field(..., description="Sepal width in cm", ge=0, le=10)
    petal_length: float = Field(..., description="Petal length in cm", ge=0, le=10)
    petal_width: float = Field(..., description="Petal width in cm", ge=0, le=10)

class IrisPredictRequest(BaseModel):
    """Iris prediction request (accepts legacy 'rows' alias)."""
    model_type: str = Field("rf", description="Model type: 'rf' or 'logreg'")
    samples: List[IrisFeatures] = Field(
        ...,
        description="List of iris measurements",
        alias="rows",
    )

    class Config:
        populate_by_name = True
        extra = "forbid"

class IrisPredictResponse(BaseModel):
    """Iris prediction response."""
    predictions: List[str] = Field(..., description="Predicted iris species")
    probabilities: List[List[float]] = Field(..., description="Class probabilities")
    input_received: List[IrisFeatures] = Field(..., description="Echo of input features") 

In [None]:
%%writefile api/app/schemas/bayes.py
from pydantic import BaseModel, Field, validator
from typing import Optional

class BayesCancerParams(BaseModel):
    draws: int = Field(1000, ge=200, le=20_000, description="Posterior draws retained")
    tune: int = Field(1000, ge=200, le=20_000, description="Tuning (warmup) steps")
    target_accept: float = Field(0.95, ge=0.80, le=0.999, description="NUTS target acceptance")
    compute_waic: bool = Field(True, description="Attempt WAIC (may be slow)")
    compute_loo: bool = Field(False, description="Attempt LOO (slower); auto-off by default")
    max_rhat_warn: float = Field(1.01, ge=1.0, le=1.1)
    min_ess_warn: int = Field(400, ge=50, le=5000)

    @validator("tune")
    def tune_reasonable(cls, v, values):
        if "draws" in values and v < 0.2 * values["draws"]:
            # gentle warning, not rejection
            pass
        return v

    def to_kwargs(self):
        return {
            "draws": self.draws,
            "tune": self.tune,
            "target_accept": self.target_accept,
        } 

In [None]:
%%writefile api/app/ml/__init__.py
"""
ML sub-package – exposes built-in trainers so the service can import
`app.ml.builtin_trainers` with an absolute import.
"""

from .builtin_trainers import (
    train_iris_random_forest,
    train_iris_logreg,
    train_breast_cancer_bayes,
    train_breast_cancer_stub,
)

__all__ = [
    "train_iris_random_forest",
    "train_iris_logreg",
    "train_breast_cancer_bayes",
    "train_breast_cancer_stub",
] 


In [None]:
%%writefile api/app/ml/utils.py
# api/app/ml/utils.py

def configure_pytensor_compiler(*_, **__):
    """
    Stub kept for backward‑compatibility.

    The project now uses the **JAX backend**, so PyTensor never calls a C
    compiler.  This function therefore does nothing and always returns True.
    """
    return True

# ─── LEGACY ALIAS ──────────────────────────────────────────────────────────
# Some early-boot modules import "find_compiler", so we alias it here
find_compiler = configure_pytensor_compiler



In [None]:
%%writefile api/app/ml/builtin_trainers.py
# api/ml/builtin_trainers.py
"""
Built-in trainers for Iris RF and Breast-Cancer Bayesian LogReg.
Executed automatically by ModelService when a model is missing.
"""

import logging
logger = logging.getLogger(__name__)

from pathlib import Path
import mlflow, mlflow.sklearn, mlflow.pyfunc
from sklearn.datasets import load_iris, load_breast_cancer
from sklearn.model_selection import train_test_split
from sklearn.ensemble import RandomForestClassifier
from sklearn.preprocessing import StandardScaler
from sklearn.metrics import accuracy_score, f1_score, precision_score, recall_score
import pandas as pd
import numpy as np
import tempfile
import pickle
import warnings
import subprocess
import os
import platform
from app.core.config import settings

# Conditional imports for heavy dependencies
if os.getenv("UNIT_TESTING") != "1" and os.getenv("SKIP_BACKGROUND_TRAINING") != "1":
    import pymc as pm
    import arviz as az
else:
    pm = None
    az = None

# --- ADD THIS NEAR THE TOP (after imports) ----------------------------------
def _ensure_experiment(name: str = "ml_fullstack_models") -> str:
    """
    Guarantee that `name` exists and return its experiment_id.
    Handles the MLflow race where set_experiment returns a dangling ID
    if the experiment folder has not been written yet.
    """
    client = mlflow.tracking.MlflowClient()
    exp = client.get_experiment_by_name(name)
    if exp is None:
        exp_id = client.create_experiment(name)
    else:
        exp_id = exp.experiment_id
    mlflow.set_experiment(name)          # marks it the active one
    return exp_id

# ------------------------------------------------------------------
# Honour whatever Settings or the shell already provided; then
# fall back if the host part cannot be resolved quickly.
# ------------------------------------------------------------------
from urllib.parse import urlparse
import socket, time

def _fast_resolve(uri: str) -> bool:
    if uri.startswith("http"):
        host = urlparse(uri).hostname
        try:
            t0 = time.perf_counter()
            socket.getaddrinfo(host, None, proto=socket.IPPROTO_TCP)
            return (time.perf_counter() - t0) < 0.05
        except socket.gaierror:
            return False
    return True



# MLflow tracking URI is now resolved in ModelService.initialize()
# Trainers assume MLflow is already configured and experiments exist
logger.info("📦 Trainers ready - MLflow URI will be resolved at service startup")

MLFLOW_EXPERIMENT = "ml_fullstack_models"

# Remove side-effectful MLflow calls at import time
# Experiments will be created on-demand in each trainer function

# --- psutil health probe ----------------------------------------------------
def _psutil_healthy() -> bool:
    """
    Return True if psutil imports cleanly *and* exposes a working Process() object.
    We cache the result because repeated checks are cheap but noisy in logs.
    """
    global _PSUTIL_HEALTH_CACHE
    try:
        return _PSUTIL_HEALTH_CACHE
    except NameError:
        pass

    ok = False
    try:
        import psutil  # type: ignore
        ok = hasattr(psutil, "Process")
        if ok:
            try:
                _ = psutil.Process().pid  # touch native layer
            except Exception:  # bad native ext
                ok = False
    except Exception:
        ok = False

    _PSUTIL_HEALTH_CACHE = ok
    if not ok:
        logger.warning("🩺 psutil unhealthy – disabling sklearn/joblib parallelism (n_jobs=1).")
    return ok


# -----------------------------------------------------------------------------
#  IRIS – point-estimate Random-Forest (enhanced with better parameters)
# -----------------------------------------------------------------------------
def train_iris_random_forest(
    n_estimators: int = 300,
    max_depth: int | None = None,
    random_state: int = 42,
) -> str:
    from sklearn.datasets import load_iris
    from sklearn.model_selection import train_test_split
    from sklearn.ensemble import RandomForestClassifier
    import mlflow, mlflow.pyfunc

    # 1️⃣  ALWAYS ensure the experiment exists *inside* the function
    _ensure_experiment(MLFLOW_EXPERIMENT)

    iris = load_iris(as_frame=True)
    X, y = iris.data, iris.target
    X_tr, X_te, y_tr, y_te = train_test_split(
        X, y, test_size=0.25, stratify=y, random_state=random_state
    )

    safe_jobs = -1 if _psutil_healthy() else 1
    rf = RandomForestClassifier(
        n_estimators=n_estimators,
        max_depth=max_depth,
        random_state=random_state,
        n_jobs=safe_jobs,
        class_weight="balanced",
    ).fit(X_tr, y_tr)

    preds = rf.predict(X_te)
    metrics = {
        "accuracy": accuracy_score(y_te, preds),
        "f1_macro": f1_score(y_te, preds, average="macro"),
        "precision_macro": precision_score(y_te, preds, average="macro"),
        "recall_macro": recall_score(y_te, preds, average="macro"),
    }

    class IrisRFWrapper(mlflow.pyfunc.PythonModel):
        def __init__(self, model):
            if hasattr(model, "n_jobs"):
                model.n_jobs = 1
            self.model = model
            self._cols = list(X.columns)

        def _df(self, arr):
            import pandas as pd, numpy as np
            if isinstance(arr, pd.DataFrame):
                return arr
            return pd.DataFrame(np.asarray(arr), columns=self._cols)

        def predict(self, context, model_input, params=None):
            X_ = self._df(model_input)
            return self.model.predict_proba(X_)

    with mlflow.start_run(run_name="iris_random_forest") as run:
        mlflow.log_metrics(metrics)
        mlflow.log_params({
            "n_estimators": n_estimators,
            "max_depth": max_depth,
            "random_state": random_state,
            "safe_n_jobs": safe_jobs,
        })

        sig = mlflow.models.signature.infer_signature(X, rf.predict_proba(X))
        mlflow.pyfunc.log_model(
            artifact_path="model",            # ✅ correct kw-arg
            python_model=IrisRFWrapper(rf),
            registered_model_name="iris_random_forest",
            input_example=X.head(),
            signature=sig,
        )
        return run.info.run_id


# -----------------------------------------------------------------------------
#  IRIS – logistic-regression trainer (NEW)
# -----------------------------------------------------------------------------

def train_iris_logreg(
    C: float = 1.0,
    max_iter: int = 400,
    random_state: int = 42,
) -> str:
    from sklearn.datasets import load_iris
    from sklearn.model_selection import train_test_split
    from sklearn.linear_model import LogisticRegression
    import mlflow, mlflow.pyfunc

    # 1️⃣  ALWAYS ensure the experiment exists *inside* the function
    _ensure_experiment(MLFLOW_EXPERIMENT)

    iris = load_iris(as_frame=True)
    X, y = iris.data, iris.target
    X_tr, X_te, y_tr, y_te = train_test_split(
        X, y, test_size=0.25, stratify=y, random_state=random_state
    )

    safe_jobs = -1 if _psutil_healthy() else 1
    clf = LogisticRegression(
        C=C,
        max_iter=max_iter,
        multi_class="multinomial",
        solver="lbfgs",
        n_jobs=safe_jobs,
        random_state=random_state,
    ).fit(X_tr, y_tr)

    preds = clf.predict(X_te)
    metrics = {
        "accuracy": accuracy_score(y_te, preds),
        "f1_macro": f1_score(y_te, preds, average="macro"),
        "precision_macro": precision_score(y_te, preds, average="macro"),
        "recall_macro": recall_score(y_te, preds, average="macro"),
    }

    class IrisLogRegWrapper(mlflow.pyfunc.PythonModel):
        def __init__(self, model):
            if hasattr(model, "n_jobs"):
                model.n_jobs = 1
            self.model = model
            self._cols = list(X.columns)

        def _df(self, arr):
            import pandas as pd, numpy as np
            if isinstance(arr, pd.DataFrame):
                return arr
            return pd.DataFrame(np.asarray(arr), columns=self._cols)

        def predict(self, context, model_input, params=None):
            X_ = self._df(model_input)
            return self.model.predict_proba(X_)

    with mlflow.start_run(run_name="iris_logreg") as run:
        mlflow.log_metrics(metrics)
        mlflow.log_params({
            "C": C,
            "max_iter": max_iter,
            "random_state": random_state,
            "safe_n_jobs": safe_jobs,
        })

        sig = mlflow.models.signature.infer_signature(X, clf.predict_proba(X))
        mlflow.pyfunc.log_model(
            artifact_path="model",            # ✅ correct kw-arg
            python_model=IrisLogRegWrapper(clf),
            registered_model_name="iris_logreg",
            input_example=X.head(),
            signature=sig,
        )
        return run.info.run_id


# -----------------------------------------------------------------------------
#  BREAST-CANCER STUB – ultra-fast fallback model
# -----------------------------------------------------------------------------
def train_breast_cancer_stub(random_state: int = 42) -> str:
    """
    Ultra-fast fallback binary LogisticRegression on the breast-cancer dataset.

    Serializes safely on Windows by forcing `n_jobs=1` if psutil unhealthy,
    and exports MLflow PythonModel w/ modern signature that returns P(malignant).
    References: joblib parallelism + psutil; MLflow PythonModel signature. :contentReference[oaicite:21]{index=21}
    """
    from sklearn.datasets import load_breast_cancer
    from sklearn.model_selection import train_test_split
    from sklearn.linear_model import LogisticRegression
    from sklearn.metrics import accuracy_score
    import mlflow, tempfile, pickle, pandas as pd

    # 1️⃣  ALWAYS ensure the experiment exists *inside* the function
    _ensure_experiment(MLFLOW_EXPERIMENT)

    X, y = load_breast_cancer(return_X_y=True, as_frame=True)
    Xtr, Xte, ytr, yte = train_test_split(
        X, y, test_size=0.3, stratify=y, random_state=random_state
    )

    safe_jobs = -1 if _psutil_healthy() else 1

    clf = LogisticRegression(
        max_iter=200, n_jobs=safe_jobs, random_state=random_state
    ).fit(Xtr, ytr)

    class CancerStubWrapper(mlflow.pyfunc.PythonModel):
        def __init__(self, model):
            if hasattr(model, "n_jobs"):
                try:
                    model.n_jobs = 1
                except Exception:
                    pass
            self.model = model
            self._cols = list(X.columns)

        def _df(self, arr):
            import pandas as pd, numpy as np
            if isinstance(arr, pd.DataFrame):
                return arr
            return pd.DataFrame(np.asarray(arr), columns=self._cols)

        def predict(self, context, model_input, params=None):
            proba = self.model.predict_proba(self._df(model_input))
            return proba[:, 1]  # malignant probability

        def predict_proba(self, X):
            return self.model.predict_proba(self._df(X))

    acc = accuracy_score(yte, clf.predict(Xte))

    with tempfile.TemporaryDirectory() as td, mlflow.start_run(run_name="breast_cancer_stub") as run:
        # Log config hash for reproducibility and drift detection
        from app.core.config import settings as _s
        import hashlib, json
        _cfg_hash = hashlib.sha256(json.dumps(_s.model_dump(), sort_keys=True).encode()).hexdigest()
        mlflow.log_param("train_config_hash", _cfg_hash)

        mlflow.log_metric("accuracy", acc)
        mlflow.log_param("safe_n_jobs", safe_jobs)
        wrapper = CancerStubWrapper(clf)
        sig = mlflow.models.signature.infer_signature(X, wrapper.predict(None, X))
        mlflow.pyfunc.log_model(
            "model",
            python_model=wrapper,
            registered_model_name="breast_cancer_stub",
            input_example=X.head(),
            signature=sig,
        )
        return run.info.run_id


# -----------------------------------------------------------------------------
#  BREAST-CANCER – hierarchical Bayesian logistic regression
# -----------------------------------------------------------------------------

def train_breast_cancer_bayes(
    draws: int = 1000,
    tune: int = 1000,
    target_accept: float = 0.99,
    params_obj=None,   # optional BayesCancerParams
) -> str:
    """
    Hierarchical Bayesian logistic regression with varying intercepts.

    Adds:
      * Validated hyperparams via BayesCancerParams (if provided)
      * Convergence diagnostics: R-hat, bulk/tail ESS
      * Optional WAIC / LOO (guarded; can be disabled)
      * Logged warnings if thresholds exceeded
    """
    import pymc as pm
    import pandas as pd, numpy as np
    from sklearn.datasets import load_breast_cancer
    from sklearn.preprocessing import StandardScaler
    import mlflow, tempfile, pickle
    from pathlib import Path

    # Schema override if provided
    if params_obj is not None:
        draws = params_obj.draws
        tune = params_obj.tune
        target_accept = params_obj.target_accept
        compute_waic = params_obj.compute_waic
        compute_loo = params_obj.compute_loo
        max_rhat_warn = params_obj.max_rhat_warn
        min_ess_warn = params_obj.min_ess_warn
    else:
        compute_waic = True
        compute_loo = False
        max_rhat_warn = 1.01
        min_ess_warn = 400

    logger.info(
        "BayesCancer: draws=%d tune=%d target_accept=%.3f waic=%s loo=%s",
        draws, tune, target_accept, compute_waic, compute_loo
    )

    _ensure_experiment(MLFLOW_EXPERIMENT)

    X_df, y = load_breast_cancer(as_frame=True, return_X_y=True)
    quint, edges = pd.qcut(X_df["mean texture"], 5, labels=False, retbins=True)
    g = np.asarray(quint, dtype="int64")
    scaler = StandardScaler().fit(X_df)
    Xs = scaler.transform(X_df)

    coords = {"group": np.arange(5)}
    with pm.Model(coords=coords) as m:
        α = pm.Normal("α", 0.0, 1.0, dims="group")
        β = pm.Normal("β", 0.0, 1.0, shape=Xs.shape[1])
        logit = α[g] + pm.math.dot(Xs, β)
        pm.Bernoulli("obs", logit_p=logit, observed=y)
        idata = pm.sample(
            draws=draws,
            tune=tune,
            chains=4,
            nuts_sampler="numpyro",
            target_accept=target_accept,
            progressbar=False,
        )

    # Diagnostics
    import arviz as az
    rhat = az.rhat(idata).to_array().values.max()
    ess_bulk = az.ess(idata, method="bulk").to_array().values.min()
    ess_tail = az.ess(idata, method="tail").to_array().values.min()
    waic_val = None
    loo_val = None
    try:
        if compute_waic:
            waic_val = float(az.waic(idata).waic)
    except Exception as e:
        logger.warning("WAIC computation failed: %s", e)
    try:
        if compute_loo:
            loo_val = float(az.loo(idata).loo)
    except Exception as e:
        logger.warning("LOO computation failed: %s", e)

    # Wrapper
    class _HierBayesWrapper(mlflow.pyfunc.PythonModel):
        def __init__(self, trace, sc, ed, cols):
            self.trace, self.scaler, self.edges, self.cols = trace, sc, ed, cols

        def _quint(self, df):
            col = "mean texture"
            if col not in df.columns and "mean_texture" in df.columns:
                df = df.rename(columns={"mean_texture": col})
            tex = df[col].to_numpy()
            return np.clip(np.digitize(tex, self.edges, right=False), 0, 4)

        def predict(self, context, model_input, params=None):
            df = model_input if isinstance(model_input, pd.DataFrame) else pd.DataFrame(model_input, columns=self.cols)
            xs = self.scaler.transform(df)
            g = self._quint(df)
            αg = self.trace.posterior["α"].median(("chain", "draw")).values
            β = self.trace.posterior["β"].median(("chain", "draw")).values
            log = αg[g] + np.dot(xs, β)
            return 1.0 / (1.0 + np.exp(-log))

    wrapper = _HierBayesWrapper(idata, scaler, edges[1:-1], X_df.columns.tolist())
    preds = wrapper.predict(None, X_df)
    acc = float(((preds > 0.5).astype(int) == y).mean())

    # Threshold warnings
    if rhat > max_rhat_warn:
        logger.warning("R-hat exceeds threshold: %.4f > %.2f", rhat, max_rhat_warn)
    if ess_bulk < min_ess_warn:
        logger.warning("Bulk ESS below threshold: %.1f < %d", ess_bulk, min_ess_warn)

    with tempfile.TemporaryDirectory() as td, mlflow.start_run(run_name="breast_cancer_bayes") as run:
        from app.core.config import settings as _s
        import hashlib, json
        _cfg_hash = hashlib.sha256(json.dumps(_s.model_dump(), sort_keys=True).encode()).hexdigest()

        # Metrics
        mlflow.log_metric("accuracy", acc)
        mlflow.log_metric("rhat_max", rhat)
        mlflow.log_metric("ess_bulk_min", ess_bulk)
        mlflow.log_metric("ess_tail_min", ess_tail)
        if waic_val is not None:
            mlflow.log_metric("waic", waic_val)
        if loo_val is not None:
            mlflow.log_metric("loo", loo_val)

        # Params
        mlflow.log_param("train_config_hash", _cfg_hash)
        mlflow.log_param("draws", draws)
        mlflow.log_param("tune", tune)
        mlflow.log_param("target_accept", target_accept)
        mlflow.log_param("compute_waic", compute_waic)
        mlflow.log_param("compute_loo", compute_loo)

        sc_path = Path(td) / "scaler.pkl"
        pickle.dump(scaler, open(sc_path, "wb"))
        mlflow.pyfunc.log_model(
            "model",
            python_model=wrapper,
            artifacts={"scaler": str(sc_path)},
            registered_model_name="breast_cancer_bayes",
            input_example=X_df.head(),
            signature=mlflow.models.signature.infer_signature(X_df, wrapper.predict(None, X_df)),
        )
        return run.info.run_id


In [None]:
%%writefile api/app/services/ml/model_service.py
"""
Model service – self-healing startup with background training.
"""

from __future__ import annotations
import asyncio, logging, os, time, socket, shutil, subprocess, hashlib, json
from concurrent.futures import ThreadPoolExecutor
from typing import Dict, Any, List, Tuple, Optional
from pathlib import Path

import mlflow, pandas as pd, numpy as np
from mlflow.tracking import MlflowClient
from mlflow.exceptions import MlflowException

# ------------------------------------------------------------------
# IMPORTANT IMPORT NOTE
# ------------------------------------------------------------------
# This module lives in app/services/ml/.
# To reach the sibling top-level package app/core/ we must step
# *two* levels up (services/ml -> services -> app) before importing.
# Rather than counting dots ('from ...core.config import settings'),
# we choose *absolute imports* for clarity & stability across refactors.
# See Python import system docs + Real Python discussion on why absolute
# imports are preferred for larger projects.  :contentReference[oaicite:25]{index=25}
# ------------------------------------------------------------------

from app.core.config import settings
from app.ml.builtin_trainers import (
    train_iris_random_forest,
    train_iris_logreg,  # NEW
    train_breast_cancer_bayes,
    train_breast_cancer_stub,
)

# NEW imports for registry integration
from importlib import import_module
try:
    from src.registry import registry as dynamic_registry  # noqa
except Exception:
    dynamic_registry = None  # tolerant if path not yet packaged


# from ..core.config import settings
# from ..ml.builtin_trainers import (
#     train_iris_random_forest,
#     train_iris_logreg,  # NEW
#     train_breast_cancer_bayes,
#     train_breast_cancer_stub,
# )


logger = logging.getLogger(__name__)

# --- safe sklearn predict_proba helper ---------------------------------------
def _safe_sklearn_proba(estimator, X, *, log_prefix=""):
    """
    Call estimator.predict_proba(X) but recover from environments where
    joblib/loky -> psutil introspection explodes (e.g., AttributeError: psutil.Process).

    Strategy:
    1. Try fast path.
    2. On AttributeError mentioning psutil (or any RuntimeError from joblib),
       set estimator.n_jobs = 1 if present and retry serially.
    3. As a last resort, call estimator.predict(X) and synthesize 1-hot probs.

    Returns a NumPy array of shape (n_samples, n_classes).
    """
    import numpy as _np
    from joblib import parallel_backend

    # Make sure we have an array / DataFrame scikit can handle
    X_ = X

    # 1st attempt --------------------------------------------------------------
    try:
        return estimator.predict_proba(X_)
    except Exception as e:  # broad → we inspect below
        msg = str(e)
        bad_psutil = "psutil" in msg and "Process" in msg
        if not bad_psutil:
            logger.warning("%s predict_proba failed (%s) – retry single-threaded",
                           log_prefix, e)

        # 2nd attempt: force serial backend -----------------------------------
        try:
            if hasattr(estimator, "n_jobs"):
                try:
                    estimator.n_jobs = 1
                except Exception:  # read-only attr
                    pass
            with parallel_backend("threading", n_jobs=1):
                return estimator.predict_proba(X_)
        except Exception as e2:
            logger.error("%s serial predict_proba failed (%s) – fallback to classes",
                         log_prefix, e2)

    # 3rd attempt: derive 1-hot from predict ----------------------------------
    try:
        preds = estimator.predict(X_)
        preds = _np.asarray(preds, dtype=int)
        n_classes = getattr(estimator, "n_classes_", preds.max() + 1)
        probs = _np.zeros((preds.size, n_classes), dtype=float)
        probs[_np.arange(preds.size), preds] = 1.0
        return probs
    except Exception as e3:
        logger.exception("%s fallback predict also failed (%s)", log_prefix, e3)
        raise  # Let caller handle

# ---------------------------------------------------------------------------
# Cancer column mapping: Pydantic field names ➜ training column names
# ---------------------------------------------------------------------------
_CANCER_COLMAP: dict[str, str] = {
    # Means
    "mean_radius": "mean radius",
    "mean_texture": "mean texture",
    "mean_perimeter": "mean perimeter",
    "mean_area": "mean area",
    "mean_smoothness": "mean smoothness",
    "mean_compactness": "mean compactness",
    "mean_concavity": "mean concavity",
    "mean_concave_points": "mean concave points",
    "mean_symmetry": "mean symmetry",
    "mean_fractal_dimension": "mean fractal dimension",
    # SE
    "se_radius": "radius error",
    "se_texture": "texture error",
    "se_perimeter": "perimeter error",
    "se_area": "area error",
    "se_smoothness": "smoothness error",
    "se_compactness": "compactness error",
    "se_concavity": "concavity error",
    "se_concave_points": "concave points error",
    "se_symmetry": "symmetry error",
    "se_fractal_dimension": "fractal dimension error",
    # Worst
    "worst_radius": "worst radius",
    "worst_texture": "worst texture",
    "worst_perimeter": "worst perimeter",
    "worst_area": "worst area",
    "worst_smoothness": "worst smoothness",
    "worst_compactness": "worst compactness",
    "worst_concavity": "worst concavity",
    "worst_concave_points": "worst concave points",
    "worst_symmetry": "worst symmetry",
    "worst_fractal_dimension": "worst fractal dimension",
}

def _rename_cancer_columns(df: pd.DataFrame) -> pd.DataFrame:
    """
    Ensure DataFrame columns match the training schema used by MLflow artefacts.
    Unknown columns are left untouched so legacy models still work.
    """
    return df.rename(columns=_CANCER_COLMAP)

# Trainer mapping for self-healing
TRAINERS = {
    "iris_random_forest": train_iris_random_forest,
    "iris_logreg":        train_iris_logreg,  # NEW
    "breast_cancer_bayes": train_breast_cancer_bayes,
    "breast_cancer_stub":  train_breast_cancer_stub,
}

class ModelService:
    """
    Self-healing model service that loads existing models and schedules
    background training for missing ones.
    """

    _EXECUTOR = ThreadPoolExecutor(max_workers=2)

    def __init__(self) -> None:
        self._unit_test_mode = settings.UNIT_TESTING
        self.initialized = False

        # 🚫 Heavy clients only when NOT unit-testing
        self.client = None if self._unit_test_mode else None  # Will be set in initialize()
        self.mlflow_client = None

        self.models: Dict[str, Any] = {}
        # registry bootstrap flags
        self._registry_loaded = False
        self.status: Dict[str, str] = {
            "iris_random_forest": "missing",
            "iris_logreg":        "missing",  # NEW
            "breast_cancer_bayes": "missing",
            "breast_cancer_stub": "missing",
        }

    # --- Registry integration (increment 1) ---------------------------------
    def _init_registry_once(self):
        if self._registry_loaded:
            return
        if dynamic_registry is None:
            logger.info("Registry package not available yet; skipping dynamic trainer loading.")
            self._registry_loaded = True
            return
        try:
            # Hardcode first migrated trainer; later we will iterate YAML directory.
            from src.registry.registry import load_from_entry_point  # type: ignore
            load_from_entry_point("src.trainers.iris_rf_trainer:IrisRandomForestTrainer")
            self._registry_loaded = True
            logger.info("Dynamic registry initialized with trainers: %s",
                        list(dynamic_registry.all_names()))
        except Exception as e:
            logger.warning("Failed to initialize registry: %s", e)
            self._registry_loaded = True  # prevent retry storm

    def _get_trainer_or_none(self, name: str):
        if not self._registry_loaded:
            self._init_registry_once()
        try:
            from src.registry.registry import get as reg_get  # type: ignore
            return reg_get(name)
        except Exception:
            return None

    async def train_via_registry(self, name: str, overrides: Dict[str, Any] | None = None) -> Optional[str]:
        spec = self._get_trainer_or_none(name)
        if spec is None:
            logger.info("No registry trainer for %s", name)
            return None
        trainer = spec.cls()
        # merge overrides
        params = trainer.merge_hyperparams(overrides or {})
        loop = asyncio.get_running_loop()
        logger.info("Training %s via registry with params=%s", name, params)
        result = await loop.run_in_executor(self._EXECUTOR, lambda: trainer.train(**params))
        # After training, force reload of production candidate (latest run fallback)
        await self._try_load(name)
        return result.run_id

    def _resolve_tracking_uri(self) -> str:
        """
        Resolve MLflow tracking URI with graceful fallback:
          1. Explicit env var MLFLOW_TRACKING_URI
          2. settings.MLFLOW_TRACKING_URI
          3. local file store 'file:./mlruns_local'
        DNS / connection problems downgrade to local file store.
        """
        import socket, urllib.parse, mlflow
        candidates = []
        if os.getenv("MLFLOW_TRACKING_URI"):
            candidates.append(("env", os.getenv("MLFLOW_TRACKING_URI")))
        candidates.append(("settings", settings.MLFLOW_TRACKING_URI))
        candidates.append(("fallback", "file:./mlruns_local"))

        for origin, uri in candidates:
            parsed = urllib.parse.urlparse(uri)
            if parsed.scheme in ("http", "https"):
                host = parsed.hostname
                try:
                    socket.getaddrinfo(host, parsed.port or 80)
                    logger.info("MLflow URI ok (%s): %s", origin, uri)
                    return uri
                except socket.gaierror as e:
                    logger.warning("MLflow URI unresolved (%s=%s) -> %s", origin, uri, e)
            else:
                # file store always acceptable
                logger.info("MLflow file store selected (%s): %s", origin, uri)
                return uri

        return "file:./mlruns_local"

    async def initialize(self) -> None:
        """
        Connect to MLflow – fall back to local file store if the configured
        tracking URI is unreachable *or* the client is missing critical methods
        (e.g. when mlflow-skinny accidentally shadows the full package).
        """
        if self.initialized:
            return

        # Log critical dependency versions for diagnostics
        try:
            import pytensor
            logger.info("📦 PyTensor version: %s", pytensor.__version__)
        except ImportError:
            logger.warning("⚠️  PyTensor not available")
        except Exception as e:
            logger.warning("⚠️  Could not determine PyTensor version: %s", e)

        def _needs_fallback(client) -> bool:
            # any missing attr is a strong signal we are on mlflow-skinny
            return not callable(getattr(client, "list_experiments", None))

        try:
            resolved = self._resolve_tracking_uri()
            mlflow.set_tracking_uri(resolved)
            logger.info("Using tracking URI: %s", resolved)
            self.mlflow_client = MlflowClient(resolved)

            if _needs_fallback(self.mlflow_client):
                raise AttributeError("list_experiments not implemented – skinny build detected")

            # minimal probe (cheap & always present)
            self.mlflow_client.search_experiments(max_results=1)
            logger.info("🟢  Connected to MLflow @ %s", resolved)

        except (MlflowException, socket.gaierror, AttributeError) as exc:
            logger.warning("🔄  Falling back to local MLflow store – %s", exc)
            mlflow.set_tracking_uri("file:./mlruns_local")
            self.mlflow_client = MlflowClient("file:./mlruns_local")
            logger.info("📂  Using local file store ./mlruns_local")

        await self._load_models()
        self.initialized = True

    async def _load_models(self) -> None:
        """Load existing models from MLflow."""
        for name in ["iris_random_forest", "iris_logreg",
                     "breast_cancer_bayes", "breast_cancer_stub"]:
            try:
                await self._try_load(name)
            except Exception as exc:
                logger.error("❌  load %s failed: %s", name, exc)

    async def startup(self, auto_train: bool | None = None) -> None:
        """
        Faster: serve stub immediately; heavy Bayesian job in background.
        """
        if self._unit_test_mode:
            logger.info("🔒 UNIT_TESTING=1 – skipping model loading")
            return                      # 👉 nothing else runs

        # Initialize MLflow connection first
        await self.initialize()

        if settings.SKIP_BACKGROUND_TRAINING and not settings.AUTO_TRAIN_MISSING:
            logger.warning("⏩ Both training flags disabled – models must already exist")
            # We still *try* to load existing artefacts so prod works
            await self._try_load("iris_random_forest")
            await self._try_load("iris_logreg")
            await self._try_load("breast_cancer_bayes")
            return

        auto = auto_train if auto_train is not None else settings.AUTO_TRAIN_MISSING
        logger.info("🔄 Model-service startup (auto_train=%s)", auto)

        # Registry-aware load for migrated models
        self._init_registry_once()

        # Try dynamic load first for iris_random_forest
        loaded_rf = await self._try_load("iris_random_forest")
        if not loaded_rf and auto:
            # prefer registry path
            run_id = await self.train_via_registry("iris_random_forest")
            if run_id:
                await self._try_load("iris_random_forest")

        # Legacy deterministic model (to be migrated later)
        if not await self._try_load("iris_logreg") and auto:
            logger.info("Training iris logistic-regression (legacy path)…")
            await asyncio.get_running_loop().run_in_executor(
                self._EXECUTOR, train_iris_logreg
            )
            await self._try_load("iris_logreg")

        # Bayesian path unchanged (will migrate later)
        if not await self._try_load("breast_cancer_bayes"):
            if not await self._try_load("breast_cancer_stub") and auto:
                logger.info("Training stub cancer model …")
                await asyncio.get_running_loop().run_in_executor(
                    self._EXECUTOR, train_breast_cancer_stub
                )
                await self._try_load("breast_cancer_stub")
            if auto and not settings.SKIP_BACKGROUND_TRAINING:
                logger.info("Scheduling Bayesian retrain in background")
                asyncio.create_task(
                    self._train_and_reload("breast_cancer_bayes", train_breast_cancer_bayes)
                )

    async def _try_load(self, name: str) -> bool:
        """Try to load a model and update status."""
        try:
            model = await self._load_production_model(name)
            if model:
                self.models[name] = model
                self.status[name] = "loaded"
                logger.info("✅ %s loaded", name)
                return True
            self.status.setdefault(name, "missing")
            return False
        except Exception as exc:
            logger.error("❌  load %s failed: %s", name, exc)
            self.status[name] = "failed"
            self.status[f"{name}_last_error"] = str(exc)
            return False

    async def _train_and_reload(self, name: str, trainer) -> None:
        """Train a model in background and reload it, with verbose phase logs."""
        try:
            t0 = time.perf_counter()
            logger.info("🏗️  BEGIN training %s", name)
            self.status[name] = "training"

            loop = asyncio.get_running_loop()
            await loop.run_in_executor(self._EXECUTOR, trainer)

            logger.info("📦 Training %s complete in %.1fs – re-loading", name,
                        time.perf_counter() - t0)
            model = await self._load_production_model(name)
            if not model:
                raise RuntimeError(f"{name} trained but could not be re-loaded")

            self.models[name] = model
            self.status[name] = "loaded"

            # Trigger retention clean‑up in background
            loop = asyncio.get_running_loop()
            loop.run_in_executor(self._EXECUTOR,
                                 lambda: asyncio.run(self._cleanup_runs(name)))
            logger.info("✅ %s trained & loaded", name)

        except Exception as exc:
            self.status[name] = "failed"
            logger.error("❌ %s failed: %s", name, exc, exc_info=True)  # ← keeps trace
            # NEW: persist last_error for UI / debug endpoint
            self.status[f"{name}_last_error"] = str(exc)

# --- DROP-IN REPLACEMENT ----------------------------------------------------
    async def _load_production_model(self, name: str):
        """
        Load the canonical production model with alias support *and* perform an
        **environment audit** of the recorded vs. runtime dependencies.

        We DO NOT install anything automatically.  Instead we:
            • attempt to load in the usual fallback order (@prod → @staging → Production stage → latest run)
            • after a *successful* load, call `_audit_model_env(uri, name)` to diff
              the model's logged environment spec against the current runtime
              (importlib.metadata) and record mismatches in `self.status`.

        The audit is *diagnostic* unless an optional enforcement policy is enabled
        via env/config (MODEL_ENV_ENFORCEMENT = warn|fail|retrain).

        Returns
        -------
        Loaded MLflow model instance *or* None if nothing could be loaded, or
        load was refused under "fail" policy for env mismatch.
        """
        import sys
        import mlflow
        from mlflow.tracking.artifact_utils import _download_artifact_from_uri
        from packaging.version import Version, InvalidVersion
        import importlib.metadata as im
        import json
        import os

        # Use MLOps configuration for enforcement policy
        policy = settings.MODEL_AUDIT_ENFORCEMENT.lower()

        def _warn_model_env(uri: str) -> None:
            # unchanged best‑effort header check (Python version)
            try:
                local_dir = _download_artifact_from_uri(uri)
                mlmodel_path = Path(local_dir) / "MLmodel"
                if not mlmodel_path.is_file():
                    return
                import yaml
                meta = yaml.safe_load(mlmodel_path.read_text())
                py_model_ver = (
                    meta.get("python_env", {}).get("python")
                    or meta.get("flavors", {})
                        .get("python_function", {})
                        .get("loader_module_python_version")
                )
                runtime_py = f"{sys.version_info.major}.{sys.version_info.minor}"
                if py_model_ver and not py_model_ver.startswith(runtime_py):
                    logger.warning(
                        "⚠️ %s logged under Python %s but runtime is %s; "
                        "deserialization may fail. Consider retraining.",
                        name, py_model_ver, runtime_py
                    )
            except Exception as e:  # best-effort
                logger.debug("env check failed for %s (%s)", uri, e)

        def _audit_model_env(uri: str, model_name: str) -> dict:
            """
            Return a dict: {pkg: {'required': spec, 'current': ver, 'match': bool, 'severity': str}}
            Only logs the pip‑install command if there are mismatches.
            """
            import logging
            from pathlib import Path
            from packaging.version import Version, InvalidVersion
            import importlib.metadata as im
            from mlflow.pyfunc import get_model_dependencies

            audit: dict[str, dict] = {}

            # 1️⃣ Suppress MLflow's own INFO log for pip install
            pyfunc_logger = logging.getLogger("mlflow.pyfunc")
            old_level = pyfunc_logger.level
            pyfunc_logger.setLevel(logging.WARNING)
            try:
                try:
                    deps_path = get_model_dependencies(uri)
                except Exception:
                    deps_path = None
            finally:
                pyfunc_logger.setLevel(old_level)

            # 2️⃣ Read the pip requirements.txt
            req_lines: list[str] = []
            if deps_path and Path(deps_path).is_file():
                for ln in Path(deps_path).read_text().splitlines():
                    ln = ln.strip()
                    if not ln or ln.startswith("#"):
                        continue
                    req_lines.append(ln)

            # 3️⃣ Build the audit by comparing to runtime versions
            for spec in req_lines:
                pkg = spec.split("@", 1)[0].split(";", 1)[0].strip()
                pkg_lc = pkg.lower().replace("_", "-")

                req_ver = None
                if "==" in spec:
                    req_ver = spec.split("==", 1)[1].strip()
                elif ">=" in spec:
                    req_ver = spec.split(">=", 1)[1].strip()

                cur_ver = None
                try:
                    cur_ver = im.version(pkg_lc)
                except Exception:
                    cur_ver = None

                match = True
                sev = "OK"
                if req_ver:
                    try:
                        v_req = Version(req_ver)
                        if cur_ver:
                            v_cur = Version(cur_ver)
                            if v_cur.major != v_req.major:
                                sev, match = "MAJOR_DRT", False
                            elif v_cur != v_req:
                                sev, match = "MINOR_DRT", False
                        else:
                            sev, match = "MISSING", False
                    except InvalidVersion:
                        pass
                elif cur_ver is None:
                    sev, match = "MISSING", False

                audit[pkg_lc] = {
                    "required": req_ver,
                    "current": cur_ver,
                    "match": match,
                    "severity": sev,
                }

            # 4️⃣ Record audit in service status
            self.status[f"{model_name}_dep_audit"] = audit

            # 5️⃣ Only show pip‑install hint if there *are* mismatches
            if deps_path and any(not rec["match"] for rec in audit.values()):
                pyfunc_logger.info(
                    "To install the dependencies that were used to train the model, "
                    "run the following command: 'pip install -r %s'",
                    deps_path,
                )

            # 6️⃣ MLOps policy enforcement based on environment
            if policy in ("fail", "retrain"):
                critical = ("numpy", "scipy", "scikit-learn", "psutil")
                majors = [
                    pkg
                    for pkg, rec in audit.items()
                    if pkg in critical and rec["severity"] == "MAJOR_DRT"
                ]
                if majors:
                    msg = f"Critical env drift for {model_name}: {majors}"
                    logger.error(msg)
                    if policy == "fail":
                        self.status[f"{model_name}_last_error"] = msg
                        return {"_REFUSE_LOAD": True}
                    elif policy == "retrain":
                        logger.warning(
                            "Scheduling background retrain for %s due to env drift",
                            model_name,
                        )
                        asyncio.create_task(
                            self._train_and_reload(model_name, TRAINERS[model_name])
                        )

            return audit

        client = self.mlflow_client

        # MLOps-aware loading order based on environment
        env_canon = settings.ENVIRONMENT_CANONICAL
        if env_canon == "production":
            # Production: strict order - only Production stage or @prod alias
            attempts = [
                ("@prod", f"models:/{name}@prod"),
                ("Production stage", None),  # handle below
            ]
        elif env_canon == "staging":
            # Staging: allow staging versions for testing
            attempts = [
                ("@staging", f"models:/{name}@staging"),
                ("@prod", f"models:/{name}@prod"),
                ("Production stage", None),  # handle below
            ]
        else:
            # Development: most permissive
            attempts = [
                ("@prod", f"models:/{name}@prod"),
                ("@staging", f"models:/{name}@staging"),
                ("Production stage", None),  # handle below
                ("latest run", None),        # handle below
            ]

        # 1️⃣ Try aliases first ------------------------------------------------------
        for alias_name, uri in attempts[:2]:  # Only try aliases
            if uri is None:
                continue
            try:
                _warn_model_env(uri)
                logger.info("↪︎  Loading %s from alias %s", name, alias_name)
                mdl = mlflow.pyfunc.load_model(uri)
                audit = _audit_model_env(uri, name)
                if audit.get("_REFUSE_LOAD"):
                    logger.warning("Refusing %s from %s under policy; continuing fallbacks", name, alias_name)
                else:
                    # --- config hash drift check -------------------------------------------------
                    try:
                        # Candidate may be run-based or version-based; we try to extract run_id param from MLflow model flavor metadata.
                        # Fallback: skip silently.
                        from mlflow import get_tracking_uri
                        tracking_client = self.mlflow_client
                        # try to read run params if we have run context
                        # NOTE: _download_artifact_from_uri gave us `uri`; if it's a runs:/ URI we can parse run_id
                        if uri.startswith("runs:/"):
                            run_id = uri.split("/", 2)[1]
                            run = tracking_client.get_run(run_id)
                            train_hash = run.data.params.get("train_config_hash")
                            if train_hash:
                                from app.core.config import settings as _s
                                cur_hash = hashlib.sha256(json.dumps(_s.model_dump(), sort_keys=True).encode()).hexdigest()
                                if train_hash != cur_hash:
                                    logger.warning(
                                        "⚠️ Config drift for %s: train_hash=%s current=%s",
                                        model_name, train_hash[:8], cur_hash[:8]
                                    )
                    except Exception as _e_hash:  # best-effort
                        logger.debug("Config drift check skipped: %s", _e_hash)
                    return mdl
            except Exception as e:
                logger.debug("Alias %s not available for %s: %s", alias_name, name, e)

        # 2️⃣ Try Production stage ---------------------------------------------------
        try:
            versions = client.search_model_versions(f"name='{name}' AND stage='Production'")
            if versions:
                version = versions[0].version
                uri = f"models:/{name}/{version}"
                _warn_model_env(uri)
                logger.info("↪︎  Loading %s from registry: Production v%s", name, version)
                mdl = mlflow.pyfunc.load_model(uri)
                audit = _audit_model_env(uri, name)
                if audit.get("_REFUSE_LOAD"):
                    logger.warning("Refusing %s Production v%s under policy; continuing fallbacks", name, version)
                else:
                    return mdl
        except Exception as e:
            logger.debug("Production stage not available for %s: %s", name, e)

        # 3️⃣ Try Staging stage (only in dev/staging) ------------------------------
        if settings.ENVIRONMENT != "production":
            try:
                versions = client.search_model_versions(f"name='{name}' AND stage='Staging'")
                if versions:
                    version = versions[0].version
                    uri = f"models:/{name}/{version}"
                    _warn_model_env(uri)
                    logger.info("↪︎  Loading %s from registry: Staging v%s", name, version)
                    mdl = mlflow.pyfunc.load_model(uri)
                    audit = _audit_model_env(uri, name)
                    if audit.get("_REFUSE_LOAD"):
                        logger.warning("Refusing %s Staging v%s under policy; continuing fallbacks", name, version)
                    else:
                        return mdl
            except Exception as e:
                logger.debug("Staging stage not available for %s: %s", name, e)

        # 4️⃣  (Possible) Fallback to latest run – now allowed in prod too
        allow_run_fallback = (
            settings.ENVIRONMENT_CANONICAL != "production"
            or settings.ALLOW_PROD_RUN_FALLBACK
        )
        if allow_run_fallback:
            try:
                runs = []
                for exp in client.search_experiments():
                    runs.extend(client.search_runs(
                        [exp.experiment_id],
                        f"tags.mlflow.runName = '{name}'",
                        order_by=["attributes.start_time DESC"],
                        max_results=1))
                if runs:
                    uri = f"runs:/{runs[0].info.run_id}/model"
                    logger.warning(
                        "⚠️  %s: alias/stage missing – loading *latest run* (%s) "
                        "because ALLOW_PROD_RUN_FALLBACK=%d",
                        name, runs[0].info.run_id, allow_run_fallback,
                    )
                    _warn_model_env(uri)
                    mdl = mlflow.pyfunc.load_model(uri)
                    audit = _audit_model_env(uri, name)
                    if audit.get("_REFUSE_LOAD"):
                        logger.warning("Refusing %s latest run under policy", name)
                    else:
                        return mdl
            except Exception as e:
                logger.debug("Latest‑run fallback failed for %s: %s", name, e)

        logger.error("❌ No suitable model found for %s after all fallbacks", name)
        return None
# --- END DROP-IN REPLACEMENT -------------------------------------------------



    async def evaluate_model_quality(
        self, 
        model_name: str, 
        candidate_run_id: str,
        test_data_path: Optional[str] = None
    ) -> dict:
        """
        Evaluate a candidate model against production baseline.

        This implements quality gates for MLOps:
        1. Load production model (if exists)
        2. Load candidate model from run_id
        3. Evaluate both on test set
        4. Return comparison metrics

        Args:
            model_name: Name of the model to evaluate
            candidate_run_id: MLflow run ID of candidate model
            test_data_path: Optional path to test data (uses built-in if None)

        Returns:
            Dict with evaluation results and promotion decision
        """
        from sklearn.metrics import accuracy_score, f1_score, precision_score, recall_score
        import pandas as pd

        logger.info("🔍 Evaluating quality gate for %s (candidate: %s)", model_name, candidate_run_id)

        # Load test data
        if model_name.startswith("iris"):
            from sklearn.datasets import load_iris
            from sklearn.model_selection import train_test_split

            iris = load_iris(as_frame=True)
            X, y = iris.data, iris.target
            _, X_test, _, y_test = train_test_split(X, y, test_size=0.25, stratify=y, random_state=42)

        elif model_name.startswith("breast_cancer"):
            from sklearn.datasets import load_breast_cancer
            from sklearn.model_selection import train_test_split

            X, y = load_breast_cancer(return_X_y=True, as_frame=True)
            _, X_test, _, y_test = train_test_split(X, y, test_size=0.3, stratify=y, random_state=42)
            X_test = _rename_cancer_columns(X_test)
        else:
            raise ValueError(f"Unknown model type: {model_name}")

        # Load candidate model
        try:
            candidate_uri = f"runs:/{candidate_run_id}/model"
            candidate_model = mlflow.pyfunc.load_model(candidate_uri)
            logger.info("✅ Loaded candidate model from %s", candidate_run_id)
        except Exception as e:
            logger.error("❌ Failed to load candidate model: %s", e)
            return {
                "promoted": False,
                "error": f"Failed to load candidate model: {e}",
                "candidate_metrics": None,
                "production_metrics": None
            }

        # Evaluate candidate
        try:
            if model_name.startswith("iris"):
                y_pred = candidate_model.predict(X_test)
                if len(y_pred.shape) == 2:  # probabilities
                    y_pred = y_pred.argmax(axis=1)
            else:  # cancer model
                y_pred_proba = candidate_model.predict(X_test)
                y_pred = (y_pred_proba > 0.5).astype(int)

            candidate_metrics = {
                "accuracy": accuracy_score(y_test, y_pred),
                "f1_macro": f1_score(y_test, y_pred, average="macro"),
                "precision_macro": precision_score(y_test, y_pred, average="macro"),
                "recall_macro": recall_score(y_test, y_pred, average="macro")
            }
            logger.info("📊 Candidate metrics: %s", candidate_metrics)

        except Exception as e:
            logger.error("❌ Failed to evaluate candidate: %s", e)
            return {
                "promoted": False,
                "error": f"Failed to evaluate candidate: {e}",
                "candidate_metrics": None,
                "production_metrics": None
            }

        # Try to load production model for comparison
        production_metrics = None
        try:
            prod_model = await self._load_production_model(model_name)
            if prod_model:
                if model_name.startswith("iris"):
                    y_pred_prod = prod_model.predict(X_test)
                    if len(y_pred_prod.shape) == 2:  # probabilities
                        y_pred_prod = y_pred_prod.argmax(axis=1)
                else:  # cancer model
                    y_pred_proba_prod = prod_model.predict(X_test)
                    y_pred_prod = (y_pred_proba_prod > 0.5).astype(int)

                production_metrics = {
                    "accuracy": accuracy_score(y_test, y_pred_prod),
                    "f1_macro": f1_score(y_test, y_pred_prod, average="macro"),
                    "precision_macro": precision_score(y_test, y_pred_prod, average="macro"),
                    "recall_macro": recall_score(y_test, y_pred_prod, average="macro")
                }
                logger.info("📊 Production metrics: %s", production_metrics)
            else:
                logger.info("📊 No production model found for comparison")

        except Exception as e:
            logger.warning("⚠️  Could not evaluate production model: %s", e)

                # Quality gate decision
        promoted = False
        reason = ""

        if production_metrics:
            # Compare against production baseline
            acc_improvement = candidate_metrics["accuracy"] - production_metrics["accuracy"]
            f1_improvement = candidate_metrics["f1_macro"] - production_metrics["f1_macro"]

            # Quality gate: must maintain or improve performance
            if acc_improvement >= -0.01 and f1_improvement >= -0.01:  # Allow 1% degradation
                promoted = True
                reason = f"Performance maintained (acc: {acc_improvement:+.3f}, f1: {f1_improvement:+.3f})"
            else:
                reason = f"Performance degraded (acc: {acc_improvement:+.3f}, f1: {f1_improvement:+.3f})"
        else:
            # No production baseline - use absolute thresholds from settings
            if candidate_metrics["accuracy"] >= settings.QUALITY_GATE_ACCURACY_THRESHOLD \
               and candidate_metrics["f1_macro"] >= settings.QUALITY_GATE_F1_THRESHOLD:
                promoted = True
                reason = f"Meets minimum thresholds (acc: {candidate_metrics['accuracy']:.3f} >= {settings.QUALITY_GATE_ACCURACY_THRESHOLD}, f1: {candidate_metrics['f1_macro']:.3f} >= {settings.QUALITY_GATE_F1_THRESHOLD})"
            else:
                reason = f"Below minimum thresholds (acc: {candidate_metrics['accuracy']:.3f} < {settings.QUALITY_GATE_ACCURACY_THRESHOLD} or f1: {candidate_metrics['f1_macro']:.3f} < {settings.QUALITY_GATE_F1_THRESHOLD})"

        result = {
            "promoted": promoted,
            "reason": reason,
            "candidate_metrics": candidate_metrics,
            "production_metrics": production_metrics,
            "candidate_run_id": candidate_run_id,
            "model_name": model_name
        }

        # log evaluation metadata back to MLflow
        try:
            client = self.mlflow_client
            client.set_tag(candidate_run_id, f"quality_gate:{settings.ENVIRONMENT_CANONICAL}",
                           "PASSED" if promoted else "FAILED")
            client.set_tag(candidate_run_id, "quality_gate_reason", reason)
        except Exception as e:
            logger.debug("Failed to set MLflow quality gate tags: %s", e)

        logger.info("🎯 Quality gate result: %s - %s", "PASSED" if promoted else "FAILED", reason)
        return result

    async def promote_model_to_staging(
        self, 
        model_name: str, 
        run_id: str
    ) -> dict:
        """
        Promote a model to staging after quality gate passes.

        This is the core MLOps promotion logic:
        1. Evaluate model quality
        2. If passed, register as staging version
        3. Set @staging alias

        Args:
            model_name: Name of the model to promote
            run_id: MLflow run ID of the candidate model

        Returns:
            Dict with promotion result
        """
        logger.info("🚀 Starting promotion process for %s (run: %s)", model_name, run_id)

        # Evaluate quality gate
        eval_result = await self.evaluate_model_quality(model_name, run_id)

        if not eval_result["promoted"]:
            logger.warning("❌ Quality gate failed for %s: %s", model_name, eval_result.get("reason", "Unknown"))
            return {
                "promoted": False,
                "error": eval_result.get("error", eval_result.get("reason", "Quality gate failed")),
                "evaluation": eval_result
            }

        # Promote to staging
        try:
            client = self.mlflow_client
            candidate_uri = f"runs:/{run_id}/model"

            # Create new model version
            mv = client.create_model_version(
                name=model_name,
                source=candidate_uri,
                run_id=run_id
            )

            # Transition to Staging
            client.transition_model_version_stage(
                name=model_name,
                version=mv.version,
                stage="Staging"
            )

            # Set @staging alias
            client.set_registered_model_alias(
                name=model_name,
                alias="staging",
                version=mv.version
            )

            logger.info("✅ Successfully promoted %s to staging (version %s)", model_name, mv.version)

            return {
                "promoted": True,
                "version": mv.version,
                "stage": "Staging",
                "alias": "staging",
                "evaluation": eval_result
            }

        except Exception as e:
            error_msg = str(e)
            logger.error("❌ Failed to promote %s to staging: %s", model_name, error_msg)
            return {
                "promoted": False,
                "error": f"Promotion failed: {error_msg}",
                "evaluation": eval_result
            }

    async def promote_model_to_production(
        self, 
        model_name: str,
        version: Optional[int] = None,
        approved_by: Optional[str] = None
    ) -> dict:
        """
        Promote a staging model to production.

        This can be called manually or automatically:
        1. If version specified, promote that specific version
        2. Otherwise, promote the current @staging alias
        3. Set @prod alias for atomic promotion

        Args:
            model_name: Name of the model to promote
            version: Specific version to promote (optional)

        Returns:
            Dict with promotion result
        """
        logger.info("🚀 Promoting %s to production (version: %s)", model_name, version or "staging")

        # Enforce human approval in production if required
        if settings.REQUIRE_MODEL_APPROVAL and settings.ENVIRONMENT_CANONICAL == "production":
            if not approved_by:
                return {
                    "promoted": False,
                    "error": "Approval required: pass approved_by=<user> when promoting to production.",
                }

        try:
            client = self.mlflow_client

            if version is None:
                # Get the current staging version
                staging_versions = client.search_model_versions(
                    f"name='{model_name}' AND stage='Staging'"
                )
                if not staging_versions:
                    return {
                        "promoted": False,
                        "error": f"No staging version found for {model_name}"
                    }
                version = staging_versions[0].version

            # Transition to Production
            client.transition_model_version_stage(
                name=model_name,
                version=version,
                stage="Production"
            )

            # Set @prod alias for atomic promotion
            client.set_registered_model_alias(
                name=model_name,
                alias="prod",
                version=version
            )

            # record approval metadata as tags
            try:
                client.set_model_version_tag(model_name, version, "approved_by", str(approved_by or "n/a"))
                client.set_model_version_tag(model_name, version, "approved_env", str(settings.ENVIRONMENT_CANONICAL))
            except Exception as e:
                logger.debug("Could not tag model version approval: %s", e)

            logger.info("✅ Successfully promoted %s to production (version %s)", model_name, version)

            return {
                "promoted": True,
                "version": version,
                "stage": "Production",
                "alias": "prod"
            }

        except Exception as e:
            error_msg = str(e)
            logger.error("❌ Failed to promote %s to production: %s", model_name, error_msg)
            return {
                "promoted": False,
                "error": f"Production promotion failed: {error_msg}"
            }

    async def promote_model_to_stage(
        self, 
        model_name: str,
        target_stage: str,
        version: Optional[int] = None,
        approved_by: Optional[str] = None
    ) -> dict:
        """
        Promote a model to a specific stage (staging or production).

        Args:
            model_name: Name of the model to promote
            target_stage: Target stage ('Staging' or 'Production')
            version: Specific version to promote (optional)
            approved_by: User who approved the promotion (optional)

        Returns:
            Dict with promotion result
        """
        logger.info("🚀 Promoting %s to %s (version: %s)", model_name, target_stage, version or "latest")

        # Validate target stage
        if target_stage not in ["Staging", "Production"]:
            return {
                "promoted": False,
                "error": f"Invalid target stage: {target_stage}. Must be 'Staging' or 'Production'"
            }

        # Enforce human approval in production if required
        if target_stage == "Production" and settings.REQUIRE_MODEL_APPROVAL and settings.ENVIRONMENT_CANONICAL == "production":
            if not approved_by:
                return {
                    "promoted": False,
                    "error": "Approval required: pass approved_by=<user> when promoting to production.",
                }

        try:
            client = self.mlflow_client

            if version is None:
                # Get the latest version
                versions = client.search_model_versions(f"name='{model_name}'")
                if not versions:
                    return {
                        "promoted": False,
                        "error": f"No versions found for {model_name}"
                    }
                version = versions[0].version

            # Set appropriate alias (modern approach - skip deprecated stage transitions)
            alias = "prod" if target_stage == "Production" else "staging"
            client.set_registered_model_alias(
                name=model_name,
                alias=alias,
                version=version
            )

            # record approval metadata as tags
            try:
                client.set_model_version_tag(model_name, version, "approved_by", approved_by or "n/a")
                client.set_model_version_tag(model_name, version, "approved_env", settings.ENVIRONMENT_CANONICAL)
            except Exception as e:
                logger.debug("Could not tag model version approval: %s", e)

            logger.info("✅ Successfully promoted %s to %s (version %s)", model_name, target_stage, version)

            return {
                "promoted": True,
                "version": version,
                "stage": target_stage,
                "alias": alias
            }

        except Exception as e:
            error_msg = str(e)
            logger.error("❌ Failed to promote %s to %s: %s", model_name, target_stage, error_msg)
            return {
                "promoted": False,
                "error": f"{target_stage} promotion failed: {error_msg}"
            }

    async def get_model_metrics(self, model_name: str) -> List[Dict[str, Any]]:
        """
        Retrieve metrics for all versions of a registered model.

        This enables MLOps comparison between different model versions
        for quality gate decisions and promotion workflows.

        Args:
            model_name: Name of the registered model

        Returns:
            List of dicts with version info and metrics for each model version
        """
        client = self.mlflow_client
        versions = client.search_model_versions(f"name='{model_name}'")
        results = []

        for v in versions:
            run_id = v.run_id
            try:
                run = client.get_run(run_id)
                # Convert MLflow Metric objects to plain Python values
                metrics = {}
                for key, metric in run.data.metrics.items():
                    metrics[key] = float(metric.value) if hasattr(metric, 'value') else float(metric)
                # Add metadata for better MLOps context
                tags = run.data.tags
                creation_timestamp = v.creation_timestamp
                last_updated_timestamp = v.last_updated_timestamp
            except Exception as e:
                logger.warning(f"Could not fetch metrics for {model_name} v{v.version}: {e}")
                metrics = {}
                tags = {}
                creation_timestamp = None
                last_updated_timestamp = None

            results.append({
                "version": int(v.version),
                "stage": v.current_stage,
                "run_id": run_id,
                "metrics": metrics,
                "tags": tags,
                "creation_timestamp": creation_timestamp,
                "last_updated_timestamp": last_updated_timestamp,
                "description": v.description or ""
            })

        # Sort by version number for consistent ordering
        results.sort(key=lambda x: x["version"])
        return results

    async def compare_model_versions(
        self, 
        model_name: str, 
        version_a: int, 
        version_b: int
    ) -> Dict[str, Any]:
        """
        Compare two specific model versions for MLOps decision making.

        Args:
            model_name: Name of the registered model
            version_a: First version to compare
            version_b: Second version to compare

        Returns:
            Dict with comparison results and recommendation
        """
        client = self.mlflow_client

        # Get both versions
        try:
            version_a_info = client.get_model_version(model_name, version_a)
            version_b_info = client.get_model_version(model_name, version_b)
        except Exception as e:
            return {
                "error": f"Could not fetch model versions: {e}",
                "comparison": None
            }

        # Get metrics for both versions
        metrics_a = await self._get_version_metrics(version_a_info.run_id)
        metrics_b = await self._get_version_metrics(version_b_info.run_id)

        # Compare key metrics
        comparison = {}
        for metric in ["accuracy", "f1_macro", "precision_macro", "recall_macro"]:
            if metric in metrics_a and metric in metrics_b:
                val_a = metrics_a[metric]
                val_b = metrics_b[metric]
                diff = val_b - val_a
                comparison[metric] = {
                    "version_a": val_a,
                    "version_b": val_b,
                    "difference": diff,
                    "improvement": diff > 0
                }

        # Determine recommendation
        improvements = sum(1 for comp in comparison.values() if comp["improvement"])
        total_metrics = len(comparison)

        if total_metrics == 0:
            recommendation = "insufficient_data"
        elif improvements == total_metrics:
            recommendation = "promote_version_b"
        elif improvements == 0:
            recommendation = "keep_version_a"
        else:
            recommendation = "mixed_results"

        return {
            "model_name": model_name,
            "version_a": {
                "version": version_a,
                "stage": version_a_info.current_stage,
                "run_id": version_a_info.run_id,
                "metrics": metrics_a
            },
            "version_b": {
                "version": version_b,
                "stage": version_b_info.current_stage,
                "run_id": version_b_info.run_id,
                "metrics": metrics_b
            },
            "comparison": comparison,
            "recommendation": recommendation,
            "summary": f"Version B improves {improvements}/{total_metrics} metrics"
        }

    async def _get_version_metrics(self, run_id: str) -> Dict[str, float]:
        """Helper to get metrics for a specific run."""
        try:
            run = self.mlflow_client.get_run(run_id)
            # Convert MLflow Metric objects to plain Python values
            metrics = {}
            for key, metric in run.data.metrics.items():
                metrics[key] = float(metric.value) if hasattr(metric, 'value') else float(metric)
            return metrics
        except Exception as e:
            logger.warning(f"Could not fetch metrics for run {run_id}: {e}")
            return {}

    # Manual training endpoints (for UI)
    async def train_iris(self, model_type: str = "rf") -> None:
        """
        Train either the Random Forest or the Logistic Regression
        on the Iris dataset, per the caller's choice.
        """
        if model_type == "rf":
            name, trainer = "iris_random_forest", TRAINERS["iris_random_forest"]
        else:  # "logreg"
            name, trainer = "iris_logreg", TRAINERS["iris_logreg"]

        # reuse your existing helper
        await self._train_and_reload(name, trainer)

    async def train_cancer(self, model_type: str = "bayes") -> None:
        """
        Train either the Bayesian (PyMC) or stub (LogReg)
        on the Breast Cancer dataset, per the caller's choice.
        """
        if model_type == "bayes":
            name, trainer = "breast_cancer_bayes", TRAINERS["breast_cancer_bayes"]
        else:  # "stub"
            name, trainer = "breast_cancer_stub", TRAINERS["breast_cancer_stub"]

        await self._train_and_reload(name, trainer)

    async def train_bayes_cancer_with_params(self, params=None) -> str:
        """
        Train Bayesian cancer model with validated parameters.
        Returns the MLflow run ID.
        """
        from app.ml.builtin_trainers import train_breast_cancer_bayes
        
        # Run training with parameters
        run_id = train_breast_cancer_bayes(params_obj=params)
        
        # Reload the model after training
        await self._try_load("breast_cancer_bayes")
        
        return run_id

    # Predict methods (unchanged from your previous version)
    async def predict_iris(
        self,
        features: List[Dict[str, float]],
        model_type: str = "rf",
    ) -> Tuple[List[str], List[List[float]]]:
        """
        Predict Iris species from measurements.

        Hardens input normalization & always uses a serial, psutil‑safe path
        to compute class probabilities to avoid joblib/loky crashes when
        psutil is broken. Also ensures feature names are preserved to silence
        scikit‑learn's 'X does not have valid feature names' warning. :contentReference[oaicite:23]{index=23}
        """
        if model_type not in ("rf", "logreg"):
            raise ValueError("model_type must be 'rf' or 'logreg'")

        model_name = "iris_random_forest" if model_type == "rf" else "iris_logreg"
        model = self.models.get(model_name)
        if not model:
            raise RuntimeError(f"{model_name} not loaded")

        # construct DF w/ training column names in correct order
        X_df = pd.DataFrame(
            [{
                "sepal length (cm)":  f["sepal_length"],
                "sepal width (cm)":   f["sepal_width"],
                "petal length (cm)":  f["petal_length"],
                "petal width (cm)":   f["petal_width"],
            } for f in features]
        )
        logger.debug("predict_iris(%s) columns=%s", model_name, X_df.columns.tolist())

        # ALWAYS unwrap and call safe helper (skip top-level pyfunc)
        base = model
        try:
            py_model = model.unwrap_python_model()  # mlflow ≥2
            if hasattr(py_model, "model"):
                base = py_model.model
            else:
                base = py_model
        except Exception:
            pass

        probs = _safe_sklearn_proba(base, X_df, log_prefix=model_name)

        # convert to names
        import numpy as _np
        probs = _np.asarray(probs, dtype=float)
        if probs.ndim == 1:  # defensive
            # promote to 3-class; treat as class-0 vs rest
            z = _np.zeros((probs.size, 3), dtype=float)
            z[:, 0] = probs
            z[:, 1:] = (1 - probs) / 2
            probs = z
        preds = probs.argmax(axis=1)
        class_names = ["setosa", "versicolor", "virginica"]
        pred_names = [class_names[int(i)] for i in preds]
        return pred_names, probs.tolist()


    async def predict_cancer(
        self,
        features: List[Dict[str, float]],
        model_type: str = "bayes",
        posterior_samples: Optional[int] = None,
    ) -> Tuple[List[str], List[float], Optional[List[Tuple[float, float]]]]:
        """
        Predict breast cancer diagnosis.

        For stub (sklearn) path we unwrap & call psutil‑safe helper to avoid
        loky/psutil crashes; for bayes path we call model.predict() directly.
        MLflow PythonModel wrappers now expose modern signature. :contentReference[oaicite:24]{index=24}
        """
        if model_type == "bayes":
            model = self.models.get("breast_cancer_bayes") or self.models.get("breast_cancer_stub")
            using_bayes = "breast_cancer_bayes" in self.models and model is self.models["breast_cancer_bayes"]
        elif model_type == "stub":
            model = self.models.get("breast_cancer_stub")
            using_bayes = False
        else:
            raise ValueError("model_type must be 'bayes' or 'stub'")
        if not model:
            raise RuntimeError("No cancer model available")

        X_df_raw = pd.DataFrame(features)
        X_df = _rename_cancer_columns(X_df_raw)

        if using_bayes and hasattr(model, "predict"):
            probs = model.predict(X_df)
        else:
            # unwrap & safe path
            base = model
            try:
                py_model = model.unwrap_python_model()
                base = getattr(py_model, "model", py_model)
            except Exception:
                pass
            probs_full = _safe_sklearn_proba(base, X_df, log_prefix="breast_cancer_stub")
            probs = probs_full[:, 1] if probs_full.ndim == 2 else probs_full

        labels = ["malignant" if p > 0.5 else "benign" for p in probs]

        ci = None
        if posterior_samples and using_bayes:
            try:
                # Access the underlying python model to get the trace
                python_model = model.unwrap_python_model()

                # Access posterior samples for uncertainty quantification
                draws = python_model.trace.posterior
                αg = draws["α"].stack(samples=("chain", "draw"))
                β = draws["β"].stack(samples=("chain", "draw"))

                # Get group indices and standardized features
                g = python_model._quint(X_df)
                Xs = python_model.scaler.transform(X_df)

                # Compute posterior predictive samples
                logits = αg.values[:, g] + np.dot(β.values.T, Xs.T)      # shape (S, N)
                pp = 1 / (1 + np.exp(-logits))

                # Compute 95% credible intervals
                lo, hi = np.percentile(pp, [2.5, 97.5], axis=0)
                ci = list(zip(lo.tolist(), hi.tolist()))

            except Exception as e:
                logger.warning(f"Failed to compute uncertainty intervals: {e}")
                ci = None

        return labels, probs.tolist(), ci

    async def _cleanup_runs(self, model_name: str) -> None:
        """
        Keep the **newest N runs** for `model_name` and drop the rest, then
        optionally invoke `mlflow gc` to purge artifact folders.

        Runs marked *deleted* are still present on disk until GC executes,
        so we always run GC when `settings.MLFLOW_GC_AFTER_TRAIN` is True.
        """
        keep = max(settings.RETAIN_RUNS_PER_MODEL, 0)
        try:
            # 1️⃣ fetch runs newest→oldest
            runs = self.mlflow_client.search_runs(
                experiment_ids=[exp.experiment_id for exp in self.mlflow_client.search_experiments()],
                filter_string=f"tags.mlflow.runName = '{model_name}'",
                order_by=["attributes.start_time DESC"],
            )
            if len(runs) <= keep:
                logger.debug("No pruning needed for %s (runs=%d, keep=%d)",
                             model_name, len(runs), keep)
                return

            to_delete = runs[keep:]
            for r in to_delete:
                self.mlflow_client.delete_run(r.info.run_id)
            logger.info("🗑️  Pruned %d old %s runs; kept %d",
                        len(to_delete), model_name, keep)

            # 2️⃣ garbage‑collect artifacts
            if settings.MLFLOW_GC_AFTER_TRAIN:
                uri = mlflow.get_tracking_uri().removeprefix("file:")
                before = shutil.disk_usage(uri).used
                subprocess.run(
                    ["mlflow", "gc",
                     "--backend-store-uri", uri,
                     "--artifact-store", uri],
                    check=True,
                    stdout=subprocess.PIPE,
                    stderr=subprocess.PIPE,
                )
                after = shutil.disk_usage(uri).used
                logger.info("🧹 mlflow gc completed (%.2f MB → %.2f MB)",
                            before/1e6, after/1e6)

        except Exception as exc:
            logger.warning("Cleanup for %s failed: %s", model_name, exc)

    async def vacuum_store(self) -> None:
        """Force a *store‑wide* `mlflow gc` (use from cron jobs)."""
        try:
            uri = mlflow.get_tracking_uri().removeprefix("file:")
            before = shutil.disk_usage(uri).used
            subprocess.run(
                ["mlflow", "gc",
                 "--backend-store-uri", uri,
                 "--artifact-store", uri],
                check=True,
                stdout=subprocess.PIPE,
                stderr=subprocess.PIPE,
            )
            after = shutil.disk_usage(uri).used
            logger.info("🧹 Store-wide vacuum completed (%.2f MB → %.2f MB)",
                        before/1e6, after/1e6)
        except Exception as exc:
            logger.warning("Store vacuum failed: %s", exc)


# Global singleton
model_service = ModelService()


In [None]:
%%writefile api/app/main.py
import logging
import os
import asyncio
import json
from fastapi import FastAPI, Request, Depends, BackgroundTasks, status, HTTPException
from fastapi.security import OAuth2PasswordRequestForm
from fastapi.middleware.cors import CORSMiddleware
from fastapi.responses import JSONResponse
from sqlalchemy.ext.asyncio import AsyncSession
from sqlalchemy.exc import SQLAlchemyError
import time
from typing import Optional

from pydantic import BaseModel

# ── NEW: Fix ML backend configuration before any JAX imports ───────────────────────────
from .utils.env_sanitizer import fix_ml_backends
fix_ml_backends()
# ──────────────────────────────────────────────────────────────────────────

# ── NEW: Rate limiting imports ─────────────────────────────────────────────────────────
from fastapi_limiter import FastAPILimiter
from redis import asyncio as redis
# ────────────────────────────────────────────────────────────────────────────────────────

# ── NEW: Concurrency limiting imports ────────────────────────────────────────────────
from .middleware.concurrency import ConcurrencyLimiter
# ────────────────────────────────────────────────────────────────────────────────────────

from .db import lifespan, get_db, get_app_ready
from .security import create_access_token, get_current_user, verify_password
from .crud import get_user_by_username
from .schemas.iris import IrisPredictRequest, IrisPredictResponse, IrisFeatures
from .schemas.cancer import CancerPredictRequest, CancerPredictResponse, CancerFeatures
from .schemas.train import IrisTrainRequest, CancerTrainRequest, BayesTrainRequest, BayesTrainResponse, BayesConfigResponse, BayesRunMetrics
from .services.ml.model_service import model_service
from .core.config import settings
from .deps.limits import default_limit, heavy_limit, login_limit, training_limit, light_limit
from .security import LoginPayload, get_credentials

# ── NEW: guarantee log directory exists ───────────────────────────
os.makedirs("logs", exist_ok=True)
# ──────────────────────────────────────────────────────────────────

# Configure logging
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)

# ── NEW: Redis cache client for prediction caching ────────────────────────────────────
# Use the same Redis URL logic as in db.py for consistency
if settings.CACHE_ENABLED:
    env_url = os.getenv("REDIS_URL")
    if settings.ENVIRONMENT_CANONICAL == "production" and env_url:
        redis_url = env_url
    else:
        redis_url = settings.REDIS_URL

    cache = redis.from_url(
        redis_url,
        encoding="utf-8",
        decode_responses=True,
    )
    logger.info("📦 Prediction caching enabled (Redis %s)", redis_url)
else:
    cache = None
    logger.info("📦 Prediction caching disabled by config")
# ────────────────────────────────────────────────────────────────────────────────────────

# Pydantic models
class Payload(BaseModel):
    count: int

class PredictionRequest(BaseModel):
    data: Payload

class PredictionResponse(BaseModel):
    prediction: str
    confidence: float
    input_received: Payload  # Echo back the input for verification

class Token(BaseModel):
    access_token: str
    token_type: str

app = FastAPI(
    title="FastAPI + React ML App",
    version="1.0.0",
    docs_url="/api/v1/docs",
    redoc_url="/api/v1/redoc",
    openapi_url="/api/v1/openapi.json",
    swagger_ui_parameters={"persistAuthorization": True},
    lifespan=lifespan,  # register startup/shutdown events
)

# ── Rate limiting is now initialized in lifespan() ────────────────────────────────────
# ────────────────────────────────────────────────────────────────────────────────────────

# Configure CORS with environment-based origins
origins_env = settings.ALLOWED_ORIGINS
origins: list[str] = [o.strip() for o in origins_env.split(",")] if origins_env != "*" else ["*"]

app.add_middleware(
    CORSMiddleware,
    allow_origins=["*"],  # In production, replace with specific origins
    allow_credentials=True,
    allow_methods=["*"],
    allow_headers=["*"],
)

# ── NEW: Add concurrency limiting middleware ──────────────────────────────────────────
app.add_middleware(ConcurrencyLimiter, max_concurrent=4)
# ────────────────────────────────────────────────────────────────────────────────────────

@app.middleware("http")
async def add_process_time_header(request: Request, call_next):
    """Measure request time and add X-Process-Time header."""
    start = time.perf_counter()
    response = await call_next(request)
    elapsed = time.perf_counter() - start
    response.headers["X-Process-Time"] = f"{elapsed:.4f}"
    return response

# Health check endpoint
@app.get("/api/v1/health")
async def health_check():
    """Basic health check - always returns 200 if server is running."""
    return {"status": "healthy", "timestamp": time.time()}

@app.get("/api/v1/hello")
async def hello(current_user: str = Depends(get_current_user)):
    """Simple endpoint for token validation."""
    return {"message": f"Hello {current_user}!", "status": "authenticated"}

@app.get("/api/v1/ready")
async def ready():
    """Basic readiness check."""
    return {"ready": get_app_ready()}

@app.get("/api/v1/ready/frontend")
async def ready_frontend() -> dict:
    """
    Frontend-safe readiness payload.
    Returns only small, stable fields the React SPA depends on.
    This avoids the large nested dependency audit data that was causing frontend crashes.
    """
    ready_for_login = get_app_ready()
    loaded = set(model_service.models.keys())
    return {
        "ready": ready_for_login,
        "models": {
            "iris": "iris_random_forest" in loaded or "iris_logreg" in loaded,
            "cancer": "breast_cancer_bayes" in loaded or "breast_cancer_stub" in loaded,
        },
        "has_bayes": "breast_cancer_bayes" in loaded,
        "has_stub": "breast_cancer_stub" in loaded,
        "all_models_loaded": all(
            model in loaded 
            for model in ["iris_random_forest", "breast_cancer_bayes"]
        ),
    }

@app.post("/api/v1/token", response_model=Token, dependencies=[Depends(login_limit)])
async def login(
    creds: LoginPayload = Depends(get_credentials),
    db: AsyncSession = Depends(get_db),
):
    """
    Issue a JWT. Accepts **either**
    • JSON {"username": "...", "password": "..."}  *or*
    • classic x‑www‑form‑urlencoded.
    """
    # 1️⃣ readiness gate
    if not get_app_ready():
        raise HTTPException(
            status_code=status.HTTP_503_SERVICE_UNAVAILABLE,
            detail="Backend still loading models. Try again in a moment.",
            headers={"Retry‑After": "10"},
        )

    # 2️⃣ verify credentials
    user = await get_user_by_username(db, creds.username)
    if not user or not verify_password(creds.password, user.hashed_password):
        raise HTTPException(status_code=status.HTTP_401_UNAUTHORIZED,
                            detail="Invalid credentials")

    # 3️⃣ issue token
    token = create_access_token(subject=user.username)
    return Token(access_token=token, token_type="bearer")

# --- PATCH: ready_full -------------------------------------------------------
@app.get("/api/v1/ready/full")
async def ready_full(debug: Optional[bool] = False) -> dict:
    """
    Extended readiness probe with environment drift summary.

    Query param:
      debug=1  -> include filtered_status map for troubleshooting.
    """
    ready_for_login = get_app_ready()
    expected = {"iris_random_forest", "breast_cancer_bayes"}  # minimal contract
    loaded = set(model_service.models.keys())

    # ----- helpers -----------------------------------------------------------
    def _is_meta(k: str) -> bool:
        return k.endswith("_dep_audit") or k.endswith("_last_error")

    def _model_status_items():
        for k, v in model_service.status.items():
            if _is_meta(k):
                continue
            yield k, v

    # ----- env drift summary -------------------------------------------------
    drift = {}
    for m in ("iris_random_forest", "iris_logreg", "breast_cancer_bayes", "breast_cancer_stub"):
        audit = model_service.status.get(f"{m}_dep_audit", {})
        critical = any(
            (pkg in ("numpy", "scipy", "scikit-learn", "psutil")) and rec.get("severity") == "MAJOR_DRT"
            for pkg, rec in audit.items()
        )
        drift[m] = {"critical_drift": critical, "details": audit}

    # ----- core fields -------------------------------------------------------
    filtered_status = dict(_model_status_items())
    all_models_loaded = all(v == "loaded" for v in filtered_status.values())
    training = [k for k, v in filtered_status.items() if v == "training"]

    response = {
        "ready": ready_for_login,
        "model_status": model_service.status,  # raw (includes meta)
        "env_drift": drift,
        "all_models_loaded": all_models_loaded,
        "models": {m: (m in loaded) for m in expected},
        "training": training,
    }

    if debug:
        response["status_filtered"] = filtered_status
        response["status_counts"] = {
            "raw": len(model_service.status),
            "filtered": len(filtered_status),
        }

    # Log response size for debugging
    if debug:
        import json
        response_size = len(json.dumps(response))
        logger.info("READY_FULL debug: payload size=%d bytes", response_size)

    logger.debug("READY endpoint – _app_ready=%s", ready_for_login)
    return response
# --- END PATCH ---------------------------------------------------------------



# ── Alias routes (no auth, not shown in OpenAPI) ────────────────────────────
@app.get("/ready/full", include_in_schema=False)
async def ready_full_alias():
    """Alias for front-end calls that miss the /api/v1 prefix."""
    return await ready_full()

@app.get("/health", include_in_schema=False)
async def health_alias():
    """Alias for plain /health (SPA hits it before it knows the prefix)."""
    return await health_check()

@app.post("/token", include_in_schema=False)
async def login_alias(request: Request):
    """
    Alias: accept /token like /api/v1/token.
    Keeps the OAuth2PasswordRequestForm semantics without exposing clutter in docs.
    """
    from fastapi import Form

    # Parse form data manually to match OAuth2PasswordRequestForm behavior
    form_data = await request.form()
    username = form_data.get("username")
    password = form_data.get("password")

    if not username or not password:
        raise HTTPException(
            status_code=status.HTTP_422_UNPROCESSABLE_ENTITY,
            detail="username and password are required"
        )

    # Create a mock OAuth2PasswordRequestForm object
    class MockForm:
        def __init__(self, username, password):
            self.username = username
            self.password = password

    mock_form = MockForm(username, password)

    # Reuse the existing login logic
    db = await get_db().__anext__()
    return await login(mock_form, db)

@app.post("/iris/predict", include_in_schema=False)
async def iris_predict_alias(request: Request):
    """Alias for /api/v1/iris/predict"""
    from .schemas.iris import IrisPredictRequest

    # Parse JSON body
    body = await request.json()
    iris_request = IrisPredictRequest(**body)

    # Reuse the existing prediction logic without authentication for testing
    background_tasks = BackgroundTasks()
    current_user = "test_user"  # Skip authentication for alias endpoints
    return await predict_iris(iris_request, background_tasks, current_user)

@app.post("/cancer/predict", include_in_schema=False)
async def cancer_predict_alias(request: Request):
    """Alias for /api/v1/cancer/predict"""
    from .schemas.cancer import CancerPredictRequest

    # Parse JSON body
    body = await request.json()
    cancer_request = CancerPredictRequest(**body)

    # Reuse the existing prediction logic without authentication for testing
    background_tasks = BackgroundTasks()
    current_user = "test_user"  # Skip authentication for alias endpoints
    return await predict_cancer(cancer_request, background_tasks, current_user)

# ----- on-demand training endpoints ----------------------------------
@app.post("/api/v1/iris/train", status_code=202, dependencies=[Depends(training_limit)])
async def train_iris(
    request: IrisTrainRequest,
    background_tasks: BackgroundTasks,
    current_user: str = Depends(get_current_user)
):
    """
    Kick off training of the chosen Iris model.
    """
    background_tasks.add_task(
        model_service.train_iris,
        request.model_type
    )
    return {"status": f"started iris training ({request.model_type})"}

@app.post("/api/v1/cancer/train", status_code=202, dependencies=[Depends(training_limit)])
async def train_cancer(
    request: CancerTrainRequest,
    background_tasks: BackgroundTasks,
    current_user: str = Depends(get_current_user)
):
    """
    Kick off training of the chosen Cancer model.
    """
    background_tasks.add_task(
        model_service.train_cancer,
        request.model_type
    )
    return {"status": f"started cancer training ({request.model_type})"}

@app.get("/api/v1/cancer/bayes/config", response_model=BayesConfigResponse)
async def get_bayes_config(current_user: str = Depends(get_current_user)):
    """
    Get Bayesian training configuration for frontend form generation.
    """
    from .schemas.bayes import BayesCancerParams
    
    defaults = BayesCancerParams()
    
    return BayesConfigResponse(
        defaults=defaults,
        bounds={
            "draws": {"min": 200, "max": 20000},
            "tune": {"min": 200, "max": 20000},
            "target_accept": {"min": 0.80, "max": 0.999},
            "max_rhat_warn": {"min": 1.0, "max": 1.1},
            "min_ess_warn": {"min": 50, "max": 5000},
        },
        descriptions={
            "draws": "Number of posterior draws retained. More draws = better MCSE but longer runtime.",
            "tune": "Warmup steps for NUTS adaptation. Should be ≥ 0.2 * draws for good convergence.",
            "target_accept": "Target acceptance rate. Higher values reduce divergences but increase runtime.",
            "compute_waic": "Compute Widely Applicable Information Criterion. Fast but may be less robust than LOO.",
            "compute_loo": "Compute Leave-One-Out cross-validation. More reliable but slower.",
        },
        runtime_estimate={
            "base_seconds_per_sample": 0.001,  # rough estimate
            "chains": 4,
            "overhead_seconds": 5.0,  # model setup, data loading, etc.
        }
    )

@app.post("/api/v1/cancer/bayes/train", response_model=BayesTrainResponse, dependencies=[Depends(training_limit)])
async def train_bayes_cancer(
    request: BayesTrainRequest,
    background_tasks: BackgroundTasks,
    current_user: str = Depends(get_current_user)
):
    """
    Train Bayesian cancer model with validated hyperparameters.
    """
    if not get_app_ready():
        raise HTTPException(
            status_code=status.HTTP_503_SERVICE_UNAVAILABLE,
            detail="Backend still loading models. Try again in a moment.",
            headers={"Retry‑After": "10"},
        )

    try:
        if request.async_training:
            # Queue for background processing
            job_id = f"bayes_{int(time.time())}"
            background_tasks.add_task(
                model_service.train_bayes_cancer_with_params, 
                request.params
            )
            return BayesTrainResponse(
                run_id="",  # will be set when job completes
                job_id=job_id,
                status="queued",
                message="Training queued for background processing"
            )
        else:
            # Synchronous training
            run_id = await model_service.train_bayes_cancer_with_params(request.params)
            return BayesTrainResponse(
                run_id=run_id,
                status="completed",
                message="Training completed successfully"
            )
    except Exception as e:
        logger.error("Bayesian training failed: %s", e)
        return BayesTrainResponse(
            run_id="",
            status="failed",
            message=str(e)
        )

@app.get("/api/v1/cancer/bayes/runs/{run_id}", response_model=BayesRunMetrics)
async def get_bayes_run_metrics(
    run_id: str,
    current_user: str = Depends(get_current_user)
):
    """
    Get metrics for a specific Bayesian training run.
    """
    try:
        import mlflow
        run = mlflow.get_run(run_id)
        metrics = run.data.metrics
        params = run.data.params
        
        warnings = []
        if metrics.get("rhat_max", 0) > 1.01:
            warnings.append(f"R-hat exceeds threshold: {metrics['rhat_max']:.4f} > 1.01")
        if metrics.get("ess_bulk_min", 0) < 400:
            warnings.append(f"Bulk ESS below threshold: {metrics['ess_bulk_min']:.1f} < 400")
        
        return BayesRunMetrics(
            run_id=run_id,
            accuracy=metrics.get("accuracy", 0.0),
            rhat_max=metrics.get("rhat_max"),
            ess_bulk_min=metrics.get("ess_bulk_min"),
            ess_tail_min=metrics.get("ess_tail_min"),
            waic=metrics.get("waic"),
            loo=metrics.get("loo"),
            status="completed",
            warnings=warnings
        )
    except Exception as e:
        logger.error("Failed to get run metrics for %s: %s", run_id, e)
        raise HTTPException(status_code=404, detail=f"Run {run_id} not found or metrics unavailable")

# ----- debug endpoints ----------------------------------
@app.get("/api/v1/debug/ready")
async def debug_ready():
    """Debug endpoint to verify configuration loading."""
    return {
        "status": "ready",
        "environment": settings.ENVIRONMENT,
        "rate_limits": {
            "default": settings.RATE_LIMIT_DEFAULT,
            "cancer": settings.RATE_LIMIT_CANCER,
            "login": settings.RATE_LIMIT_LOGIN,
            "training": settings.RATE_LIMIT_TRAINING,
            "window": settings.RATE_LIMIT_WINDOW,
            "window_light": settings.RATE_LIMIT_WINDOW_LIGHT,
        },
        "quality_gates": {
            "accuracy_threshold": settings.QUALITY_GATE_ACCURACY_THRESHOLD,
            "f1_threshold": settings.QUALITY_GATE_F1_THRESHOLD,
        },
        "mlflow": {
            "experiment": settings.MLFLOW_EXPERIMENT,
            "tracking_uri": settings.MLFLOW_TRACKING_URI,
            "registry_uri": settings.MLFLOW_REGISTRY_URI,
        },
        "training": {
            "skip_background": settings.SKIP_BACKGROUND_TRAINING,
            "auto_train_missing": settings.AUTO_TRAIN_MISSING,
        },
        "debug": {
            "debug_ratelimit": settings.DEBUG_RATELIMIT,
        }
    }

# --- effective config debug --------------------------------------------------
@app.get("/api/v1/debug/effective-config")
async def effective_config(current_user: str = Depends(get_current_user)):
    """
    Inspect the *effective* runtime configuration (after YAML + env overrides).

    Sensitive fields are redacted. Use to debug environment drift across
    dev/staging/production deployments.
    """
    from app.core.config import settings

    redacted = {"SECRET_KEY", "DATABASE_URL"}
    cfg = settings.model_dump()
    for k in list(cfg):
        if k.upper() in redacted and cfg[k] is not None:
            cfg[k] = "***redacted***"
    return {
        "environment": settings.ENVIRONMENT_CANONICAL,
        "config": cfg,
    }

# ----- MLOps endpoints (new) ----------------------------------------
@app.post("/api/v1/mlops/evaluate/{model_name}")
async def evaluate_model(
    model_name: str,
    run_id: str,
    current_user: str = Depends(get_current_user)
):
    """
    Evaluate a candidate model against production baseline.

    This endpoint is used by CI/CD pipelines to implement quality gates.
    The model is evaluated on a fixed test set and compared to production.
    """
    logger.info(f"User {current_user} evaluating model {model_name} (run: {run_id})")

    result = await model_service.evaluate_model_quality(model_name, run_id)
    return result

@app.post("/api/v1/mlops/promote/{model_name}/staging")
async def promote_to_staging(
    model_name: str,
    run_id: str,
    current_user: str = Depends(get_current_user)
):
    """
    Promote a model to staging after quality gate evaluation.

    This endpoint:
    1. Evaluates the model quality
    2. If passed, registers as staging version
    3. Sets @staging alias for atomic promotion
    """
    logger.info(f"User {current_user} promoting {model_name} to staging (run: {run_id})")

    result = await model_service.promote_model_to_staging(model_name, run_id)
    return result

@app.post("/api/v1/mlops/promote/{model_name}/production")
async def promote_to_production(
    model_name: str,
    version: Optional[int] = None,
    approved_by: Optional[str] = None,
    current_user: str = Depends(get_current_user)
):
    """
    Promote a staging model to production.

    This can be called manually or by CI/CD:
    - If version specified, promotes that specific version
    - Otherwise, promotes the current @staging alias
    - Sets @prod alias for atomic promotion
    """
    logger.info(f"User {current_user} promoting {model_name} to production (version: {version})")

    result = await model_service.promote_model_to_production(model_name, version, approved_by)
    return result

@app.post("/api/v1/mlops/reload-model")
async def reload_model(
    model_name: Optional[str] = None,
    current_user: str = Depends(get_current_user)
):
    """
    Hot-reload models from MLflow registry.

    This endpoint allows the container to pick up new models
    without restarting the entire service. Useful for:
    - CI/CD deployments that update models
    - Manual model promotions
    - Testing new model versions

    Args:
        model_name: Specific model to reload (optional, reloads all if None)
    """
    logger.info(f"User {current_user} reloading models (specific: {model_name})")

    try:
        if model_name:
            # Reload specific model
            success = await model_service._try_load(model_name)
            if success:
                return {
                    "reloaded": True,
                    "model": model_name,
                    "status": model_service.status.get(model_name, "unknown")
                }
            else:
                raise HTTPException(
                    status_code=404,
                    detail=f"Failed to reload model {model_name}"
                )
        else:
            # Reload all models
            reloaded = []
            failed = []

            for name in ["iris_random_forest", "iris_logreg", 
                        "breast_cancer_bayes", "breast_cancer_stub"]:
                try:
                    success = await model_service._try_load(name)
                    if success:
                        reloaded.append(name)
                    else:
                        failed.append(name)
                except Exception as e:
                    logger.error(f"Failed to reload {name}: {e}")
                    failed.append(name)

            return {
                "reloaded": len(failed) == 0,
                "reloaded_models": reloaded,
                "failed_models": failed,
                "status": model_service.status
            }

    except Exception as e:
        logger.error(f"Model reload failed: {e}")
        raise HTTPException(
            status_code=500,
            detail=f"Model reload failed: {e}"
        )

@app.get("/api/v1/mlops/status")
async def mlops_status(current_user: str = Depends(get_current_user)):
    """
    Get MLOps status including model versions and stages.

    Returns comprehensive information about:
    - Model loading status
    - Registry versions and stages
    - Alias assignments
    - Training status
    """
    logger.info(f"User {current_user} requesting MLOps status")

    try:
        client = model_service.mlflow_client

        # Get model registry information
        registry_info = {}
        for model_name in ["iris_random_forest", "iris_logreg", 
                          "breast_cancer_bayes", "breast_cancer_stub"]:
            try:
                versions = client.search_model_versions(f"name='{model_name}'")
                registry_info[model_name] = {
                    "versions": len(versions),
                    "stages": {},
                    "aliases": {}
                }

                # Group by stage
                for v in versions:
                    stage = v.current_stage
                    if stage not in registry_info[model_name]["stages"]:
                        registry_info[model_name]["stages"][stage] = []
                    registry_info[model_name]["stages"][stage].append({
                        "version": v.version,
                        "run_id": v.run_id,
                        "created_at": v.creation_timestamp
                    })

                # Get aliases
                try:
                    aliases = client.get_registered_model_aliases(model_name)
                    registry_info[model_name]["aliases"] = {
                        alias: version for alias, version in aliases.items()
                    }
                except Exception as e:
                    logger.debug(f"Could not get aliases for {model_name}: {e}")

            except Exception as e:
                logger.warning(f"Could not get registry info for {model_name}: {e}")
                registry_info[model_name] = {"error": str(e)}

        return {
            "model_status": model_service.status,
            "loaded_models": list(model_service.models.keys()),
            "registry_info": registry_info,
            "app_ready": get_app_ready(),
            "mlflow_uri": settings.MLFLOW_TRACKING_URI
        }

    except Exception as e:
        logger.error(f"MLOps status failed: {e}")
        raise HTTPException(
            status_code=500,
            detail=f"MLOps status failed: {e}"
        )

@app.get("/api/v1/mlops/models/{model_name}/metrics")
async def get_model_metrics(
    model_name: str,
    current_user: str = Depends(get_current_user)
):
    """
    Get metrics for all versions of a registered model.

    This endpoint enables MLOps comparison between different model versions
    for quality gate decisions and promotion workflows.
    """
    try:
        metrics = await model_service.get_model_metrics(model_name)
        if not metrics:
            raise HTTPException(
                status_code=404, 
                detail=f"No registered model found with name '{model_name}'"
            )
        return {
            "model_name": model_name,
            "versions": metrics,
            "total_versions": len(metrics)
        }
    except Exception as e:
        logger.error(f"Error fetching metrics for {model_name}: {e}")
        raise HTTPException(
            status_code=500,
            detail=f"Failed to fetch model metrics: {str(e)}"
        )

@app.get("/api/v1/mlops/models/{model_name}/compare")
async def compare_model_versions(
    model_name: str,
    version_a: int,
    version_b: int,
    current_user: str = Depends(get_current_user)
):
    """
    Compare two specific model versions for MLOps decision making.

    This endpoint helps determine which model version performs better
    across key metrics like accuracy, F1-score, precision, and recall.
    """
    try:
        comparison = await model_service.compare_model_versions(
            model_name, version_a, version_b
        )

        if "error" in comparison:
            raise HTTPException(
                status_code=400,
                detail=comparison["error"]
            )

        return comparison
    except Exception as e:
        logger.error(f"Error comparing versions for {model_name}: {e}")
        raise HTTPException(
            status_code=500,
            detail=f"Failed to compare model versions: {str(e)}"
        )

@app.get("/api/v1/mlops/models/{model_name}/quality-gate")
async def check_quality_gate(
    model_name: str,
    version: Optional[int] = None,
    current_user: str = Depends(get_current_user)
):
    """
    Check if a model version passes quality gates.

    This endpoint evaluates a model against production baseline
    or absolute thresholds to determine if it's ready for promotion.
    """
    try:
        # If no version specified, use the latest staging version
        if version is None:
            client = model_service.mlflow_client
            staging_versions = client.search_model_versions(
                f"name='{model_name}' AND stage='Staging'"
            )
            if not staging_versions:
                raise HTTPException(
                    status_code=404,
                    detail=f"No staging version found for {model_name}"
                )
            version = staging_versions[0].version

        # Get the run_id for this version
        version_info = client.get_model_version(model_name, version)
        run_id = version_info.run_id

        # Evaluate quality gate
        eval_result = await model_service.evaluate_model_quality(model_name, run_id)

        return {
            "model_name": model_name,
            "version": version,
            "run_id": run_id,
            "quality_gate_result": eval_result,
            "passes_gate": eval_result["promoted"],
            "reason": eval_result["reason"]
        }
    except Exception as e:
        logger.error(f"Error checking quality gate for {model_name}: {e}")
        raise HTTPException(
            status_code=500,
            detail=f"Failed to check quality gate: {str(e)}"
        )

@app.get("/api/v1/iris/ready")
async def iris_ready():
    """Check if Iris model is loaded and ready."""
    return {"loaded": "iris_random_forest" in model_service.models}

@app.get("/api/v1/cancer/ready")
async def cancer_ready():
    """Check if Cancer model is loaded and ready."""
    return {"loaded": "breast_cancer_bayes" in model_service.models}

@app.post(
    "/api/v1/iris/predict",
    response_model=IrisPredictResponse,
    status_code=status.HTTP_200_OK,
    dependencies=[Depends(light_limit)]
)
async def predict_iris(
    request: IrisPredictRequest,
    background_tasks: BackgroundTasks,
    current_user: str = Depends(get_current_user),
):
    """
    Predict iris species from measurements, with optional Redis caching.
    """
    logger.info(f"User {current_user} called /iris/predict with {len(request.samples)} samples")

    # Ensure model is loaded
    model_name = "iris_random_forest" if request.model_type == "rf" else "iris_logreg"
    if model_name not in model_service.models:
        raise HTTPException(
            status_code=503,
            detail="Iris model still loading. Try again shortly.",
            headers={"Retry-After": "30"},
        )

    # Build cache key from primitives (avoid Pydantic models)
    serialized_samples = [s.dict() for s in request.samples]
    key = f"iris:{request.model_type}:{json.dumps(serialized_samples, sort_keys=True)}"

    # Try Redis GET if caching enabled
    if settings.CACHE_ENABLED:
        cached = await cache.get(key)
        if cached:
            logger.debug("Cache hit for key %s", key)
            return IrisPredictResponse(**json.loads(cached))

    # Perform prediction
    preds, probs = await model_service.predict_iris(
        features=serialized_samples,
        model_type=request.model_type,
    )

    # Prepare a fully-serializable result dict
    result = {
        "predictions": preds,
        "probabilities": probs,
        "input_received": serialized_samples,
    }

    # Store in cache if enabled
    if settings.CACHE_ENABLED:
        ttl = settings.CACHE_TTL_MINUTES * 60
        await cache.set(key, json.dumps(result), ex=ttl)

    # Audit log in background
    background_tasks.add_task(
        logger.info,
        f"[audit] user={current_user} endpoint=iris input={serialized_samples} output={preds}"
    )

    return IrisPredictResponse(**result)

@app.post(
    "/api/v1/cancer/predict",
    response_model=CancerPredictResponse,
    status_code=status.HTTP_200_OK,
    dependencies=[Depends(heavy_limit)]
)
async def predict_cancer(
    request: CancerPredictRequest,
    background_tasks: BackgroundTasks,
    current_user: str = Depends(get_current_user),
):
    """
    Predict breast-cancer diagnosis, with optional Redis caching.
    """
    logger.info(f"User {current_user} called /cancer/predict with {len(request.samples)} samples")

    # Build cache key from primitives (includes posterior_samples)
    serialized_samples = [s.dict() for s in request.samples]
    key = (
        f"cancer:{request.model_type}:"
        f"{request.posterior_samples or 0}:"
        f"{json.dumps(serialized_samples, sort_keys=True)}"
    )

    # Try Redis GET if caching enabled
    if settings.CACHE_ENABLED:
        cached = await cache.get(key)
        if cached:
            logger.debug("Cache hit for key %s", key)
            return CancerPredictResponse(**json.loads(cached))

    # Perform prediction
    preds, probs, uncertainties = await model_service.predict_cancer(
        features=serialized_samples,
        model_type=request.model_type,
        posterior_samples=request.posterior_samples,
    )

    # Prepare a fully-serializable result dict
    result = {
        "predictions": preds,
        "probabilities": probs,
        "uncertainties": uncertainties,
        "input_received": serialized_samples,
    }

    # Store in cache if enabled
    if settings.CACHE_ENABLED:
        ttl = settings.CACHE_TTL_MINUTES * 60
        await cache.set(key, json.dumps(result), ex=ttl)

    # Audit log in background
    background_tasks.add_task(
        logger.info,
        f"[audit] user={current_user} endpoint=cancer input={serialized_samples} output={preds}"
    )

    return CancerPredictResponse(**result) 

@app.get("/api/v1/debug/compiler")
async def debug_compiler():
    """
    Debug endpoint to check JAX/NumPyro backend configuration.
    Returns information about the JAX backend setup.
    """
    try:
        import jax
        import numpyro
        import pymc as pm

        return {
            "backend": "jax_numpyro",
            "jax_version": jax.__version__,
            "numpyro_version": numpyro.__version__,
            "pymc_version": pm.__version__,
            "jax_devices": str(jax.devices()),
            "jax_platform": jax.default_backend(),
            "status": "jax_backend_configured"
        }
    except ImportError as e:
        return {
            "backend": "unknown",
            "error": f"Import error: {e}",
            "status": "missing_dependencies"
        }
    except Exception as e:
        return {
            "backend": "unknown", 
            "error": f"Configuration error: {e}",
            "status": "configuration_failed"
        }

@app.get("/api/v1/debug/psutil")
async def debug_psutil():
    """
    Debug endpoint to check psutil status and configuration.
    Returns information about psutil module and its Process class.
    """
    import sys, types
    try:
        import psutil
        module_info = {
            "module_path": getattr(psutil, "__file__", "?"),
            "version": getattr(psutil, "__version__", "?"),
            "has_Process": hasattr(psutil, "Process"),
            "sys_path": sys.path
        }

        # Try a safe Process call
        try:
            proc = psutil.Process()
            module_info["process_test"] = {
                "success": True,
                "pid": proc.pid,
                "cpu_count": psutil.cpu_count()
            }
        except Exception as e:
            module_info["process_test"] = {
                "success": False,
                "error": str(e)
            }

        return {
            "status": "loaded",
            "info": module_info
        }
    except ImportError as e:
        return {
            "status": "import_failed",
            "error": str(e)
        }
    except Exception as e:
        return {
            "status": "error",
            "error": str(e)
        } 

@app.get("/api/v1/debug/deps")
async def debug_deps():
    """
    Report recorded vs. runtime dependency versions for each loaded model.

    Uses audit data collected during ModelService._load_production_model().
    Helpful when MLflow logs 'requirements_utils' mismatch warnings.

    NOTE: purely diagnostic – no secrets.
    """
    import importlib.metadata as im
    runtime = {}
    for pkg in ("numpy", "scipy", "scikit-learn", "psutil", "pandas"):
        try:
            runtime[pkg] = im.version(pkg)
        except Exception:
            runtime[pkg] = None

    audits = {k: v for k, v in model_service.status.items() if k.endswith("_dep_audit")}
    return {
        "runtime": runtime,
        "model_audits": audits,
        "enforcement_policy": os.getenv("MODEL_ENV_ENFORCEMENT", "warn"),
    }


@app.get("/api/v1/test/401")
async def test_401():
    """Test endpoint that returns 401 for testing session expiry."""
    raise HTTPException(
        status_code=status.HTTP_401_UNAUTHORIZED,
        detail="Test 401 response"
    )

# ── Debug‑only ratelimit helpers ─────────────────────────────────────────────
from .deps.limits import get_redis, _user_or_ip as user_or_ip

@app.post("/api/v1/debug/ratelimit/reset", include_in_schema=False)
async def rl_reset(request: Request):
    """
    Flush **all** rate‑limit counters bound to the caller (JWT _or_ IP).

    We match every fragment that contains the identifier to survive
    future changes in FastAPI‑Limiter's key schema.
    """
    r = get_redis()
    if not r:
        raise HTTPException(status_code=503, detail="Rate‑limiter not initialised")

    ident = await user_or_ip(request)
    keys = await r.keys(f"ratelimit:*{ident}*")        # <— broader pattern
    if keys:
        await r.delete(*keys)
    return {"reset": len(keys)}

if settings.DEBUG_RATELIMIT:          # OFF by default
    @app.get("/api/v1/debug/ratelimit/{bucket}", include_in_schema=False)
    async def rl_status(bucket: str, request: Request):
        """
        Inspect Redis keys for the current identifier + bucket.
        Handy for CI tests – **never enable in prod**.
        """
        key_prefix = f"ratelimit:{bucket}:{await user_or_ip(request)}"
        r = get_redis()
        keys = await r.keys(f"{key_prefix}*")
        values = await r.mget(keys) if keys else []
        return dict(zip(keys, values)) 
