# Wine Quality Classification: Spark Data Processing + LightGBM Training

This example demonstrates a hybrid approach using **Spark for data processing** and **native LightGBM for model training**.

## Architecture

- **Data Processing**: PySpark for distributed ETL operations (scales to large datasets)
- **Training**: Native LightGBM (efficient gradient boosting on driver)
- **Model Logging**: MLflow lightgbm flavor (reliable and compatible)

## Execution Modes

This notebook supports two execution modes:
- **Darwin Cluster Mode**: Uses Darwin SDK with Ray for distributed Spark processing
- **Local Mode**: Uses local Spark session for development/testing

## Dataset

- **Name**: Wine Quality Dataset
- **Samples**: 178 wine samples
- **Target**: Wine class (0, 1, 2) - three different cultivars
- **Type**: Multi-class Classification

## Model

- **Framework**: Native LightGBM
- **Data Processing**: PySpark
- **Objective**: Multi-class classification

## Features

The dataset includes 13 physicochemical properties:
- `alcohol`: Alcohol content
- `malic_acid`: Malic acid content
- `ash`: Ash content
- `alcalinity_of_ash`: Alcalinity of ash
- `magnesium`: Magnesium content
- `total_phenols`: Total phenols
- `flavanoids`: Flavanoids content
- `nonflavanoid_phenols`: Non-flavanoid phenols
- `proanthocyanins`: Proanthocyanins content
- `color_intensity`: Color intensity
- `hue`: Hue
- `od280/od315_of_diluted_wines`: OD280/OD315 ratio
- `proline`: Proline content

## Key Features

- Spark handles data loading, transformation, and splitting (can scale to big data)
- Native LightGBM handles model training (efficient and fast)
- Model logged using `mlflow.lightgbm` flavor (works with any MLflow server)
- Fast model loading at serving time (no Spark dependencies needed)

In [None]:
# Fix pyOpenSSL/cryptography compatibility issue first
%pip install --upgrade pyOpenSSL cryptography

# Install main dependencies (pin MLflow to match server version)
%pip install lightgbm pandas numpy scikit-learn mlflow==2.12.2 pyspark

In [None]:
import os
import argparse
import json
import tempfile
import numpy as np
import pandas as pd
from datetime import datetime

# LightGBM imports
import lightgbm as lgb

# Spark imports (for data processing only)
from pyspark.sql import SparkSession
from pyspark.sql.functions import col

# MLflow imports
import mlflow
import mlflow.lightgbm
from mlflow import set_tracking_uri, set_experiment
from mlflow.client import MlflowClient
from mlflow.models import infer_signature

# Scikit-learn imports (for loading dataset and metrics)
from sklearn.datasets import load_wine
from sklearn.metrics import accuracy_score, precision_score, recall_score, f1_score, confusion_matrix

# Darwin SDK imports (optional - only available on Darwin cluster)
DARWIN_SDK_AVAILABLE = False
try:
    import ray
    from darwin import init_spark_with_configs, stop_spark
    DARWIN_SDK_AVAILABLE = True
    print("Darwin SDK available - will use distributed Spark on Darwin cluster")
except ImportError as e:
    print(f"Darwin SDK not available: {e}")
    print("Running in LOCAL mode - will use local Spark session")
except AttributeError as e:
    if "X509_V_FLAG" in str(e) or "lib" in str(e):
        print("=" * 80)
        print("ERROR: pyOpenSSL/cryptography version conflict detected!")
        print("Please run the following in a cell before importing:")
        print("  %pip install --upgrade pyOpenSSL cryptography")
        print("Then restart the kernel and try again.")
        print("=" * 80)
        raise
    else:
        raise

