# Demo 5 — Model Serving

Wrap the ARIMA model as a **pyfunc** for serving, then:
1. Serve it locally with `mlflow models serve`
2. Package it as a Docker image with `mlflow models build-docker`

In [1]:
import pandas as pd
import numpy as np
from statsmodels.tsa.arima.model import ARIMA
from sklearn.metrics import mean_squared_error
import mlflow
import mlflow.pyfunc
import warnings
warnings.filterwarnings("ignore")

In [None]:
mlflow.set_tracking_uri("http://localhost:5050")
mlflow.set_experiment("temperature-forecast-simple")

MODEL_NAME = "temperature-forecast-simple"

## Define a pyfunc wrapper for ARIMA

The native statsmodels model can't be served directly via REST because its `predict()` API doesn't match what MLflow serving expects. We wrap it so that `predict()` accepts a DataFrame with an `n_steps` column and returns forecasted values.

In [None]:
class ARIMAWrapper(mlflow.pyfunc.PythonModel):
    """Wraps a fitted statsmodels ARIMA for MLflow serving."""

    def load_context(self, context):
        import statsmodels.tsa.arima.model
        self.model = mlflow.statsmodels.load_model(context.artifacts["arima_model"])

    def predict(self, context, model_input, params=None):
        n_steps = int(model_input["n_steps"].iloc[0])
        forecast = self.model.forecast(steps=n_steps)
        return forecast.tolist()

## Train, wrap, and register the pyfunc model

In [None]:
df = pd.read_csv("../data/jena_daily_temp.csv", parse_dates=["Date Time"], index_col="Date Time")
train = df.iloc[:-90]
test = df.iloc[-90:]

# Fit the underlying ARIMA model
order = (5, 1, 2)
model = ARIMA(train["temperature"], order=order)
results = model.fit()

print(f"Fitted ARIMA{order} — AIC: {results.aic:.2f}")

In [None]:
with mlflow.start_run(run_name="pyfunc-serving") as run:
    # First log the statsmodels model as an internal artifact
    arima_model_info = mlflow.statsmodels.log_model(results, name="statsmodels_arima")

    # Then log the pyfunc wrapper that references it
    pyfunc_model_info = mlflow.pyfunc.log_model(
        name="arima_model",
        python_model=ARIMAWrapper(),
        artifacts={"arima_model": arima_model_info.model_uri},
        input_example=pd.DataFrame({"n_steps": [30]}),
    )

    # Register the pyfunc model
    mv = mlflow.register_model(
        model_uri=f"runs:/{run.info.run_id}/arima_model",
        name=MODEL_NAME,
    )

    # Set @champion alias
    client = mlflow.MlflowClient()
    client.set_registered_model_alias(MODEL_NAME, "champion", mv.version)

    print(f"Registered pyfunc model v{mv.version} with @champion alias")

## Test the pyfunc model locally

In [None]:
loaded = mlflow.pyfunc.load_model(f"models:/{MODEL_NAME}@champion")

result = loaded.predict(pd.DataFrame({"n_steps": [7]}))
print("7-day forecast:", result)

---

## Part 1: Serve with `mlflow models serve`

Run this in a terminal:

```bash
MLFLOW_TRACKING_URI=http://localhost:5050 mlflow models serve \
  -m "models:/temperature-forecast-simple@champion" \
  -p 5001 --no-conda
```

> **Note:** We set `MLFLOW_TRACKING_URI` so the CLI can reach the dockerized MLflow server to download the model artifacts.

In [None]:
import requests

payload = {
    "dataframe_split": {
        "columns": ["n_steps"],
        "data": [[30]]
    }
}

response = requests.post(
    "http://localhost:5001/invocations",
    json=payload,
)

print(f"Status: {response.status_code}")
print(f"Forecast: {response.json()}")

Equivalent curl command:

```bash
curl http://localhost:5001/invocations \
  -H "Content-Type: application/json" \
  -d '{"dataframe_split": {"columns": ["n_steps"], "data": [[30]]}}'
```

---

## Part 2: Build a Docker image

```bash
MLFLOW_TRACKING_URI=http://localhost:5050 mlflow models build-docker \
  -m "models:/temperature-forecast-simple@champion" \
  -n temp-forecast-server
```

Then run:

```bash
docker run -p 5001:8080 temp-forecast-server
```

In [None]:
# After starting the Docker container, test with the same request:

response = requests.post(
    "http://localhost:5001/invocations",
    json=payload,
)

print(f"Status: {response.status_code}")
print(f"Forecast: {response.json()}")

---
**That's it!** We've gone from experiment tracking to a containerized model serving forecasts via REST API.