# PM2.5 model training with MLflow — **Fixed**

This notebook is a direct conversion of your `models/train_pm25_mlflow.py`, with the following **critical fixes** baked in:

- ✅ Correct `mlflow.lightgbm.log_model` argument (`artifact_path` not `name`)
- ✅ Tracking URI is configurable via `MLFLOW_TRACKING_URI` (defaults to `http://localhost:5000`)
- ✅ Safer experiment creation using `mlflow-artifacts:/...` when `--serve-artifacts` is enabled
- ✅ Robust datetime handling for time splits
- ✅ Optional auto-transition of the newly registered model to **Production** (so your API can load `models:/routeaq_pm25/Production`)

In [1]:
# models/train_pm25_mlflow_fixed.ipynb
import os
import pandas as pd
import numpy as np
import mlflow
import mlflow.lightgbm
from mlflow.tracking import MlflowClient
from mlflow.models.signature import infer_signature
import lightgbm as lgb
from feast import FeatureStore

# -----------------------
# Config (env-overridable)
# -----------------------
TRACKING_URI    = os.getenv("MLFLOW_TRACKING_URI", "http://localhost:5000")
EXPERIMENT_NAME = os.getenv("MLFLOW_EXPERIMENT_NAME", "routeaq_pm25_v2")
MODEL_NAME      = os.getenv("MLFLOW_REGISTERED_MODEL", "routeaq_pm25")  # registered model

JOINED_PARQUET  = os.getenv("JOINED_PARQUET", "/workspaces/airoute_mlops/airoute_mlops/data/silver/joined/hourly_joined_2025_from_28jul.parquet")
FEAST_REPO      = os.getenv("FEAST_REPO", "/workspaces/airoute_mlops/airoute_mlops/feature_repo")

FEATURE_LIST = [
    "aq_hourly:pm25_t_1",
    "aq_hourly:no2_t_1",
    "aq_hourly:o3_t_1",
    "aq_hourly:temp",
    "aq_hourly:wind",
    "aq_hourly:humidity",
]
TARGET_COL = "pm25_target"

TRAIN_START = os.getenv("TRAIN_START", "2025-07-30T10:00:00Z")
TRAIN_END   = os.getenv("TRAIN_END",   "2025-07-31T23:00:00Z")
VAL_START   = os.getenv("VAL_START",   "2025-08-01T00:00:00Z")
VAL_END     = os.getenv("VAL_END",     "2025-08-02T23:00:00Z")

AUTO_PROMOTE_PROD = os.getenv("AUTO_PROMOTE_PROD", "true").lower() == "true"


def ensure_experiment(name: str) -> str:
    client = MlflowClient()
    exp = client.get_experiment_by_name(name)
    if exp is None:
        # Prefer server-side artifact scheme (requires mlflow server started with --serve-artifacts)
        try:
            exp_id = client.create_experiment(name, artifact_location=f"mlflow-artifacts:/{name}")
        except Exception:
            exp_id = client.create_experiment(name)
        return exp_id
    return exp.experiment_id


