# Iris Species Classification: Spark Data Processing + Sklearn Training

This example demonstrates a hybrid approach using **Spark for data processing** and **sklearn for model training**.

## Architecture

- **Data Processing**: PySpark for distributed ETL operations (scales to large datasets)
- **Training**: sklearn RandomForestClassifier (reliable, fast, easy to deploy)
- **Model Logging**: MLflow sklearn flavor (compatible with all MLflow servers)

## Execution Modes

This notebook supports two execution modes:
- **Darwin Cluster Mode**: Uses Darwin SDK with Ray for distributed Spark processing
- **Local Mode**: Uses local Spark session for development/testing

## Dataset

- **Name**: Iris Dataset (Fisher, 1936)
- **Samples**: 150 iris flowers
- **Target**: Species (Setosa, Versicolor, Virginica)
- **Type**: Multi-class Classification

## Model

- **Framework**: sklearn RandomForestClassifier
- **Data Processing**: PySpark
- **Objective**: Multi-class classification

## Features

The dataset includes 4 measurements:
- `sepal_length`: Sepal length (cm)
- `sepal_width`: Sepal width (cm)
- `petal_length`: Petal length (cm)
- `petal_width`: Petal width (cm)

## Key Features

- Spark handles data loading, transformation, and splitting (can scale to big data)
- sklearn handles model training (reliable MLflow integration)
- Model logged using `mlflow.sklearn` flavor (works with any MLflow server)
- Fast model loading at serving time (no Java/Spark dependencies needed)

In [None]:
# Fix pyOpenSSL/cryptography compatibility issue first
%pip install --upgrade pyOpenSSL cryptography

# Install main dependencies (pin MLflow to match server version)
%pip install pandas numpy scikit-learn mlflow==2.12.2 pyspark

In [None]:
import os
import argparse
import json
import tempfile
import numpy as np
import pandas as pd
from datetime import datetime

# Spark imports (for data processing only)
from pyspark.sql import SparkSession
from pyspark.sql.functions import col

# MLflow imports
import mlflow
import mlflow.sklearn
from mlflow import set_tracking_uri, set_experiment
from mlflow.client import MlflowClient
from mlflow.models import infer_signature

# Scikit-learn imports (for training and metrics)
from sklearn.datasets import load_iris
from sklearn.ensemble import RandomForestClassifier
from sklearn.metrics import accuracy_score, precision_score, recall_score, f1_score, confusion_matrix

# Darwin SDK imports (optional - only available on Darwin cluster)
DARWIN_SDK_AVAILABLE = False
try:
    import ray
    from darwin import init_spark_with_configs, stop_spark
    DARWIN_SDK_AVAILABLE = True
    print("Darwin SDK available - will use distributed Spark on Darwin cluster")
except ImportError as e:
    print(f"Darwin SDK not available: {e}")
    print("Running in LOCAL mode - will use local Spark session")
except AttributeError as e:
    if "X509_V_FLAG" in str(e) or "lib" in str(e):
        print("=" * 80)
        print("ERROR: pyOpenSSL/cryptography version conflict detected!")
        print("Please run the following in a cell before importing:")
        print("  %pip install --upgrade pyOpenSSL cryptography")
        print("Then restart the kernel and try again.")
        print("=" * 80)
        raise
    else:
        raise

In [None]:
def initialize_spark():
    """Initialize Spark session for data processing.
    
    Uses Darwin SDK on cluster, local Spark otherwise.
    Spark is used for distributed data processing (ETL, splitting).
    Training is done with sklearn on the driver.
    """
    print("\n" + "=" * 80)
    print("INITIALIZING SPARK SESSION")
    print("=" * 80)
    
    # Base Spark configurations for data processing
    spark_configs = {
        "spark.sql.execution.arrow.pyspark.enabled": "true",
        "spark.sql.session.timeZone": "UTC",
        "spark.sql.shuffle.partitions": "4",
        "spark.default.parallelism": "4",
        "spark.executor.memory": "1g",
        "spark.executor.cores": "1",
        "spark.driver.memory": "1g",
        "spark.executor.instances": "2",
    }
    
    if DARWIN_SDK_AVAILABLE:
        # Running on Darwin cluster - use distributed Spark via Ray
        print("Mode: Darwin Cluster (Distributed)")
        ray.init()
        spark = init_spark_with_configs(spark_configs=spark_configs)
    else:
        # Running locally - use local Spark session
        print("Mode: Local Spark Session")
        builder = SparkSession.builder \
            .appName("Iris-Spark-DataProcessing") \
            .master("local[*]")
        
        for key, value in spark_configs.items():
            builder = builder.config(key, value)
        
        spark = builder.getOrCreate()
    
    print(f"Spark version: {spark.version}")
    print(f"Application ID: {spark.sparkContext.applicationId}")
    
    return spark