In [None]:
def initialize_spark():
    """Initialize Spark session for data processing.
    
    Uses Darwin SDK on cluster, local Spark otherwise.
    Spark is used for distributed data processing (ETL, splitting).
    Training is done with native LightGBM on the driver.
    """
    print("\n" + "=" * 80)
    print("INITIALIZING SPARK SESSION")
    print("=" * 80)
    
    # Base Spark configurations for data processing
    spark_configs = {
        "spark.sql.execution.arrow.pyspark.enabled": "true",
        "spark.sql.session.timeZone": "UTC",
        "spark.sql.shuffle.partitions": "4",
        "spark.default.parallelism": "4",
        "spark.executor.memory": "1g",
        "spark.executor.cores": "1",
        "spark.driver.memory": "1g",
        "spark.executor.instances": "2",
    }
    
    if DARWIN_SDK_AVAILABLE:
        # Running on Darwin cluster - use distributed Spark via Ray
        print("Mode: Darwin Cluster (Distributed)")
        ray.init()
        spark = init_spark_with_configs(spark_configs=spark_configs)
    else:
        # Running locally - use local Spark session
        print("Mode: Local Spark Session")
        builder = SparkSession.builder \
            .appName("Wine-Spark-DataProcessing") \
            .master("local[*]")
        
        for key, value in spark_configs.items():
            builder = builder.config(key, value)
        
        spark = builder.getOrCreate()
    
    print(f"Spark version: {spark.version}")
    print(f"Application ID: {spark.sparkContext.applicationId}")
    
    return spark


def cleanup_spark(spark):
    """Stop Spark session properly based on environment."""
    print("\nStopping Spark session...")
    if DARWIN_SDK_AVAILABLE:
        stop_spark()
    else:
        spark.stop()
    print("Spark session stopped.")

In [None]:
def setup_mlflow(mlflow_uri: str, username: str, password: str) -> MlflowClient:
    """Configure MLflow tracking and return client."""
    os.environ["MLFLOW_TRACKING_USERNAME"] = username
    os.environ["MLFLOW_TRACKING_PASSWORD"] = password
    
    set_tracking_uri(mlflow_uri)
    client = MlflowClient(mlflow_uri)
    
    print(f"MLflow tracking URI: {mlflow_uri}")
    return client


def load_and_prepare_data(spark: SparkSession):
    """Load Wine dataset using Spark for processing, return pandas for training.
    
    Uses Spark for distributed data operations (can scale to large datasets).
    Returns pandas DataFrames for LightGBM training.
    """
    print("\n" + "=" * 80)
    print("LOADING DATASET")
    print("=" * 80)
    
    # Load dataset
    data = load_wine(as_frame=True)
    pdf = data.data.copy()
    pdf['label'] = data.target
    
    feature_names = data.feature_names
    
    print(f"Dataset: Wine")
    print(f"Samples: {len(pdf):,}")
    print(f"Features: {len(feature_names)}")
    
    print(f"\nFeature names:")
    for i, col_name in enumerate(feature_names, 1):
        print(f"  {i}. {col_name}")
    
    print(f"\nTarget distribution:")
    for class_idx in range(3):
        count = (pdf['label'] == class_idx).sum()
        print(f"  Class {class_idx}: {count} samples")
    
    # Use Spark for distributed data splitting (demonstrates Spark processing)
    print("\nUsing Spark for distributed data splitting...")
    spark_df = spark.createDataFrame(pdf)
    train_spark, test_spark = spark_df.randomSplit([0.8, 0.2], seed=42)
    
    # Collect to pandas for LightGBM training
    print("Collecting to pandas for training...")
    train_pdf = train_spark.toPandas()
    test_pdf = test_spark.toPandas()
    
    print(f"\nTrain samples: {len(train_pdf):,}")
    print(f"Test samples: {len(test_pdf):,}")
    
    return train_pdf, test_pdf, feature_names


