# Iris Species Classification with Random Forest on Spark

This example demonstrates **distributed training** of a Random Forest classifier using PySpark ML to predict iris species.

## Execution Modes

This notebook supports two execution modes:
- **Darwin Cluster Mode**: Uses Darwin SDK with Ray for distributed Spark processing (automatic when running on Darwin cluster)
- **Local Mode**: Uses local Spark session for development/testing (automatic fallback when Darwin SDK is not available)

## Dataset

- **Name**: Iris Dataset (Fisher, 1936)
- **Samples**: 150 iris flowers
- **Target**: Species (Setosa, Versicolor, Virginica)
- **Type**: Multi-class Classification

## Model

- **Framework**: PySpark ML RandomForestClassifier
- **Type**: Distributed Random Forest Classifier
- **Objective**: Multi-class classification

## Features

The dataset includes 4 measurements:
- `sepal_length`: Sepal length (cm)
- `sepal_width`: Sepal width (cm)
- `petal_length`: Petal length (cm)
- `petal_width`: Petal width (cm)

## Key Features

- Training uses PySpark ML `RandomForestClassifier` for distributed training
- Data is processed as Spark DataFrames with `VectorAssembler`
- Model is registered to MLflow Model Registry for versioning
- Demonstrates loading model and running inference on driver
- **Auto-installs required packages on Spark executors**

In [None]:
# Fix pyOpenSSL/cryptography compatibility issue first
%pip install --upgrade pyOpenSSL cryptography

# Install main dependencies
%pip install pandas numpy scikit-learn mlflow pyspark

In [None]:
import os
import argparse
import json
import tempfile
import numpy as np
import pandas as pd
from datetime import datetime

# Spark imports
from pyspark.sql import SparkSession
from pyspark.ml.feature import VectorAssembler
from pyspark.ml.classification import RandomForestClassifier as SparkRFClassifier
from pyspark.ml.classification import RandomForestClassificationModel
from pyspark.ml.evaluation import MulticlassClassificationEvaluator
from pyspark.sql.functions import col

# MLflow imports
import mlflow
import mlflow.spark
from mlflow import set_tracking_uri, set_experiment
from mlflow.client import MlflowClient
from mlflow.models import infer_signature

# Scikit-learn imports (for loading dataset and metrics)
from sklearn.datasets import load_iris
from sklearn.metrics import accuracy_score, precision_score, recall_score, f1_score, confusion_matrix

# Darwin SDK imports (optional - only available on Darwin cluster)
DARWIN_SDK_AVAILABLE = False
try:
    import ray
    from darwin import init_spark_with_configs, stop_spark
    DARWIN_SDK_AVAILABLE = True
    print("Darwin SDK available - will use distributed Spark on Darwin cluster")
except ImportError as e:
    print(f"Darwin SDK not available: {e}")
    print("Running in LOCAL mode - will use local Spark session")
except AttributeError as e:
    if "X509_V_FLAG" in str(e) or "lib" in str(e):
        print("=" * 80)
        print("ERROR: pyOpenSSL/cryptography version conflict detected!")
        print("Please run the following in a cell before importing:")
        print("  %pip install --upgrade pyOpenSSL cryptography")
        print("Then restart the kernel and try again.")
        print("=" * 80)
        raise
    else:
        raise

In [None]:
def install_packages_on_executors(spark):
    """Install required packages on Spark executors using pip."""
    print("Installing packages on Spark executors...")
    
    # Packages required for data processing on executors
    EXECUTOR_PACKAGES = [
        "scikit-learn",
        "pandas",
        "numpy",
    ]
    
    def install_packages(iterator):
        import subprocess
        import sys
        # Install all required packages on each executor
        subprocess.check_call([
            sys.executable, "-m", "pip", "install", 
            *EXECUTOR_PACKAGES,
            "-q", "--disable-pip-version-check"
        ])
        yield True
    
    # Run installation on all executors
    num_executors = spark.sparkContext.defaultParallelism
    spark.sparkContext.parallelize(range(num_executors), num_executors) \
        .mapPartitions(install_packages) \
        .collect()
    
    print(f"Packages installed on {num_executors} executors: {', '.join(EXECUTOR_PACKAGES)}")