def cleanup_spark(spark):
    """Stop Spark session properly based on environment."""
    print("\nStopping Spark session...")
    if DARWIN_SDK_AVAILABLE:
        stop_spark()
    else:
        spark.stop()
    print("Spark session stopped.")

In [None]:
def setup_mlflow(mlflow_uri: str, username: str, password: str) -> MlflowClient:
    """Configure MLflow tracking and return client."""
    os.environ["MLFLOW_TRACKING_USERNAME"] = username
    os.environ["MLFLOW_TRACKING_PASSWORD"] = password
    
    set_tracking_uri(mlflow_uri)
    client = MlflowClient(mlflow_uri)
    
    print(f"MLflow tracking URI: {mlflow_uri}")
    return client


def load_and_prepare_data(spark: SparkSession):
    """Load Iris dataset using Spark for processing, return pandas for training.
    
    Uses Spark for distributed data operations (can scale to large datasets).
    Returns pandas DataFrames for sklearn training.
    """
    print("\n" + "=" * 80)
    print("LOADING DATASET")
    print("=" * 80)
    
    # Load dataset
    iris = load_iris(as_frame=True)
    pdf = iris.data.copy()
    pdf.columns = ['sepal_length', 'sepal_width', 'petal_length', 'petal_width']
    pdf['label'] = iris.target
    
    target_names = iris.target_names.tolist()
    feature_names = ['sepal_length', 'sepal_width', 'petal_length', 'petal_width']
    
    print(f"Dataset: Iris")
    print(f"Samples: {len(pdf):,}")
    print(f"Features: {len(feature_names)}")
    
    print(f"\nFeature names:")
    for i, name in enumerate(feature_names, 1):
        print(f"  {i}. {name}")
    
    print(f"\nTarget classes:")
    for i, name in enumerate(target_names):
        count = (pdf['label'] == i).sum()
        print(f"  {i}. {name}: {count} samples")
    
    # Use Spark for distributed data splitting (demonstrates Spark processing)
    print("\nUsing Spark for distributed data splitting...")
    spark_df = spark.createDataFrame(pdf)
    train_spark, test_spark = spark_df.randomSplit([0.8, 0.2], seed=42)
    
    # Collect to pandas for sklearn training
    print("Collecting to pandas for training...")
    train_pdf = train_spark.toPandas()
    test_pdf = test_spark.toPandas()
    
    print(f"\nTrain samples: {len(train_pdf):,}")
    print(f"Test samples: {len(test_pdf):,}")
    
    return train_pdf, test_pdf, feature_names, target_names


def train_model(train_pdf, test_pdf, feature_names, hyperparams: dict):
    """Train Random Forest model using sklearn."""
    print("\n" + "=" * 80)
    print("TRAINING MODEL (sklearn RandomForest)")
    print("=" * 80)
    
    print("Hyperparameters:")
    for key, value in hyperparams.items():
        print(f"  {key}: {value}")
    
    # Prepare data
    X_train = train_pdf[feature_names].values
    y_train = train_pdf['label'].values
    X_test = test_pdf[feature_names].values
    y_test = test_pdf['label'].values
    
    # Create and train sklearn Random Forest
    print("\nTraining Random Forest model...")
    model = RandomForestClassifier(
        n_estimators=hyperparams.get("n_estimators", 100),
        max_depth=hyperparams.get("max_depth", 10),
        min_samples_leaf=hyperparams.get("min_samples_leaf", 1),
        random_state=hyperparams.get("random_state", 42),
        n_jobs=-1
    )
    model.fit(X_train, y_train)
    
    print("Training completed!")
    print(f"Number of trees: {model.n_estimators}")
    
    # Make predictions
    y_train_pred = model.predict(X_train)
    y_test_pred = model.predict(X_test)
    
    return model, X_train, y_train, X_test, y_test, y_train_pred, y_test_pred


def calculate_metrics(y_true, y_pred, dataset_name="Test"):
    """Calculate evaluation metrics using sklearn."""
    return {
        f"{dataset_name.lower()}_accuracy": accuracy_score(y_true, y_pred),
        f"{dataset_name.lower()}_precision": precision_score(y_true, y_pred, average='weighted'),
        f"{dataset_name.lower()}_recall": recall_score(y_true, y_pred, average='weighted'),
        f"{dataset_name.lower()}_f1": f1_score(y_true, y_pred, average='weighted')
    }


