# Spark ML Classification Demo: Iris

This notebook demonstrates a complete end-to-end classification workflow in Spark ML using the classic Iris dataset. It includes:
- Data download into DBFS
- Preprocessing and model pipeline
- Training and evaluation metrics (accuracy, F1, weighted precision/recall)
- Confusion matrix
- Model introspection (coefficients or feature importances)

In [0]:
# Imports and setup
from pyspark.sql import functions as F
from pyspark.sql import types as T
from pyspark.ml import Pipeline
from pyspark.ml.feature import StringIndexer, VectorAssembler, StandardScaler
from pyspark.ml.classification import LogisticRegression, RandomForestClassifier, GBTClassifier
from pyspark.ml.evaluation import MulticlassClassificationEvaluator
from pyspark.mllib.evaluation import MulticlassMetrics
import requests, os
from urllib.parse import urlparse

In [0]:
# Widgets for configuration
dbutils.widgets.text("data_url", "https://raw.githubusercontent.com/uiuc-cse/data-fa14/gh-pages/data/iris.csv", "Data URL (CSV)")
dbutils.widgets.text("dbfs_output_path", "dbfs:/tmp/datasets/iris.csv", "DBFS output path")
dbutils.widgets.dropdown("classifier", "logistic_regression", ["logistic_regression","random_forest","gbt"], "Classifier")
dbutils.widgets.text("test_size", "0.2", "Test fraction (0-1)")
dbutils.widgets.text("random_seed", "42", "Random seed")
dbutils.widgets.dropdown("standardize_features", "true", ["true","false"], "Standardize numeric features")
dbutils.widgets.dropdown("force_download", "false", ["false","true"], "Force re-download data")

In [0]:
# Read widget values
data_url = dbutils.widgets.get("data_url").strip()
dbfs_output_path = dbutils.widgets.get("dbfs_output_path").strip()
clf_choice = dbutils.widgets.get("classifier").strip()
test_fraction = float(dbutils.widgets.get("test_size"))
seed = int(float(dbutils.widgets.get("random_seed")))
standardize = dbutils.widgets.get("standardize_features") == "true"
force_download = dbutils.widgets.get("force_download") == "true"

print(f"Data URL: {data_url}")
print(f"DBFS path: {dbfs_output_path}")
print(f"Classifier: {clf_choice}, standardize: {standardize}")
print(f"Test fraction: {test_fraction}, seed: {seed}")
print(f"Force download: {force_download}")

Data URL: https://raw.githubusercontent.com/uiuc-cse/data-fa14/gh-pages/data/iris.csv
DBFS path: dbfs:/tmp/datasets/iris.csv
Classifier: logistic_regression, standardize: True
Test fraction: 0.2, seed: 42
Force download: False


In [0]:
# Utilities to download and cache data in DBFS
def dbfs_to_local(dbfs_path: str) -> str:
    """Convert dbfs:/path to /dbfs/path so Python can write the file."""
    if dbfs_path.startswith("dbfs:/"):
        return "/dbfs/" + dbfs_path.replace("dbfs:/", "")
    elif dbfs_path.startswith("/dbfs/"):
        return dbfs_path
    else:
        return "/dbfs/" + dbfs_path.lstrip("/")

def ensure_parent_dir(local_path: str):
    os.makedirs(os.path.dirname(local_path), exist_ok=True)

def download_to_dbfs(url: str, dbfs_path: str, force: bool = False) -> str:
    """Stream-download a URL into DBFS, with simple caching."""
    local_path = dbfs_to_local(dbfs_path)
    ensure_parent_dir(local_path)
    if os.path.exists(local_path) and os.path.getsize(local_path) > 0 and not force:
        print(f"File already exists at {dbfs_path} ({os.path.getsize(local_path)} bytes). Skipping download.")
        return dbfs_path
    print(f"Downloading {url} -> {dbfs_path}")
    with requests.get(url, stream=True, timeout=30) as r:
        r.raise_for_status()
        tmp_path = local_path + ".tmp"
        with open(tmp_path, "wb") as f:
            for chunk in r.iter_content(chunk_size=1<<14):
                if chunk:
                    f.write(chunk)
        os.replace(tmp_path, local_path)
    print("Download complete.")
    return dbfs_path