def initialize_spark():
    """Initialize Spark session - uses Darwin SDK on cluster, local Spark otherwise."""
    print("\n" + "=" * 80)
    print("INITIALIZING SPARK SESSION")
    print("=" * 80)
    
    # Base Spark configurations
    # Resource settings optimized for smaller cluster nodes to avoid OOM
    spark_configs = {
        "spark.sql.execution.arrow.pyspark.enabled": "true",
        "spark.sql.session.timeZone": "UTC",
        "spark.sql.shuffle.partitions": "4",
        "spark.default.parallelism": "4",
        # Executor resource settings - keep low to avoid OOM
        "spark.executor.memory": "1g",
        "spark.executor.cores": "1",
        "spark.driver.memory": "1g",
        # Limit number of executors
        "spark.executor.instances": "2",
    }
    
    if DARWIN_SDK_AVAILABLE:
        # Running on Darwin cluster - use distributed Spark via Ray
        print("Mode: Darwin Cluster (Distributed)")
        print("Resource settings: 2 executors, 1 core each, 1GB memory each")
        ray.init()
        spark = init_spark_with_configs(spark_configs=spark_configs)
        
        # Install packages on executor nodes
        install_packages_on_executors(spark)
    else:
        # Running locally - use local Spark session
        print("Mode: Local Spark Session")
        builder = SparkSession.builder \
            .appName("Iris-RandomForest-Spark-Local") \
            .master("local[*]")
        
        for key, value in spark_configs.items():
            builder = builder.config(key, value)
        
        spark = builder.getOrCreate()
    
    print(f"Spark version: {spark.version}")
    print(f"Application ID: {spark.sparkContext.applicationId}")
    
    return spark


def cleanup_spark(spark):
    """Stop Spark session properly based on environment."""
    print("\nStopping Spark session...")
    if DARWIN_SDK_AVAILABLE:
        stop_spark()
    else:
        spark.stop()
    print("Spark session stopped.")

In [None]:
def setup_mlflow(mlflow_uri: str, username: str, password: str) -> MlflowClient:
    """Configure MLflow tracking and return client."""
    os.environ["MLFLOW_TRACKING_USERNAME"] = username
    os.environ["MLFLOW_TRACKING_PASSWORD"] = password
    
    set_tracking_uri(mlflow_uri)
    client = MlflowClient(mlflow_uri)
    
    print(f"MLflow tracking URI: {mlflow_uri}")
    return client


def load_and_prepare_data(spark: SparkSession):
    """Load Iris dataset and prepare Spark DataFrames."""
    print("\n" + "=" * 80)
    print("LOADING DATASET")
    print("=" * 80)
    
    # Load dataset as pandas
    iris = load_iris(as_frame=True)
    pdf = iris.data.copy()
    pdf.columns = ['sepal_length', 'sepal_width', 'petal_length', 'petal_width']
    pdf['label'] = iris.target.astype(float)  # Spark ML requires double type for labels
    
    target_names = iris.target_names.tolist()
    
    print(f"Dataset: Iris")
    print(f"Samples: {len(pdf):,}")
    print(f"Features: {len(iris.feature_names)}")
    
    print(f"\nFeature names:")
    for i, col_name in enumerate(pdf.columns[:-1], 1):
        print(f"  {i}. {col_name}")
    
    print(f"\nTarget classes:")
    for i, name in enumerate(target_names):
        count = (pdf['label'] == i).sum()
        print(f"  {i}. {name}: {count} samples")
    
    # Convert to Spark DataFrame
    print("\nConverting to Spark DataFrame...")
    df = spark.createDataFrame(pdf)
    
    # Get feature column names (all except label)
    feature_cols = [c for c in df.columns if c != 'label']
    
    # Assemble features into vector column
    assembler = VectorAssembler(inputCols=feature_cols, outputCol="features")
    df = assembler.transform(df)
    
    # Split data (stratified-like split using randomSplit)
    train_df, test_df = df.randomSplit([0.8, 0.2], seed=42)
    
    train_count = train_df.count()
    test_count = test_df.count()
    
    print(f"\nTrain samples: {train_count:,}")
    print(f"Test samples: {test_count:,}")
    print(f"Spark partitions: {train_df.rdd.getNumPartitions()}")
    
    return train_df, test_df, feature_cols, target_names