def train_model(train_pdf, test_pdf, hyperparams: dict, feature_names: list):
    """Train LightGBM model using native LightGBM."""
    print("\n" + "=" * 80)
    print("TRAINING MODEL (Native LightGBM)")
    print("=" * 80)
    
    print("Hyperparameters:")
    for key, value in hyperparams.items():
        print(f"  {key}: {value}")
    
    # Prepare data
    X_train = train_pdf[feature_names].values
    y_train = train_pdf["label"].values
    X_test = test_pdf[feature_names].values
    y_test = test_pdf["label"].values
    
    print(f"\nTraining samples: {len(X_train)}, Test samples: {len(X_test)}")
    
    # Create LightGBM datasets
    train_data = lgb.Dataset(X_train, label=y_train, feature_name=list(feature_names))
    test_data = lgb.Dataset(X_test, label=y_test, feature_name=list(feature_names), reference=train_data)
    
    # LightGBM parameters
    params = {
        "objective": hyperparams.get("objective", "multiclass"),
        "num_class": hyperparams.get("num_class", 3),
        "num_leaves": hyperparams.get("num_leaves", 31),
        "learning_rate": hyperparams.get("learning_rate", 0.05),
        "feature_fraction": hyperparams.get("feature_fraction", 0.9),
        "bagging_fraction": hyperparams.get("bagging_fraction", 0.8),
        "bagging_freq": hyperparams.get("bagging_freq", 5),
        "verbose": -1,
        "seed": 42,
    }
    
    # Train model
    print("\nTraining LightGBM model...")
    model = lgb.train(
        params,
        train_data,
        num_boost_round=hyperparams.get("num_iterations", 100),
        valid_sets=[train_data, test_data],
        valid_names=["train", "test"],
    )
    
    print("Training completed!")
    
    # Make predictions
    train_proba = model.predict(X_train)
    test_proba = model.predict(X_test)
    train_pred = np.argmax(train_proba, axis=1)
    test_pred = np.argmax(test_proba, axis=1)
    
    # Store predictions for metrics calculation
    train_results = {"y_true": y_train, "y_pred": train_pred, "proba": train_proba}
    test_results = {"y_true": y_test, "y_pred": test_pred, "proba": test_proba}
    
    return model, train_results, test_results


def calculate_metrics(results: dict, dataset_name="Test"):
    """Calculate evaluation metrics from prediction results."""
    y_true = results["y_true"]
    y_pred = results["y_pred"]
    
    accuracy = accuracy_score(y_true, y_pred)
    precision = precision_score(y_true, y_pred, average="weighted")
    recall = recall_score(y_true, y_pred, average="weighted")
    f1 = f1_score(y_true, y_pred, average="weighted")
    
    return {
        f"{dataset_name.lower()}_accuracy": accuracy,
        f"{dataset_name.lower()}_precision": precision,
        f"{dataset_name.lower()}_recall": recall,
        f"{dataset_name.lower()}_f1": f1
    }


def log_to_mlflow(model, train_results, test_results, hyperparams, feature_names, sample_input):
    """Log LightGBM model, parameters, and metrics to MLflow."""
    print("\n" + "=" * 80)
    print("LOGGING TO MLFLOW")
    print("=" * 80)
    
    # Log hyperparameters
    for key, value in hyperparams.items():
        mlflow.log_param(key, value)
    
    # Log training framework info
    mlflow.log_param("training_framework", "lightgbm")
    mlflow.log_param("data_processing", "spark")
    
    # Calculate and log metrics
    train_metrics = calculate_metrics(train_results, dataset_name="Train")
    test_metrics = calculate_metrics(test_results, dataset_name="Test")
    all_metrics = {**train_metrics, **test_metrics}
    
    for metric_name, metric_value in all_metrics.items():
        mlflow.log_metric(metric_name, metric_value)
    
    print("\nModel Performance:")
    print(f"  Training Accuracy: {train_metrics['train_accuracy']:.4f}")
    print(f"  Training F1: {train_metrics['train_f1']:.4f}")
    print(f"  Test Accuracy: {test_metrics['test_accuracy']:.4f}")
    print(f"  Test Precision: {test_metrics['test_precision']:.4f}")
    print(f"  Test Recall: {test_metrics['test_recall']:.4f}")
    print(f"  Test F1: {test_metrics['test_f1']:.4f}")
    
    # Print confusion matrix
    cm = confusion_matrix(test_results["y_true"], test_results["y_pred"])
    print(f"\n  Confusion Matrix:")
    print(f"  {cm}")
    
    # Log model using mlflow.lightgbm
    model_logged = False
    print("\nSaving model artifacts...")
    try:
        # Create signature
        sample_output = pd.DataFrame({"prediction": [0]})
        signature = infer_signature(sample_input, sample_output)
        
        # Log LightGBM model
        print("  Logging to MLflow using lightgbm flavor...")
        mlflow.lightgbm.log_model(
            lgb_model=model,
            artifact_path="model",
            signature=signature,
            input_example=sample_input
        )
        
        model_logged = True
        print("Model artifacts logged successfully!")
    except Exception as e:
        print(f"Warning: Could not log model artifacts: {e}")
        import traceback
        traceback.print_exc()
        print("Metrics and parameters were logged successfully.")
    
    # Store references for later use
    all_metrics["_native_model"] = model
    all_metrics["_model_logged"] = model_logged
    
    return all_metrics


