# Diabetes Progression Prediction with XGBoost on Spark

This example demonstrates **distributed training** of an XGBoost regressor using Spark to predict diabetes disease progression.

## Execution Modes

This notebook supports two execution modes:
- **Darwin Cluster Mode**: Uses Darwin SDK with Ray for distributed Spark processing (automatic when running on Darwin cluster)
- **Local Mode**: Uses local Spark session for development/testing (automatic fallback when Darwin SDK is not available)

## Dataset

- **Name**: Diabetes Dataset
- **Samples**: 442
- **Features**: 10 baseline variables
- **Target**: Quantitative measure of disease progression one year after baseline
- **Type**: Regression

## Model

- **Framework**: XGBoost with Spark (SparkXGBRegressor)
- **Type**: Distributed Gradient Boosting Regressor
- **Objective**: Squared error regression

## Features

The dataset includes 10 baseline variables:
- `age`: Age in years
- `sex`: Sex
- `bmi`: Body mass index
- `bp`: Average blood pressure
- `s1`: Total serum cholesterol
- `s2`: Low-density lipoproteins
- `s3`: High-density lipoproteins
- `s4`: Total cholesterol / HDL
- `s5`: Log of serum triglycerides level
- `s6`: Blood sugar level

## Key Features

- Training uses `SparkXGBRegressor` for Spark-based XGBoost training
- Data is processed as Spark DataFrames with `VectorAssembler`
- Model is registered to MLflow Model Registry for versioning
- Demonstrates loading model from MLflow URI and running inference
- Compares predicted vs actual values for validation
- **Auto-installs xgboost on Spark executors** (required for distributed training)

In [None]:
# Fix pyOpenSSL/cryptography compatibility issue first
%pip install --upgrade pyOpenSSL cryptography

# Install main dependencies
%pip install xgboost>=2.0.0 pandas numpy scikit-learn mlflow pyspark

In [None]:
import os
import argparse
import json
import tempfile
import numpy as np
import pandas as pd
from datetime import datetime

# XGBoost imports
import xgboost as xgb
from xgboost.spark import SparkXGBRegressor

# Spark imports
from pyspark.sql import SparkSession
from pyspark.ml.feature import VectorAssembler
from pyspark.ml.evaluation import RegressionEvaluator

# MLflow imports
import mlflow
import mlflow.xgboost
from mlflow import set_tracking_uri, set_experiment
from mlflow.client import MlflowClient
from mlflow.models import infer_signature

# Scikit-learn imports (for loading dataset)
from sklearn.datasets import load_diabetes

# Darwin SDK imports (optional - only available on Darwin cluster)
DARWIN_SDK_AVAILABLE = False
try:
    import ray
    from darwin import init_spark_with_configs, stop_spark
    DARWIN_SDK_AVAILABLE = True
    print("Darwin SDK available - will use distributed Spark on Darwin cluster")
except ImportError as e:
    print(f"Darwin SDK not available: {e}")
    print("Running in LOCAL mode - will use local Spark session")
except AttributeError as e:
    # This typically happens with pyOpenSSL/cryptography version mismatch
    if "X509_V_FLAG" in str(e) or "lib" in str(e):
        print("=" * 80)
        print("ERROR: pyOpenSSL/cryptography version conflict detected!")
        print("Please run the following in a cell before importing:")
        print("  %pip install --upgrade pyOpenSSL cryptography")
        print("Then restart the kernel and try again.")
        print("=" * 80)
        raise
    else:
        raise

In [None]:
def install_packages_on_executors(spark):
    """Install required packages on Spark executors using pip."""
    print("Installing packages on Spark executors...")
    
    # Packages required for xgboost.spark on executors
    EXECUTOR_PACKAGES = [
        "xgboost>=2.0.0",
        "scikit-learn",
        "pandas",
        "numpy",
    ]
    
    def install_packages(iterator):
        import subprocess
        import sys
        # Install all required packages on each executor
        subprocess.check_call([
            sys.executable, "-m", "pip", "install", 
            *EXECUTOR_PACKAGES,
            "-q", "--disable-pip-version-check"
        ])
        yield True
    
    # Run installation on all executors
    num_executors = spark.sparkContext.defaultParallelism
    spark.sparkContext.parallelize(range(num_executors), num_executors) \
        .mapPartitions(install_packages) \
        .collect()
    
    print(f"Packages installed on {num_executors} executors: {', '.join(EXECUTOR_PACKAGES)}")