In [0]:
# Download and load the dataset
dbfs_csv = download_to_dbfs(data_url, dbfs_output_path, force_download)

# Spark can read "dbfs:/..." paths directly
df = spark.read.option("header", "true").option("inferSchema", "true").csv(dbfs_csv)

print("Schema:")
df.printSchema()

print("Sample:")
display(df.limit(5))

# Basic validation for the Iris dataset
expected_cols = {"sepal_length","sepal_width","petal_length","petal_width","species"}
missing = expected_cols - set(df.columns)
if missing:
    raise ValueError(f"CSV does not look like Iris dataset. Missing columns: {missing}")

Downloading https://raw.githubusercontent.com/uiuc-cse/data-fa14/gh-pages/data/iris.csv -> dbfs:/tmp/datasets/iris.csv
Download complete.
Schema:
root
 |-- sepal_length: double (nullable = true)
 |-- sepal_width: double (nullable = true)
 |-- petal_length: double (nullable = true)
 |-- petal_width: double (nullable = true)
 |-- species: string (nullable = true)

Sample:


sepal_length,sepal_width,petal_length,petal_width,species
5.1,3.5,1.4,0.2,setosa
4.9,3.0,1.4,0.2,setosa
4.7,3.2,1.3,0.2,setosa
4.6,3.1,1.5,0.2,setosa
5.0,3.6,1.4,0.2,setosa


In [0]:
# EDA: label distribution
display(df.groupBy("species").count().orderBy(F.desc("count")))

species,count
virginica,50
versicolor,50
setosa,50


In [0]:
# Preprocessing and model selection
label_col = "species"
feature_cols = ["sepal_length","sepal_width","petal_length","petal_width"]
indexed_label_col = "label"
features_vec_col = "features"
scaled_features_col = "scaledFeatures" if standardize else features_vec_col

# Transform label strings to indexed numeric labels
label_indexer = StringIndexer(inputCol=label_col, outputCol=indexed_label_col, handleInvalid="error")

# Assemble numeric features into a vector
assembler = VectorAssembler(inputCols=feature_cols, outputCol=features_vec_col, handleInvalid="error")

# Optionally standardize features (recommended for LR)
scaler = StandardScaler(withMean=True, withStd=True, inputCol=features_vec_col, outputCol=scaled_features_col)

# Choose classifier
if clf_choice == "logistic_regression":
    classifier = LogisticRegression(featuresCol=scaled_features_col, labelCol=indexed_label_col, maxIter=100, regParam=0.0, elasticNetParam=0.0)
elif clf_choice == "random_forest":
    classifier = RandomForestClassifier(featuresCol=scaled_features_col, labelCol=indexed_label_col, numTrees=100, maxDepth=5, seed=seed)
elif clf_choice == "gbt":
    # GBTClassifier is binary-only; fall back gracefully for a multiclass dataset
    print("Warning: GBTClassifier supports binary classification only. Falling back to RandomForestClassifier.")
    classifier = RandomForestClassifier(featuresCol=scaled_features_col, labelCol=indexed_label_col, numTrees=200, maxDepth=6, seed=seed)
else:
    raise ValueError(f"Unknown classifier choice: {clf_choice}")

stages = [label_indexer, assembler] + ([scaler] if standardize else []) + [classifier]
pipeline = Pipeline(stages=stages)

In [0]:
# Train/test split and training
train_df, test_df = df.randomSplit([1.0 - test_fraction, test_fraction], seed=seed)
print(f"Train count: {train_df.count()}, Test count: {test_df.count()}")

model = pipeline.fit(train_df)

Train count: 126, Test count: 24
🏃 View run able-kit-423 at: https://adb-3017385027020604.4.azuredatabricks.net/ml/experiments/3784693748081254/runs/48495965e5814562a38ff974609a6ae1
🧪 View experiment at: https://adb-3017385027020604.4.azuredatabricks.net/ml/experiments/3784693748081254