def train_model(train_df, test_df, hyperparams: dict):
    """Train Random Forest model using PySpark ML (distributed training)."""
    print("\n" + "=" * 80)
    print("TRAINING MODEL (Distributed on Spark)")
    print("=" * 80)
    
    print("Hyperparameters:")
    for key, value in hyperparams.items():
        print(f"  {key}: {value}")
    
    # Create Random Forest Classifier for distributed training
    rf_classifier = SparkRFClassifier(
        featuresCol="features",
        labelCol="label",
        predictionCol="prediction",
        probabilityCol="probability",
        rawPredictionCol="rawPrediction",
        numTrees=hyperparams.get("n_estimators", 100),
        maxDepth=hyperparams.get("max_depth", 10),
        minInstancesPerNode=hyperparams.get("min_samples_leaf", 1),
        seed=hyperparams.get("random_state", 42),
    )
    
    # Train model (distributed across Spark executors)
    print("\nTraining distributed Random Forest model...")
    model = rf_classifier.fit(train_df)
    
    print("Training completed!")
    print(f"Number of trees: {model.getNumTrees}")
    
    # Make predictions
    train_pred = model.transform(train_df)
    test_pred = model.transform(test_df)
    
    return model, train_pred, test_pred


def calculate_metrics_spark(predictions_df, label_col="label", prediction_col="prediction", dataset_name="Test"):
    """Calculate evaluation metrics using Spark's MulticlassClassificationEvaluator."""
    evaluator = MulticlassClassificationEvaluator(labelCol=label_col, predictionCol=prediction_col)
    
    accuracy = evaluator.setMetricName("accuracy").evaluate(predictions_df)
    f1 = evaluator.setMetricName("f1").evaluate(predictions_df)
    precision = evaluator.setMetricName("weightedPrecision").evaluate(predictions_df)
    recall = evaluator.setMetricName("weightedRecall").evaluate(predictions_df)
    
    return {
        f"{dataset_name.lower()}_accuracy": accuracy,
        f"{dataset_name.lower()}_precision": precision,
        f"{dataset_name.lower()}_recall": recall,
        f"{dataset_name.lower()}_f1": f1
    }


def log_to_mlflow(model, train_df, train_pred, test_pred, hyperparams, feature_names, target_names):
    """Log model, parameters, and metrics to MLflow."""
    print("\n" + "=" * 80)
    print("LOGGING TO MLFLOW")
    print("=" * 80)
    
    # Log hyperparameters
    for key, value in hyperparams.items():
        mlflow.log_param(key, value)
    
    # Log additional Spark-specific params
    mlflow.log_param("training_framework", "spark_ml")
    mlflow.log_param("distributed", True)
    mlflow.log_param("num_trees", model.getNumTrees)
    
    # Calculate and log metrics
    train_metrics = calculate_metrics_spark(train_pred, dataset_name="Train")
    test_metrics = calculate_metrics_spark(test_pred, dataset_name="Test")
    all_metrics = {**train_metrics, **test_metrics}
    
    for metric_name, metric_value in all_metrics.items():
        mlflow.log_metric(metric_name, metric_value)
    
    print("\nModel Performance:")
    print(f"  Training Accuracy: {train_metrics['train_accuracy']:.4f}")
    print(f"  Training F1: {train_metrics['train_f1']:.4f}")
    print(f"  Test Accuracy: {test_metrics['test_accuracy']:.4f}")
    print(f"  Test Precision: {test_metrics['test_precision']:.4f}")
    print(f"  Test Recall: {test_metrics['test_recall']:.4f}")
    print(f"  Test F1: {test_metrics['test_f1']:.4f}")
    
    # Print confusion matrix
    test_pdf = test_pred.select("label", "prediction").toPandas()
    cm = confusion_matrix(test_pdf["label"], test_pdf["prediction"])
    print(f"\n  Confusion Matrix:")
    print(f"  {cm}")
    
    # Create sample input for reference
    sample_input = train_df.select(feature_names).limit(1).toPandas()
    
    # Try to log model artifacts
    model_logged = False
    print("\nSaving model artifacts...")
    try:
        with tempfile.TemporaryDirectory() as tmpdir:
            # Save the Spark ML model
            model_path = os.path.join(tmpdir, "spark_model")
            model.save(model_path)
            
            # Save input example as JSON
            input_example_file = os.path.join(tmpdir, "input_example.json")
            sample_input.to_json(input_example_file, orient="records", indent=2)
            
            # Save target names
            target_names_file = os.path.join(tmpdir, "target_names.json")
            with open(target_names_file, "w") as f:
                json.dump(target_names, f)
            
            # Log artifacts
            mlflow.log_artifacts(model_path, artifact_path="model/spark_model")
            mlflow.log_artifact(input_example_file, artifact_path="model")
            mlflow.log_artifact(target_names_file, artifact_path="model")
            
            model_logged = True
            print("Model artifacts logged successfully!")
    except Exception as e:
        print(f"Warning: Could not log model artifacts: {e}")
        print("Metrics and parameters were logged successfully.")
    
    # Store references for later use
    all_metrics["_spark_model"] = model
    all_metrics["_model_logged"] = model_logged
    
    return all_metrics