def initialize_spark():
    """Initialize Spark session - uses Darwin SDK on cluster, local Spark otherwise."""
    print("\n" + "=" * 80)
    print("INITIALIZING SPARK SESSION")
    print("=" * 80)
    
    # Spark configurations
    spark_configs = {
        "spark.sql.execution.arrow.pyspark.enabled": "true",
        "spark.sql.session.timeZone": "UTC",
        "spark.sql.shuffle.partitions": "10",
        "spark.default.parallelism": "10",
    }
    
    if DARWIN_SDK_AVAILABLE:
        # Running on Darwin cluster - use distributed Spark via Ray
        print("Mode: Darwin Cluster (Distributed)")
        ray.init()
        spark = init_spark_with_configs(spark_configs=spark_configs)
        
        # Install xgboost on executor nodes
        install_packages_on_executors(spark)
    else:
        # Running locally - use local Spark session
        print("Mode: Local Spark Session")
        builder = SparkSession.builder \
            .appName("XGBoost-Diabetes-Spark-Local") \
            .master("local[*]")
        
        for key, value in spark_configs.items():
            builder = builder.config(key, value)
        
        spark = builder.getOrCreate()
    
    print(f"Spark version: {spark.version}")
    print(f"Application ID: {spark.sparkContext.applicationId}")
    
    return spark


def cleanup_spark(spark):
    """Stop Spark session properly based on environment."""
    print("\nStopping Spark session...")
    if DARWIN_SDK_AVAILABLE:
        stop_spark()
    else:
        spark.stop()
    print("Spark session stopped.")

In [None]:
def setup_mlflow(mlflow_uri: str, username: str, password: str) -> MlflowClient:
    """Configure MLflow tracking and return client."""
    os.environ["MLFLOW_TRACKING_USERNAME"] = username
    os.environ["MLFLOW_TRACKING_PASSWORD"] = password
    
    set_tracking_uri(mlflow_uri)
    client = MlflowClient(mlflow_uri)
    
    print(f"MLflow tracking URI: {mlflow_uri}")
    return client


def load_and_prepare_data(spark: SparkSession):
    """Load Diabetes dataset and prepare Spark DataFrames."""
    print("\n" + "=" * 80)
    print("LOADING DATASET")
    print("=" * 80)
    
    # Load dataset as pandas
    data = load_diabetes(as_frame=True)
    pdf = data.data.copy()
    pdf['target'] = data.target
    
    print(f"Dataset: Diabetes")
    print(f"Samples: {len(pdf):,}")
    print(f"Features: {len(data.feature_names)}")
    
    print(f"\nFeature names:")
    for i, col in enumerate(data.feature_names, 1):
        print(f"  {i}. {col}")
    
    print(f"\nTarget statistics:")
    print(f"  Mean: {pdf['target'].mean():.2f}")
    print(f"  Std: {pdf['target'].std():.2f}")
    print(f"  Min: {pdf['target'].min():.2f}")
    print(f"  Max: {pdf['target'].max():.2f}")
    
    # Convert to Spark DataFrame
    print("\nConverting to Spark DataFrame...")
    df = spark.createDataFrame(pdf)
    
    # Get feature column names (all except target)
    feature_cols = [c for c in df.columns if c != 'target']
    
    # Assemble features into vector column
    assembler = VectorAssembler(inputCols=feature_cols, outputCol="features")
    df = assembler.transform(df)
    
    # Split data
    train_df, test_df = df.randomSplit([0.8, 0.2], seed=42)
    
    train_count = train_df.count()
    test_count = test_df.count()
    
    print(f"\nTrain samples: {train_count:,}")
    print(f"Test samples: {test_count:,}")
    print(f"Spark partitions: {train_df.rdd.getNumPartitions()}")
    
    return train_df, test_df, feature_cols