def log_to_mlflow(model, X_train, y_train, y_test, y_train_pred, y_test_pred, 
                  hyperparams, feature_names, target_names):
    """Log sklearn model, parameters, and metrics to MLflow."""
    print("\n" + "=" * 80)
    print("LOGGING TO MLFLOW")
    print("=" * 80)
    
    # Log hyperparameters
    for key, value in hyperparams.items():
        mlflow.log_param(key, value)
    
    # Log additional params
    mlflow.log_param("training_framework", "sklearn")
    mlflow.log_param("data_processing", "spark")
    mlflow.log_param("num_trees", model.n_estimators)
    
    # Calculate and log metrics
    train_metrics = calculate_metrics(y_train, y_train_pred, dataset_name="Train")
    test_metrics = calculate_metrics(y_test, y_test_pred, dataset_name="Test")
    all_metrics = {**train_metrics, **test_metrics}
    
    for metric_name, metric_value in all_metrics.items():
        mlflow.log_metric(metric_name, metric_value)
    
    print("\nModel Performance:")
    print(f"  Training Accuracy: {train_metrics['train_accuracy']:.4f}")
    print(f"  Training F1: {train_metrics['train_f1']:.4f}")
    print(f"  Test Accuracy: {test_metrics['test_accuracy']:.4f}")
    print(f"  Test Precision: {test_metrics['test_precision']:.4f}")
    print(f"  Test Recall: {test_metrics['test_recall']:.4f}")
    print(f"  Test F1: {test_metrics['test_f1']:.4f}")
    
    # Print confusion matrix
    cm = confusion_matrix(y_test, y_test_pred)
    print(f"\n  Confusion Matrix:")
    print(f"  {cm}")
    
    # Log model using mlflow.sklearn
    model_logged = False
    print("\nSaving model artifacts...")
    try:
        # Create sample input for signature
        sample_input = pd.DataFrame([X_train[0]], columns=feature_names)
        sample_output = pd.DataFrame({"prediction": [0]})
        signature = infer_signature(sample_input, sample_output)
        
        # Log sklearn model
        print("  Logging to MLflow using sklearn flavor...")
        mlflow.sklearn.log_model(
            sk_model=model,
            artifact_path="model",
            signature=signature,
            input_example=sample_input
        )
        
        # Also log target names as additional artifact
        local_model_dir = tempfile.mkdtemp(prefix="sklearn_model_")
        target_names_file = os.path.join(local_model_dir, "target_names.json")
        with open(target_names_file, "w") as f:
            json.dump(target_names, f)
        mlflow.log_artifact(target_names_file, artifact_path="model")
        
        model_logged = True
        print("Model artifacts logged successfully!")
    except Exception as e:
        print(f"Warning: Could not log model artifacts: {e}")
        import traceback
        traceback.print_exc()
        print("Metrics and parameters were logged successfully.")
    
    all_metrics["_model_logged"] = model_logged
    
    return all_metrics


def register_model(client: MlflowClient, model_name: str, run_id: str, experiment_id: str):
    """Register model in MLflow Model Registry."""
    print("\n" + "=" * 80)
    print("REGISTERING MODEL")
    print("=" * 80)
    
    model_uri = f"runs:/{run_id}/model"
    
    # Create registered model if it doesn't exist
    try:
        client.get_registered_model(model_name)
        print(f"Model '{model_name}' already exists in registry")
    except Exception:
        try:
            client.create_registered_model(model_name)
            print(f"Created registered model: {model_name}")
        except Exception as e:
            print(f"Could not create registered model: {e}")
    
    # Create model version
    try:
        result = client.create_model_version(
            name=model_name,
            source=model_uri,
            run_id=run_id
        )
        print(f"Model version registered successfully!")
        print(f"   Model Name: {model_name}")
        print(f"   Version: {result.version}")
        print(f"   Run ID: {run_id}")
        return result.version
    except Exception as e:
        print(f"Model registration failed (model still usable via run URI): {e}")
        print(f"   You can deploy using: mlflow-artifacts:/{experiment_id}/{run_id}/artifacts/model")
        return None


