# MLFlow and Ray example

In this notebook we will train an ML model and deploy it in to Ray cluster.

## MLFlow experiment tracking

Here we train the ML model and log metrics using MLFlow tracking server.

In [1]:
import torch
import torch.nn as nn
import torch.optim as optim
from sklearn.datasets import load_diabetes
from sklearn.model_selection import train_test_split
import numpy as np
import mlflow
import mlflow.pytorch

# -------------------
# Prepare Data
# -------------------
mlflow.set_tracking_uri(uri="http://ai-starter-kit-mlflow:5000")
data = load_diabetes()
X = data.data
y = data.target.reshape(-1, 1)

X_train, X_test, y_train, y_test = train_test_split(
    X, y, test_size=0.2, random_state=42
)

X_train_tensor = torch.tensor(X_train, dtype=torch.float32)
y_train_tensor = torch.tensor(y_train, dtype=torch.float32)
X_test_tensor = torch.tensor(X_test, dtype=torch.float32)
y_test_tensor = torch.tensor(y_test, dtype=torch.float32)

# -------------------
# Define Model
# -------------------
class RegressionModel(nn.Module):
    def __init__(self, input_dim):
        super().__init__()
        self.linear = nn.Linear(input_dim, 1)

    def forward(self, x):
        return self.linear(x)

input_dim = X_train.shape[1]
model = RegressionModel(input_dim)

# -------------------
# Training
# -------------------
epochs = 100
lr = 0.01

criterion = nn.MSELoss()
optimizer = optim.Adam(model.parameters(), lr=lr)

mlflow.set_experiment("Diabetes_Prediction_PyTorch")

with mlflow.start_run():
    mlflow.log_param("epochs", epochs)
    mlflow.log_param("learning_rate", lr)
    mlflow.log_param("optimizer", "Adam")
    mlflow.log_param("loss_fn", "MSELoss")
    mlflow.log_param("input_features", input_dim)

    for epoch in range(epochs):
        model.train()
        optimizer.zero_grad()
        outputs = model(X_train_tensor)
        loss = criterion(outputs, y_train_tensor)
        loss.backward()
        optimizer.step()

    # -------------------
    # Evaluation
    # -------------------
    model.eval()
    with torch.no_grad():
        preds = model(X_test_tensor)
        mse = criterion(preds, y_test_tensor).item()
        rmse = np.sqrt(mse)

    mlflow.log_metric("mse", mse)
    mlflow.log_metric("rmse", rmse)

    # # Log model to MLflow
    # mlflow.pytorch.log_model(model, "pytorch_model")


2025/08/12 15:00:18 INFO mlflow.tracking.fluent: Experiment with name 'Diabetes_Prediction_PyTorch' does not exist. Creating a new experiment.
The git executable must be specified in one of the following ways:
    - be included in your $PATH
    - be set via $GIT_PYTHON_GIT_EXECUTABLE
    - explicitly set via git.refresh(<full-path-to-git-executable>)

All git commands will error until this is rectified.

This initial message can be silenced or aggravated in the future by setting the
$GIT_PYTHON_REFRESH environment variable. Use one of the following values:
    - quiet|q|silence|s|silent|none|n|0: for no message or exception
    - error|e|exception|raise|r|2: for a raised exception

Example:
    export GIT_PYTHON_REFRESH=quiet



🏃 View run stately-kite-741 at: http://ai-starter-kit-mlflow:5000/#/experiments/1/runs/872da23f5a5541a39c0f893adbe53466
🧪 View experiment at: http://ai-starter-kit-mlflow:5000/#/experiments/1


## Ray deployment

In this step we will use the model from the previous step to deploy it to our Ray cluster.

In [2]:
!pip install "ray[serve,client,default]"

Defaulting to user installation because normal site-packages is not writeable
Collecting ray[client,default,serve]
  Downloading ray-2.48.0-cp312-cp312-manylinux2014_aarch64.whl.metadata (19 kB)
Collecting msgpack<2.0.0,>=1.0.0 (from ray[client,default,serve])
  Downloading msgpack-1.1.1-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl.metadata (8.4 kB)
Collecting grpcio (from ray[client,default,serve])
  Downloading grpcio-1.74.0-cp312-cp312-manylinux_2_17_aarch64.whl.metadata (3.8 kB)
Collecting aiohttp_cors (from ray[client,default,serve])
  Downloading aiohttp_cors-0.8.1-py3-none-any.whl.metadata (20 kB)
Collecting colorful (from ray[client,default,serve])
  Downloading colorful-0.5.7-py2.py3-none-any.whl.metadata (16 kB)
Collecting py-spy>=0.4.0 (from ray[client,default,serve])
  Downloading py_spy-0.4.1-py2.py3-none-manylinux_2_17_aarch64.manylinux2014_aarch64.whl.metadata (510 bytes)