def train_model(train_df, test_df, hyperparams: dict):
    """Train XGBoost model using SparkXGBRegressor (distributed training)."""
    print("\n" + "=" * 80)
    print("TRAINING MODEL (Distributed on Spark)")
    print("=" * 80)
    
    print("Hyperparameters:")
    for key, value in hyperparams.items():
        print(f"  {key}: {value}")
    
    # Create SparkXGBRegressor for distributed training
    xgb_regressor = SparkXGBRegressor(
        features_col="features",
        label_col="target",
        prediction_col="prediction",
        objective="reg:squarederror",
        max_depth=hyperparams.get("max_depth", 5),
        learning_rate=hyperparams.get("learning_rate", 0.1),
        n_estimators=hyperparams.get("n_estimators", 100),
        subsample=hyperparams.get("subsample", 0.8),
        colsample_bytree=hyperparams.get("colsample_bytree", 0.8),
        random_state=hyperparams.get("random_state", 42),
    )
    
    # Train model (distributed across Spark executors)
    print("\nTraining distributed XGBoost model...")
    model = xgb_regressor.fit(train_df)
    
    print("Training completed!")
    
    # Make predictions
    train_pred = model.transform(train_df)
    test_pred = model.transform(test_df)
    
    return model, train_pred, test_pred


def calculate_metrics(predictions_df, label_col="target", prediction_col="prediction", dataset_name="Test"):
    """Calculate evaluation metrics using Spark's RegressionEvaluator."""
    evaluator = RegressionEvaluator(labelCol=label_col, predictionCol=prediction_col)
    
    rmse = evaluator.setMetricName("rmse").evaluate(predictions_df)
    mae = evaluator.setMetricName("mae").evaluate(predictions_df)
    r2 = evaluator.setMetricName("r2").evaluate(predictions_df)
    
    return {
        f"{dataset_name.lower()}_rmse": rmse,
        f"{dataset_name.lower()}_mae": mae,
        f"{dataset_name.lower()}_r2": r2
    }


def log_to_mlflow(model, train_df, train_pred, test_pred, hyperparams, feature_names):
    """Log model, parameters, and metrics to MLflow."""
    print("\n" + "=" * 80)
    print("LOGGING TO MLFLOW")
    print("=" * 80)
    
    # Log hyperparameters
    for key, value in hyperparams.items():
        mlflow.log_param(key, value)
    
    # Log additional Spark-specific params
    mlflow.log_param("training_framework", "spark_xgboost")
    mlflow.log_param("distributed", True)
    
    # Calculate and log metrics
    train_metrics = calculate_metrics(train_pred, dataset_name="Train")
    test_metrics = calculate_metrics(test_pred, dataset_name="Test")
    all_metrics = {**train_metrics, **test_metrics}
    
    for metric_name, metric_value in all_metrics.items():
        mlflow.log_metric(metric_name, metric_value)
    
    print("\nModel Performance:")
    print(f"  Training RMSE: {train_metrics['train_rmse']:.4f}")
    print(f"  Training R2: {train_metrics['train_r2']:.4f}")
    print(f"  Test RMSE: {test_metrics['test_rmse']:.4f}")
    print(f"  Test MAE: {test_metrics['test_mae']:.4f}")
    print(f"  Test R2: {test_metrics['test_r2']:.4f}")
    
    # Get native XGBoost model (Booster) from Spark model
    native_model = model.get_booster()
    
    # Create sample input for signature (pandas DataFrame)
    sample_input = train_df.select(feature_names).limit(1).toPandas()
    
    # Try to log model artifacts - with fallback for server issues
    model_logged = False
    
    # Approach 1: Try simple model file upload (most reliable)
    print("\nSaving model artifacts...")
    try:
        with tempfile.TemporaryDirectory() as tmpdir:
            # Save just the XGBoost model file (minimal, most compatible)
            model_file = os.path.join(tmpdir, "model.json")
            native_model.save_model(model_file)
            
            # Save input example as JSON
            input_example_file = os.path.join(tmpdir, "input_example.json")
            sample_input.to_json(input_example_file, orient="records", indent=2)
            
            # Log individual files
            mlflow.log_artifact(model_file, artifact_path="model")
            mlflow.log_artifact(input_example_file, artifact_path="model")
            
            model_logged = True
            print("Model artifacts logged successfully (minimal format)!")
    except Exception as e:
        print(f"Warning: Could not log model artifacts: {e}")
        print("Metrics and parameters were logged successfully.")
        print("You can save the model locally and upload manually if needed.")
    
    # Store native model reference for later use
    all_metrics["_native_model"] = native_model
    all_metrics["_model_logged"] = model_logged
    
    return all_metrics