def register_model(client: MlflowClient, model_name: str, run_id: str, experiment_id: str):
    """Register model in MLflow Model Registry."""
    print("\n" + "=" * 80)
    print("REGISTERING MODEL")
    print("=" * 80)
    
    model_uri = f"runs:/{run_id}/model"
    
    # Create registered model if it doesn't exist
    try:
        client.get_registered_model(model_name)
        print(f"Model '{model_name}' already exists in registry")
    except Exception:
        try:
            client.create_registered_model(model_name)
            print(f"Created registered model: {model_name}")
        except Exception as e:
            print(f"Could not create registered model: {e}")
    
    # Create model version
    try:
        result = client.create_model_version(
            name=model_name,
            source=model_uri,
            run_id=run_id
        )
        print(f"Model version registered successfully!")
        print(f"   Model Name: {model_name}")
        print(f"   Version: {result.version}")
        print(f"   Run ID: {run_id}")
        return result.version
    except Exception as e:
        print(f"Model registration failed (model still usable via run URI): {e}")
        print(f"   You can deploy using: mlflow-artifacts:/{experiment_id}/{run_id}/artifacts/model")
        return None


def load_model_and_predict(sample_data: pd.DataFrame, spark_model, spark: SparkSession, feature_names: list, target_names: list):
    """Load model and run prediction on driver.
    
    Args:
        sample_data: Pandas DataFrame with feature data
        spark_model: The Spark ML Random Forest model
        spark: SparkSession for creating DataFrame
        feature_names: List of feature column names
        target_names: List of target class names
    """
    print("\n" + "=" * 80)
    print("LOADING MODEL AND RUNNING PREDICTION")
    print("=" * 80)
    
    if spark_model is None:
        print("Error: No model provided")
        return None, None
    
    print("Using in-memory Spark model for prediction")
    
    # Create Spark DataFrame from sample
    sample_spark_df = spark.createDataFrame(sample_data)
    
    # Assemble features
    assembler = VectorAssembler(inputCols=feature_names, outputCol="features")
    sample_spark_df = assembler.transform(sample_spark_df)
    
    # Run prediction
    predictions_df = spark_model.transform(sample_spark_df)
    
    # Extract results
    result = predictions_df.select("prediction", "probability").collect()[0]
    prediction = int(result["prediction"])
    probabilities = result["probability"].toArray()
    
    print("\n" + "=" * 80)
    print("SAMPLE PREDICTION RESULTS")
    print("=" * 80)
    print(f"\nInput features:")
    for col_name in feature_names:
        print(f"  {col_name}: {sample_data[col_name].iloc[0]:.4f}")
    
    print(f"\nClass Probabilities:")
    for i, prob in enumerate(probabilities):
        species = target_names[i] if i < len(target_names) else f"Class {i}"
        print(f"  {species}: {prob:.4f}")
    
    predicted_species = target_names[prediction] if prediction < len(target_names) else f"Class {prediction}"
    print(f"\nPredicted Class: {prediction} ({predicted_species})")
    
    return prediction, probabilities