In [0]:
# Evaluation utilities
def evaluate_predictions(pred_df, label_col=indexed_label_col, prediction_col="prediction"):
    """
    Compute common classification metrics and confusion matrix for a prediction DataFrame.
    Returns a dict with metrics and confusion matrix (as nested lists).
    """
    predictionAndLabels = pred_df.select(F.col(prediction_col).cast("double"), F.col(label_col).cast("double")).rdd.map(tuple)
    metrics = MulticlassMetrics(predictionAndLabels)
    labels = sorted([float(l) for l in pred_df.select(label_col).distinct().orderBy(label_col).rdd.map(lambda r: r[0]).collect()])
    cm = metrics.confusionMatrix().toArray().tolist()
    eval_acc = MulticlassClassificationEvaluator(labelCol=label_col, predictionCol=prediction_col, metricName="accuracy").evaluate(pred_df)
    eval_f1 = MulticlassClassificationEvaluator(labelCol=label_col, predictionCol=prediction_col, metricName="f1").evaluate(pred_df)
    eval_wp = MulticlassClassificationEvaluator(labelCol=label_col, predictionCol=prediction_col, metricName="weightedPrecision").evaluate(pred_df)
    eval_wr = MulticlassClassificationEvaluator(labelCol=label_col, predictionCol=prediction_col, metricName="weightedRecall").evaluate(pred_df)
    return {
        "labels": labels,
        "confusion_matrix": cm,
        "accuracy": eval_acc,
        "f1": eval_f1,
        "weightedPrecision": eval_wp,
        "weightedRecall": eval_wr,
        "precisionByLabel": [metrics.precision(l) for l in labels],
        "recallByLabel": [metrics.recall(l) for l in labels],
        "fMeasureByLabel": [metrics.fMeasure(l) for l in labels],
    }

# Generate predictions
train_pred = model.transform(train_df)
test_pred = model.transform(test_df)

# Compute metrics
train_stats = evaluate_predictions(train_pred)
test_stats = evaluate_predictions(test_pred)

print("Training metrics:")
print({k: v for k,v in train_stats.items() if k not in ["confusion_matrix","labels","precisionByLabel","recallByLabel","fMeasureByLabel"]})

print("Test metrics:")
print({k: v for k,v in test_stats.items() if k not in ["confusion_matrix","labels","precisionByLabel","recallByLabel","fMeasureByLabel"]})

# Display confusion matrix for the test set
import pandas as pd
test_cm_df = pd.DataFrame(test_stats["confusion_matrix"], index=[f"true_{int(l)}" for l in test_stats["labels"]], columns=[f"pred_{int(l)}" for l in test_stats["labels"]])
display(test_cm_df)

# Recover original species labels (index -> string)
si_model = model.stages[0]  # StringIndexerModel
id_to_label = {i: lbl for i, lbl in enumerate(si_model.labels)}
print("Label index mapping:", id_to_label)



Training metrics:
{'accuracy': 0.9841269841269841, 'f1': 0.9841269841269841, 'weightedPrecision': 0.9841269841269841, 'weightedRecall': 0.9841269841269841}
Test metrics:
{'accuracy': 1.0, 'f1': 1.0, 'weightedPrecision': 1.0, 'weightedRecall': 1.0}


pred_0,pred_1,pred_2
6.0,0.0,0.0
0.0,7.0,0.0
0.0,0.0,11.0


Label index mapping: {0: 'versicolor', 1: 'virginica', 2: 'setosa'}


In [0]:
# Model introspection: coefficients or feature importances
from pyspark.ml.classification import LogisticRegressionModel, RandomForestClassificationModel, GBTClassificationModel