Collecting opencensus (from ray[client,default,serve])
  Downloading opencensus-0.11

In [3]:
import torch
import mlflow.pytorch
import numpy as np
from starlette.requests import Request
from typing import Dict

from ray import serve
import ray

# ray.init("ray://ai-starter-kit-kuberay-head-svc:10001", namespace="my_new_namespace")

# MLFLOW_MODEL_URI = "mlruns/0/<RUN_ID>/artifacts/pytorch_model"  # Change to your run path

@serve.deployment
class PyTorchMLflowDeployment:
    def __init__(self):
        print("Loading model from MLflow...")
        # # self.model = mlflow.pytorch.load_model(MLFLOW_MODEL_URI)
        self.model = model
        self.model.eval()
        print("Model loaded successfully.")

    async def __call__(self, request: Request) -> Dict:
        try:
            data = await request.json()
            features = data.get("features", None)
            if features is None:
                return {"error": "Missing 'features' in request"}

            X = np.array(features).reshape(1, -1)
            X_tensor = torch.tensor(X, dtype=torch.float32)

            with torch.no_grad():
                prediction = self.model(X_tensor).numpy().tolist()

            return {"prediction": prediction}
        except Exception as e:
            return {"error": str(e)}

app = PyTorchMLflowDeployment.bind()
serve.run(app, route_prefix="/predict")


2025-08-12 15:00:37,849	INFO worker.py:1606 -- Using address ray://ai-starter-kit-kuberay-head-svc:10001 set in the environment variable RAY_ADDRESS
2025-08-12 15:00:37,850	INFO client_builder.py:242 -- Passing the following kwargs to ray.init() on the server: log_to_driver
SIGTERM handler is not set because current thread is not the main thread.
    Ray: 2.48.0
    Python: 3.12.9
This process on Ray Client was started with:
    Ray: 2.48.0
    Python: 3.12.10

[36m(ProxyActor pid=3272)[0m INFO 2025-08-12 08:00:44,231 proxy 10.244.0.9 -- Proxy starting on node 66724a5e2332cd618965646d0b7ab0d4d89622990053ab677dc8f588 (HTTP port: 8000).
[36m(ProxyActor pid=3272)[0m INFO 2025-08-12 08:00:44,333 proxy 10.244.0.9 -- Got updated endpoints: {}.
INFO 2025-08-12 15:00:44,830 serve 72 -- Started Serve in namespace "serve".
[36m(ServeController pid=3174)[0m INFO 2025-08-12 08:00:46,387 controller 3174 -- Deploying new version of Deployment(name='PyTorchMLflowDeployment', app='default') (ini

[36m(ServeReplica:default:PyTorchMLflowDeployment pid=1940, ip=10.244.0.10)[0m Loading model from MLflow...
[36m(ServeReplica:default:PyTorchMLflowDeployment pid=1940, ip=10.244.0.10)[0m Model loaded successfully.


INFO 2025-08-12 15:00:50,456 serve 72 -- Application 'default' is ready at http://127.0.0.1:8000/predict.
INFO 2025-08-12 15:00:50,519 serve 72 -- Started <ray.serve._private.router.SharedRouterLongPollClient object at 0xffff704638f0>.


DeploymentHandle(deployment='PyTorchMLflowDeployment')

2025-08-12 15:00:50,725	ERROR dataclient.py:312 -- Callback error:
Traceback (most recent call last):
  File "/opt/bitnami/jupyterhub-singleuser/.local/lib/python3.12/site-packages/ray/util/client/dataclient.py", line 301, in _process_response
    can_remove = callback(response)
                 ^^^^^^^^^^^^^^^^^^
  File "/opt/bitnami/jupyterhub-singleuser/.local/lib/python3.12/site-packages/ray/util/client/dataclient.py", line 179, in __call__
    self.callback(self.data)
  File "/opt/bitnami/jupyterhub-singleuser/.local/lib/python3.12/site-packages/ray/util/client/common.py", line 179, in deserialize_obj
    py_callback(data)
  File "/opt/bitnami/jupyterhub-singleuser/.local/lib/python3.12/site-packages/ray/util/client/common.py", line 147, in set_future
    fut.set_result(data)
  File "/opt/bitnami/miniforge/lib/python3.12/concurrent/futures/_base.py", line 544, in set_result
    raise InvalidStateError('{}: {!r}'.format(self._state, self))
concurrent.futures._base.InvalidStateError

In [4]:
serve.delete("PyTorchMLflowDeployment")
serve.shutdown()

INFO 2025-08-12 15:01:38,215 serve 72 -- Deleting app ['PyTorchMLflowDeployment']