def print_deployment_info(run_id: str, experiment_id: str, model_name: str, model_version: str):
    """Print deployment instructions and sample payloads."""
    print("\n" + "=" * 80)
    print("TRAINING COMPLETE!")
    print("=" * 80)
    
    print(f"\nRun Information:")
    print(f"  Run ID: {run_id}")
    print(f"  Experiment ID: {experiment_id}")
    print(f"  Model URI (run): runs:/{run_id}/model")
    if model_version:
        print(f"  Model URI (registry): models:/{model_name}/{model_version}")
    
    print("\n" + "=" * 80)
    print("DEPLOYMENT PAYLOAD (deploy-model API)")
    print("=" * 80)
    
    deploy_payload = {
        "serve_name": "iris-rf-spark-classifier",
        "model_uri": f"mlflow-artifacts:/{experiment_id}/{run_id}/artifacts/model",
        "env": "local",
        "cores": 2,
        "memory": 4,
        "node_capacity": "spot",
        "min_replicas": 1,
        "max_replicas": 3
    }
    
    print(json.dumps(deploy_payload, indent=2))

In [None]:
def main():
    parser = argparse.ArgumentParser(description="Train Iris Classification Model (Spark data processing + sklearn training)")
    parser.add_argument(
        "--mlflow-uri",
        default="http://darwin-mlflow-lib.darwin.svc.cluster.local:8080",
        help="MLflow tracking URI"
    )
    parser.add_argument(
        "--username",
        default="abc@gmail.com",
        help="MLflow username"
    )
    parser.add_argument(
        "--password",
        default="password",
        help="MLflow password"
    )
    parser.add_argument(
        "--experiment-name",
        default="iris_spark_sklearn_classification",
        help="MLflow experiment name"
    )
    parser.add_argument(
        "--model-name",
        default="IrisSklearnRFClassifier",
        help="Registered model name"
    )
    
    args, _ = parser.parse_known_args()
    
    print("\n" + "=" * 80)
    print("IRIS CLASSIFICATION: SPARK DATA PROCESSING + SKLEARN TRAINING")
    print("=" * 80)
    print(f"Started at: {datetime.now().strftime('%Y-%m-%d %H:%M:%S')}")
    
    # Initialize Spark for data processing
    spark = initialize_spark()
    
    # Setup MLflow
    client = setup_mlflow(args.mlflow_uri, args.username, args.password)
    set_experiment(experiment_name=args.experiment_name)
    print(f"Experiment: {args.experiment_name}")
    
    # Load and prepare data using Spark (returns pandas DataFrames)
    train_pdf, test_pdf, feature_names, target_names = load_and_prepare_data(spark)
    
    # Define hyperparameters
    hyperparams = {
        "n_estimators": 100,
        "max_depth": 10,
        "min_samples_leaf": 1,
        "random_state": 42,
    }
    
    # Start MLflow run
    with mlflow.start_run(run_name=f"sklearn_rf_iris_{datetime.now().strftime('%Y%m%d_%H%M%S')}"):
        # Train sklearn model
        model, X_train, y_train, X_test, y_test, y_train_pred, y_test_pred = train_model(
            train_pdf, test_pdf, feature_names, hyperparams
        )
        
        # Log to MLflow
        metrics = log_to_mlflow(
            model, X_train, y_train, y_test, y_train_pred, y_test_pred,
            hyperparams, feature_names, target_names
        )
        
        # Get run information
        run_id = mlflow.active_run().info.run_id
        experiment_id = mlflow.active_run().info.experiment_id
    
    # Register model (outside of run context) - only if artifacts were logged
    model_version = None
    if metrics.get("_model_logged", False):
        model_version = register_model(client, args.model_name, run_id, experiment_id)
    else:
        print("\nSkipping model registration (artifacts not logged to MLflow)")
    
    # Demo prediction with sklearn model
    print("\n" + "=" * 80)
    print("SAMPLE PREDICTION")
    print("=" * 80)
    sample_idx = 0
    sample_features = X_test[sample_idx:sample_idx+1]
    prediction = model.predict(sample_features)[0]
    probabilities = model.predict_proba(sample_features)[0]
    
    print(f"\nInput features:")
    for i, name in enumerate(feature_names):
        print(f"  {name}: {sample_features[0][i]:.4f}")
    
    print(f"\nClass Probabilities:")
    for i, prob in enumerate(probabilities):
        print(f"  {target_names[i]}: {prob:.4f}")
    
    predicted_species = target_names[prediction]
    actual_species = target_names[y_test[sample_idx]]
    print(f"\nPredicted: {prediction} ({predicted_species})")
    print(f"Actual: {y_test[sample_idx]} ({actual_species})")
    print(f"Correct: {prediction == y_test[sample_idx]}")
    
    # Print deployment information
    print_deployment_info(run_id, experiment_id, args.model_name, model_version)
    
    # Cleanup: Stop Spark session
    cleanup_spark(spark)
    
    print("\nScript completed successfully!")
    print("=" * 80 + "\n")


if __name__ == "__main__":
    main()