def register_model(client: MlflowClient, model_name: str, run_id: str, experiment_id: str):
    """Register model in MLflow Model Registry."""
    print("\n" + "=" * 80)
    print("REGISTERING MODEL")
    print("=" * 80)
    
    model_uri = f"runs:/{run_id}/model"
    
    # Create registered model if it doesn't exist
    try:
        client.get_registered_model(model_name)
        print(f"Model '{model_name}' already exists in registry")
    except Exception:
        try:
            client.create_registered_model(model_name)
            print(f"Created registered model: {model_name}")
        except Exception as e:
            print(f"Could not create registered model: {e}")
    
    # Create model version
    try:
        result = client.create_model_version(
            name=model_name,
            source=model_uri,
            run_id=run_id
        )
        print(f"Model version registered successfully!")
        print(f"   Model Name: {model_name}")
        print(f"   Version: {result.version}")
        print(f"   Run ID: {run_id}")
        return result.version
    except Exception as e:
        print(f"Model registration failed (model still usable via run URI): {e}")
        print(f"   You can deploy using: mlflow-artifacts:/{experiment_id}/{run_id}/artifacts/model")
        return None


def load_model_and_predict(model_uri: str, sample_data: pd.DataFrame, native_model=None):
    """Load model from MLflow URI and run prediction.
    
    Args:
        model_uri: MLflow model URI (e.g., runs:/{run_id}/model)
        sample_data: Pandas DataFrame with feature data
        native_model: Optional - use this model directly instead of loading from MLflow
    """
    print("\n" + "=" * 80)
    print("LOADING MODEL AND RUNNING PREDICTION")
    print("=" * 80)
    
    if native_model is not None:
        # Use the native model directly (no need to load from MLflow)
        print("Using in-memory model for prediction")
        loaded_model = native_model
    else:
        # Try to load from MLflow
        print(f"Model URI: {model_uri}")
        try:
            # Try MLflow's xgboost loader first
            loaded_model = mlflow.xgboost.load_model(model_uri)
            print("Model loaded from MLflow successfully!")
        except Exception as e:
            # Fallback: Download artifact and load manually
            print(f"MLflow loader failed, trying artifact download: {e}")
            try:
                client = MlflowClient()
                run_id = model_uri.split("/")[1] if model_uri.startswith("runs:/") else model_uri
                
                with tempfile.TemporaryDirectory() as tmpdir:
                    local_path = client.download_artifacts(run_id, "model/model.json", tmpdir)
                    loaded_model = xgb.Booster()
                    loaded_model.load_model(local_path)
                    print("Model loaded from artifact successfully!")
            except Exception as e2:
                print(f"Could not load model: {e2}")
                raise
    
    # Create DMatrix for prediction
    dmatrix = xgb.DMatrix(sample_data)
    
    # Run prediction
    predictions = loaded_model.predict(dmatrix)
    
    print("\n" + "=" * 80)
    print("SAMPLE PREDICTION RESULTS")
    print("=" * 80)
    print(f"\nInput features:")
    for col in sample_data.columns:
        print(f"  {col}: {sample_data[col].iloc[0]:.4f}")
    print(f"\nPredicted diabetes progression: {predictions[0]:.2f}")
    
    return predictions