def register_model(client: MlflowClient, model_name: str, run_id: str, experiment_id: str):
    """Register model in MLflow Model Registry."""
    print("\n" + "=" * 80)
    print("REGISTERING MODEL")
    print("=" * 80)
    
    model_uri = f"runs:/{run_id}/model"
    
    # Create registered model if it doesn't exist
    try:
        client.get_registered_model(model_name)
        print(f"Model '{model_name}' already exists in registry")
    except Exception:
        try:
            client.create_registered_model(model_name)
            print(f"Created registered model: {model_name}")
        except Exception as e:
            print(f"Could not create registered model: {e}")
    
    # Create model version
    try:
        result = client.create_model_version(
            name=model_name,
            source=model_uri,
            run_id=run_id
        )
        print(f"Model version registered successfully!")
        print(f"   Model Name: {model_name}")
        print(f"   Version: {result.version}")
        print(f"   Run ID: {run_id}")
        return result.version
    except Exception as e:
        print(f"Model registration failed (model still usable via run URI): {e}")
        print(f"   You can deploy using: mlflow-artifacts:/{experiment_id}/{run_id}/artifacts/model")
        return None


def load_model_and_predict(sample_data: pd.DataFrame, native_model=None, feature_names=None):
    """Load model and run prediction on driver.
    
    Args:
        sample_data: Pandas DataFrame with feature data
        native_model: The native LightGBM booster model
        feature_names: List of feature column names
    """
    print("\n" + "=" * 80)
    print("LOADING MODEL AND RUNNING PREDICTION")
    print("=" * 80)
    
    if native_model is None:
        print("Error: No model provided")
        return None
    
    print("Using in-memory model for prediction")
    
    # Convert to numpy array for prediction
    X = sample_data[feature_names].values if feature_names else sample_data.values
    
    # Run prediction using native LightGBM booster
    probabilities = native_model.predict(X)
    predictions = np.argmax(probabilities, axis=1)
    
    print("\n" + "=" * 80)
    print("SAMPLE PREDICTION RESULTS")
    print("=" * 80)
    print(f"\nInput features:")
    for col_name in sample_data.columns:
        print(f"  {col_name}: {sample_data[col_name].iloc[0]:.4f}")
    
    print(f"\nClass Probabilities:")
    for i, prob in enumerate(probabilities[0]):
        print(f"  Class {i}: {prob:.4f}")
    print(f"\nPredicted Class: {predictions[0]}")
    
    return predictions, probabilities


def print_deployment_info(run_id: str, experiment_id: str, model_name: str, model_version: str):
    """Print deployment instructions and sample payloads."""
    print("\n" + "=" * 80)
    print("TRAINING COMPLETE!")
    print("=" * 80)
    
    print(f"\nRun Information:")
    print(f"  Run ID: {run_id}")
    print(f"  Experiment ID: {experiment_id}")
    print(f"  Model URI (run): runs:/{run_id}/model")
    if model_version:
        print(f"  Model URI (registry): models:/{model_name}/{model_version}")
    
    print("\n" + "=" * 80)
    print("DEPLOYMENT PAYLOAD (deploy-model API)")
    print("=" * 80)
    
    deploy_payload = {
        "serve_name": "wine-lightgbm-spark-classifier",
        "model_uri": f"mlflow-artifacts:/{experiment_id}/{run_id}/artifacts/model",
        "env": "local",
        "cores": 2,
        "memory": 4,
        "node_capacity": "spot",
        "min_replicas": 1,
        "max_replicas": 3
    }
    
    print(json.dumps(deploy_payload, indent=2))