# -----------------------
# Main training entrypoint
# -----------------------
def main():
    # Ensure client talks to the server and doesn't try to force a local artifact path
    os.environ.pop("MLFLOW_ARTIFACT_URI", None)
    mlflow.set_tracking_uri(TRACKING_URI)

    exp_id = ensure_experiment(EXPERIMENT_NAME)
    mlflow.set_experiment(EXPERIMENT_NAME)

    # --------
    # Load data
    # --------
    joined = pd.read_parquet(JOINED_PARQUET)
    print(f"Loaded joined parquet: {len(joined):,} rows")

    # Feast historical features
    store = FeatureStore(repo_path=FEAST_REPO)

    entities = (
        joined[["site_id", "date_time"]]
        .drop_duplicates()
        .rename(columns={"date_time": "event_timestamp"})
    )

    # ensure proper dtype (tz aware)
    entities["event_timestamp"] = pd.to_datetime(entities["event_timestamp"], utc=True)

    hist = store.get_historical_features(entity_df=entities, features=FEATURE_LIST).to_df()

    # normalize column name and dtypes
    if "event_timestamp" in hist.columns:
        hist = hist.rename(columns={"event_timestamp": "date_time"})
    hist["date_time"] = pd.to_datetime(hist["date_time"], utc=True)

    print("Historical feature shape:", hist.shape)

    # Merge targets
    joined["date_time"] = pd.to_datetime(joined["date_time"], utc=True)
    df = (
        hist.merge(
            joined[["site_id", "date_time", "pm25_target", "no2_target", "o3_target"]],
            on=["site_id", "date_time"],
            how="left",
        )
        .dropna(subset=[TARGET_COL])
        .copy()
    )
    print("Rows with target:", len(df))

    # Time split
    train = df[(df["date_time"] >= pd.to_datetime(TRAIN_START, utc=True)) & (df["date_time"] <= pd.to_datetime(TRAIN_END, utc=True))]
    val   = df[(df["date_time"] >= pd.to_datetime(VAL_START,   utc=True)) & (df["date_time"] <= pd.to_datetime(VAL_END,   utc=True))]
    print(f"Train rows: {len(train):,}, Val rows: {len(val):,}")

    feature_cols = [f.split(":", 1)[1] for f in FEATURE_LIST]
    X_train, y_train = train[feature_cols], train[TARGET_COL]
    X_val,   y_val   = val[feature_cols],   val[TARGET_COL]

    with mlflow.start_run(run_name="lgbm_pm25_with_signature") as run:
        print("Artifact URI (inside run):", mlflow.get_artifact_uri())

        model = lgb.LGBMRegressor(n_estimators=300, learning_rate=0.05)
        model.fit(
            X_train, y_train,
            eval_set=[(X_val, y_val)],
            eval_metric="l1",
            callbacks=[lgb.early_stopping(stopping_rounds=20), lgb.log_evaluation(period=0)]
        )

        preds = model.predict(X_val)
        mae = float(np.mean(np.abs(preds - y_val)))
        print(f"Validation MAE: {mae:.3f}")

        signature = infer_signature(X_train, model.predict(X_train))
        input_example = X_train.head(5)

        mlflow.log_params({
            "train_start": str(TRAIN_START), "train_end": str(TRAIN_END),
            "val_start": str(VAL_START),     "val_end": str(VAL_END),
            "n_estimators": 300, "learning_rate": 0.05,
        })
        mlflow.log_metrics({"val_mae": mae, "train_rows": len(X_train), "val_rows": len(X_val)})

        # ---- CRITICAL FIX ----
        # Use artifact_path, not 'name'
        mlflow.lightgbm.log_model(
            model,
            artifact_path="model_pm25",
            signature=signature,
            input_example=input_example,
        )

        run_id = run.info.run_id
        model_uri = f"runs:/{run_id}/model_pm25"

                # Register & (optionally) promote to Production if the server is serving artifacts
        registered = False
        if str(mlflow.get_artifact_uri()).startswith("mlflow-artifacts:"):
            print("Server is serving artifacts. Registering model…")
            mv = mlflow.register_model(model_uri=model_uri, name=MODEL_NAME)
            registered = True
            print(f"Registered model version: {mv.version}")

            if AUTO_PROMOTE_PROD:
                client = MlflowClient()
                client.transition_model_version_stage(
                    name=MODEL_NAME,
                    version=mv.version,
                    stage="Production",
                    archive_existing_versions=True
                )
                print(f"Transitioned {MODEL_NAME} v{mv.version} to Production.")

                # --- NEW: alias for zero-downtime serving ---
                client.set_registered_model_alias(
                    name=MODEL_NAME,
                    alias="prod",
                    version=mv.version,
                )
                print(f"Alias 'prod' now points to {MODEL_NAME} v{mv.version}")
        else:
            print(
                "\n⚠️  Skipping registration because artifact URI is not 'mlflow-artifacts:'\n"
                "   Update docker-compose mlflow service to add '--serve-artifacts', then either:\n"
                f"   • re-run training, or\n"
                f"   • register this run after the fix with:\n"
                f"       mlflow.register_model(model_uri='{model_uri}', name='{MODEL_NAME}')\n"
            )


        print(f"\n✅ Done.\nRun: {run_id}\nUI:  {TRACKING_URI}/#/experiments/{exp_id}/runs/{run_id}")
        if registered and AUTO_PROMOTE_PROD:
            print(f"Model URI for API: models:/{MODEL_NAME}/Production")

if __name__ == "__main__":
    main()

Loaded joined parquet: 20,447 rows




Historical feature shape: (20447, 8)
Rows with target: 15797
Train rows: 5,014, Val rows: 3,166
Artifact URI (inside run): mlflow-artifacts:/routeaq_pm25_v2/4a78b61d2a054d7bba7bc5c59bcc4e1f/artifacts
[LightGBM] [Info] Auto-choosing col-wise multi-threading, the overhead of testing was 0.000362 seconds.
You can set `force_col_wise=true` to remove the overhead.
[LightGBM] [Info] Total Bins 1332
[LightGBM] [Info] Number of data points in the train set: 5014, number of used features: 6
[LightGBM] [Info] Start training from score 4.307868
Training until validation scores don't improve for 20 rounds
Early stopping, best iteration is:
[46]	valid_0's l1: 1.02068	valid_0's l2: 3.64649
Validation MAE: 1.021


  self.utc_time_created = str(utc_time_created or datetime.utcnow())
Registered model 'routeaq_pm25' already exists. Creating a new version of this model...
2025/08/16 15:34:49 INFO mlflow.store.model_registry.abstract_store: Waiting up to 300 seconds for model version to finish creation. Model name: routeaq_pm25, version 6


Server is serving artifacts. Registering model…
Registered model version: 6
Transitioned routeaq_pm25 v6 to Production.
Alias 'prod' now points to routeaq_pm25 v6

✅ Done.
Run: 4a78b61d2a054d7bba7bc5c59bcc4e1f
UI:  http://localhost:5000/#/experiments/1/runs/4a78b61d2a054d7bba7bc5c59bcc4e1f
Model URI for API: models:/routeaq_pm25/Production


Created version '6' of model 'routeaq_pm25'.
  client.transition_model_version_stage(