last_stage = model.stages[-1]
if isinstance(last_stage, LogisticRegressionModel):
    print("Logistic Regression coefficients (multinomial):")
    coef = last_stage.coefficientMatrix.toArray()
    intercepts = last_stage.interceptVector.toArray()
    feature_names = feature_cols
    import numpy as np  # noqa: F401
    for cls_idx in range(coef.shape[0]):
        cls_label = id_to_label.get(cls_idx, str(cls_idx))
        weights = {feature_names[i]: float(coef[cls_idx, i]) for i in range(coef.shape[1])}
        print(f"Class {cls_idx} ({cls_label}) intercept={float(intercepts[cls_idx]):.4f} weights={weights}")
elif isinstance(last_stage, (RandomForestClassificationModel, GBTClassificationModel)):
    fi = last_stage.featureImportances
    feature_names = feature_cols
    pairs = sorted(zip(feature_names, fi.toArray().tolist()), key=lambda x: x[1], reverse=True)
    print("Feature importances:")
    for name, importance in pairs:
        print(f"{name}: {importance:.4f}")
else:
    print(f"Model type {type(last_stage)} not recognized for introspection.")

Logistic Regression coefficients (multinomial):
Class 0 (versicolor) intercept=11.1964 weights={'sepal_length': 4.102409055953124, 'sepal_width': -7.003273159281689, 'petal_length': 3.6580246040108375, 'petal_width': 4.461056427552989}
Class 1 (virginica) intercept=-5.8866 weights={'sepal_length': 2.2766833809266664, 'sepal_width': -9.986849094451065, 'petal_length': 18.916705498969627, 'petal_width': 16.870575796944898}
Class 2 (setosa) intercept=-5.3097 weights={'sepal_length': -6.37909243687979, 'sepal_width': 16.990122253732753, 'petal_length': -22.574730102980464, 'petal_width': -21.331632224497888}


## How to run this notebook in Databricks

1. Import this file as "Source":
   - Workspace sidebar → right-click your folder → Import → Choose this .py file → Import format = Source.
2. Attach to a cluster:
   - Runtime: Apache Spark 3.x (DBR 11+ recommended).
   - Python 3.8+ with MLlib included (standard runtimes suffice).
3. Configure widgets at the top:
   - data_url: public Iris CSV (default provided).
   - dbfs_output_path: e.g., dbfs:/tmp/datasets/iris.csv
   - classifier: logistic_regression, random_forest, or gbt (gbt falls back to RF).
   - standardize_features: true for LR; either for trees.
   - test_size and random_seed: control the split.
   - force_download: true to re-fetch the CSV.
4. Run cells from top to bottom. The notebook prints training/testing metrics and displays a confusion matrix.

Troubleshooting:
- Network/SSL issues when downloading? Set force_download=true or upload the CSV via UI (Data → DBFS → Upload) and point dbfs_output_path to it.
- If you import as a standard notebook instead of source, the cells and markdown render normally; "COMMAND ----------" markers preserve cell boundaries in source files.

## Verification checklist

Use this quick checklist after running the notebook top-to-bottom to verify everything is working:

- Widgets appear at the top with defaults populated (data_url, dbfs_output_path, classifier, etc.).
- Download step reports either "Skipping download" (file cached) or "Download complete."
- Schema prints with 5 columns: sepal_length, sepal_width, petal_length, petal_width, species; all numeric except species (string).
- Class distribution display shows 3 classes (Iris-setosa, Iris-versicolor, Iris-virginica).
- Train/Test counts are printed; both counts are &gt; 0.
- Training and Test metrics are printed dictionaries including accuracy, f1, weightedPrecision, weightedRecall.
- Test accuracy is typically &gt; 0.90 for Iris with LR or RF.
- A confusion matrix table is displayed for the test set.
- "Label index mapping" is printed, mapping 0/1/2 to species names.
- Model introspection prints:
  - Logistic Regression: intercepts and per-feature weights per class, or
  - RandomForest: feature importances (descending).

Optional next steps:
- Replace the dataset URL with your own CSV that has a string label column and numeric features.
- Add hyperparameter tuning via CrossValidator or TrainValidationSplit.
- Log parameters/metrics/model to MLflow (Databricks: Experiment sidebar).