def print_deployment_info(run_id: str, experiment_id: str, model_name: str, model_version: str):
    """Print deployment instructions and sample payloads."""
    print("\n" + "=" * 80)
    print("TRAINING COMPLETE!")
    print("=" * 80)
    
    print(f"\nRun Information:")
    print(f"  Run ID: {run_id}")
    print(f"  Experiment ID: {experiment_id}")
    print(f"  Model URI (run): runs:/{run_id}/model")
    if model_version:
        print(f"  Model URI (registry): models:/{model_name}/{model_version}")
    
    print("\n" + "=" * 80)
    print("DEPLOYMENT PAYLOAD (deploy-model API)")
    print("=" * 80)
    
    deploy_payload = {
        "serve_name": "diabetes-xgboost-spark-regressor",
        "model_uri": f"mlflow-artifacts:/{experiment_id}/{run_id}/artifacts/model",
        "env": "local",
        "cores": 2,
        "memory": 4,
        "node_capacity": "spot",
        "min_replicas": 1,
        "max_replicas": 3
    }
    
    print(json.dumps(deploy_payload, indent=2))

In [None]:
def main():
    parser = argparse.ArgumentParser(description="Train XGBoost Diabetes Regression Model on Spark")
    parser.add_argument(
        "--mlflow-uri",
        default="http://darwin-mlflow-lib.darwin.svc.cluster.local:8080",
        help="MLflow tracking URI"
    )
    parser.add_argument(
        "--username",
        default="abc@gmail.com",
        help="MLflow username"
    )
    parser.add_argument(
        "--password",
        default="password",
        help="MLflow password"
    )
    parser.add_argument(
        "--experiment-name",
        default="diabetes_xgboost_spark_regression",
        help="MLflow experiment name"
    )
    parser.add_argument(
        "--model-name",
        default="DiabetesXGBoostSparkRegressor",
        help="Registered model name"
    )
    
    args, _ = parser.parse_known_args()
    
    print("\n" + "=" * 80)
    print("DIABETES PROGRESSION PREDICTION WITH XGBOOST ON SPARK")
    print("=" * 80)
    print(f"Started at: {datetime.now().strftime('%Y-%m-%d %H:%M:%S')}")
    
    # Initialize Spark
    spark = initialize_spark()
    
    # Setup MLflow
    client = setup_mlflow(args.mlflow_uri, args.username, args.password)
    set_experiment(experiment_name=args.experiment_name)
    print(f"Experiment: {args.experiment_name}")
    
    # Load data
    train_df, test_df, feature_names = load_and_prepare_data(spark)
    
    # Define hyperparameters
    hyperparams = {
        "objective": "reg:squarederror",
        "max_depth": 5,
        "learning_rate": 0.1,
        "n_estimators": 100,
        "subsample": 0.8,
        "colsample_bytree": 0.8,
        "random_state": 42
    }
    
    # Start MLflow run
    with mlflow.start_run(run_name=f"spark_xgboost_diabetes_{datetime.now().strftime('%Y%m%d_%H%M%S')}"):
        # Train model (distributed on Spark)
        model, train_pred, test_pred = train_model(
            train_df, test_df, hyperparams
        )
        
        # Log to MLflow
        metrics = log_to_mlflow(
            model, train_df, train_pred, test_pred, hyperparams, feature_names
        )
        
        # Get run information
        run_id = mlflow.active_run().info.run_id
        experiment_id = mlflow.active_run().info.experiment_id
    
    # Register model (outside of run context) - only if artifacts were logged
    model_version = None
    if metrics.get("_model_logged", False):
        model_version = register_model(client, args.model_name, run_id, experiment_id)
    else:
        print("\nSkipping model registration (artifacts not logged to MLflow)")
    
    # Demonstrate running inference
    model_uri = f"runs:/{run_id}/model"
    sample_pdf = test_df.select(feature_names).limit(1).toPandas()
    
    # Use the native model for prediction (more reliable than loading from MLflow)
    native_model = metrics.get("_native_model")
    predictions = load_model_and_predict(model_uri, sample_pdf, native_model=native_model)
    
    # Get actual value for comparison
    actual_value = test_df.select("target").limit(1).toPandas()["target"].iloc[0]
    print(f"Actual diabetes progression: {actual_value:.2f}")
    print(f"Prediction error: {abs(predictions[0] - actual_value):.2f}")
    
    # Print deployment information
    print_deployment_info(run_id, experiment_id, args.model_name, model_version)
    
    # Cleanup: Stop Spark session
    cleanup_spark(spark)
    
    print("\nScript completed successfully!")
    print("=" * 80 + "\n")


if __name__ == "__main__":
    main()