In [None]:
def main():
    parser = argparse.ArgumentParser(description="Train Wine Classification Model (Spark data processing + LightGBM training)")
    parser.add_argument(
        "--mlflow-uri",
        default="http://darwin-mlflow-lib.darwin.svc.cluster.local:8080",
        help="MLflow tracking URI"
    )
    parser.add_argument(
        "--username",
        default="abc@gmail.com",
        help="MLflow username"
    )
    parser.add_argument(
        "--password",
        default="password",
        help="MLflow password"
    )
    parser.add_argument(
        "--experiment-name",
        default="wine_spark_lightgbm_classification",
        help="MLflow experiment name"
    )
    parser.add_argument(
        "--model-name",
        default="WineLightGBMClassifier",
        help="Registered model name"
    )
    
    args, _ = parser.parse_known_args()
    
    print("\n" + "=" * 80)
    print("WINE CLASSIFICATION: SPARK DATA PROCESSING + LIGHTGBM TRAINING")
    print("=" * 80)
    print(f"Started at: {datetime.now().strftime('%Y-%m-%d %H:%M:%S')}")
    
    # Initialize Spark for data processing
    spark = initialize_spark()
    
    # Setup MLflow
    client = setup_mlflow(args.mlflow_uri, args.username, args.password)
    set_experiment(experiment_name=args.experiment_name)
    print(f"Experiment: {args.experiment_name}")
    
    # Load and prepare data using Spark (returns pandas DataFrames)
    train_pdf, test_pdf, feature_names = load_and_prepare_data(spark)
    
    # Define hyperparameters
    hyperparams = {
        "objective": "multiclass",
        "num_class": 3,
        "num_leaves": 31,
        "learning_rate": 0.05,
        "feature_fraction": 0.9,
        "bagging_fraction": 0.8,
        "bagging_freq": 5,
        "num_iterations": 100,
    }
    
    # Get sample input for MLflow logging
    sample_input = train_pdf[feature_names].head(1)
    
    # Start MLflow run
    with mlflow.start_run(run_name=f"lightgbm_wine_{datetime.now().strftime('%Y%m%d_%H%M%S')}"):
        # Train LightGBM model
        model, train_results, test_results = train_model(
            train_pdf, test_pdf, hyperparams, feature_names
        )
        
        # Log to MLflow
        metrics = log_to_mlflow(
            model, train_results, test_results, hyperparams, feature_names, sample_input
        )
        
        # Get run information
        run_id = mlflow.active_run().info.run_id
        experiment_id = mlflow.active_run().info.experiment_id
    
    # Register model (outside of run context) - only if artifacts were logged
    model_version = None
    if metrics.get("_model_logged", False):
        model_version = register_model(client, args.model_name, run_id, experiment_id)
    else:
        print("\nSkipping model registration (artifacts not logged to MLflow)")
    
    # Demo prediction with LightGBM model
    print("\n" + "=" * 80)
    print("SAMPLE PREDICTION")
    print("=" * 80)
    sample_pdf = test_pdf[feature_names].head(1)
    native_model = metrics.get("_native_model")
    predictions, probabilities = load_model_and_predict(sample_pdf, native_model=native_model, feature_names=feature_names)
    
    # Get actual value for comparison
    actual_value = test_pdf["label"].iloc[0]
    print(f"\nActual Class: {int(actual_value)}")
    print(f"Prediction Correct: {predictions[0] == actual_value}")
    
    # Print deployment information
    print_deployment_info(run_id, experiment_id, args.model_name, model_version)
    
    # Cleanup: Stop Spark session
    cleanup_spark(spark)
    
    print("\nScript completed successfully!")
    print("=" * 80 + "\n")


if __name__ == "__main__":
    main()