def print_deployment_info(run_id: str, experiment_id: str, model_name: str, model_version: str):
    """Print deployment instructions and sample payloads."""
    print("\n" + "=" * 80)
    print("TRAINING COMPLETE!")
    print("=" * 80)
    
    print(f"\nRun Information:")
    print(f"  Run ID: {run_id}")
    print(f"  Experiment ID: {experiment_id}")
    print(f"  Model URI (run): runs:/{run_id}/model")
    if model_version:
        print(f"  Model URI (registry): models:/{model_name}/{model_version}")
    
    print("\n" + "=" * 80)
    print("DEPLOYMENT PAYLOAD (deploy-model API)")
    print("=" * 80)
    
    deploy_payload = {
        "serve_name": "iris-rf-spark-classifier",
        "model_uri": f"mlflow-artifacts:/{experiment_id}/{run_id}/artifacts/model",
        "env": "local",
        "cores": 2,
        "memory": 4,
        "node_capacity": "spot",
        "min_replicas": 1,
        "max_replicas": 3
    }
    
    print(json.dumps(deploy_payload, indent=2))

In [None]:
def main():
    parser = argparse.ArgumentParser(description="Train Iris Classification Model on Spark")
    parser.add_argument(
        "--mlflow-uri",
        default="http://darwin-mlflow-lib.darwin.svc.cluster.local:8080",
        help="MLflow tracking URI"
    )
    parser.add_argument(
        "--username",
        default="abc@gmail.com",
        help="MLflow username"
    )
    parser.add_argument(
        "--password",
        default="password",
        help="MLflow password"
    )
    parser.add_argument(
        "--experiment-name",
        default="iris_spark_classification",
        help="MLflow experiment name"
    )
    parser.add_argument(
        "--model-name",
        default="IrisSparkRFClassifier",
        help="Registered model name"
    )
    
    args, _ = parser.parse_known_args()
    
    print("\n" + "=" * 80)
    print("IRIS SPECIES CLASSIFICATION WITH RANDOM FOREST ON SPARK")
    print("=" * 80)
    print(f"Started at: {datetime.now().strftime('%Y-%m-%d %H:%M:%S')}")
    
    # Initialize Spark
    spark = initialize_spark()
    
    # Setup MLflow
    client = setup_mlflow(args.mlflow_uri, args.username, args.password)
    set_experiment(experiment_name=args.experiment_name)
    print(f"Experiment: {args.experiment_name}")
    
    # Load data
    train_df, test_df, feature_names, target_names = load_and_prepare_data(spark)
    
    # Define hyperparameters
    hyperparams = {
        "n_estimators": 100,
        "max_depth": 10,
        "min_samples_leaf": 1,
        "random_state": 42,
    }
    
    # Start MLflow run
    with mlflow.start_run(run_name=f"spark_rf_iris_{datetime.now().strftime('%Y%m%d_%H%M%S')}"):
        # Train model (distributed on Spark)
        model, train_pred, test_pred = train_model(
            train_df, test_df, hyperparams
        )
        
        # Log to MLflow
        metrics = log_to_mlflow(
            model, train_df, train_pred, test_pred, hyperparams, feature_names, target_names
        )
        
        # Get run information
        run_id = mlflow.active_run().info.run_id
        experiment_id = mlflow.active_run().info.experiment_id
    
    # Register model (outside of run context) - only if artifacts were logged
    model_version = None
    if metrics.get("_model_logged", False):
        model_version = register_model(client, args.model_name, run_id, experiment_id)
    else:
        print("\nSkipping model registration (artifacts not logged to MLflow)")
    
    # Demonstrate running inference on driver
    sample_pdf = test_df.select(feature_names).limit(1).toPandas()
    spark_model = metrics.get("_spark_model")
    prediction, probabilities = load_model_and_predict(
        sample_pdf, spark_model, spark, feature_names, target_names
    )
    
    # Get actual value for comparison
    actual_value = test_df.select("label").limit(1).toPandas()["label"].iloc[0]
    actual_species = target_names[int(actual_value)] if int(actual_value) < len(target_names) else f"Class {int(actual_value)}"
    print(f"\nActual Class: {int(actual_value)} ({actual_species})")
    print(f"Prediction Correct: {prediction == int(actual_value)}")
    
    # Print deployment information
    print_deployment_info(run_id, experiment_id, args.model_name, model_version)
    
    # Cleanup: Stop Spark session
    cleanup_spark(spark)
    
    print("\nScript completed successfully!")
    print("=" * 80 + "\n")


if __name__ == "__main__":
    main()