# Fractal ML Training

## Setup

In [4]:
import time
from pyspark.ml import Pipeline
from pyspark.ml.classification import RandomForestClassifier
from pyspark.ml.evaluation import MulticlassClassificationEvaluator
from pyspark.ml.feature import StandardScaler, VectorAssembler
from pyspark.ml.tuning import CrossValidator, ParamGridBuilder
from pyspark.sql import SparkSession
from pyspark.sql.functions import col, when
from sparkmeasure import StageMetrics

In [None]:
data_path = "/opt/spark/work-dir/data/FRACTAL"
sample_fraction = 2
num_executors = 2

In [None]:
spark = SparkSession.builder.appName("fractal-cv-rf") \
    .master("spark://spark-master:7077") \
    .config("spark.hadoop.fs.s3a.impl", "org.apache.hadoop.fs.s3a.S3AFileSystem") \
    .config("spark.hadoop.fs.s3a.aws.credentials.provider", "com.amazonaws.auth.DefaultAWSCredentialsProviderChain") \
    .config("spark.executor.memory", "8g") \
    .config("spark.driver.memory", "2g") \
    .config("spark.executor.cores", "2") \
    .config("spark.dynamicAllocation.enabled", "true") \
    .config("spark.dynamicAllocation.minExecutors", "1") \
    .config("spark.dynamicAllocation.maxExecutors", str(num_executors)) \
    .config("spark.dynamicAllocation.initialExecutors", str(num_executors)) \
    .config("spark.eventLog.enabled", "true") \
    .config("spark.eventLog.dir", "/opt/spark/spark-events") \
    .getOrCreate()

In [7]:
stage_metrics = StageMetrics(spark)
print(f"Spark UI: {spark.sparkContext.uiWebUrl}")

Spark UI: http://59d281bd1cfe:4040


## Data Loading

In [8]:
def prepare_data(df):
    df = df.withColumn("z_raw", col("xyz")[2]).withColumn(
        "ndvi",
        when(
            (col("Infrared") + col("Red")) != 0,
            (col("Infrared") - col("Red")) / (col("Infrared") + col("Red")),
        ).otherwise(0),
    )
    return df.select("z_raw", "Intensity", "Red", "Green", "Blue", "Infrared", "ndvi", col("Classification").alias("label"))


def load_sample(spark, path, fraction, cols):
    import random
    
    sc = spark.sparkContext
    hadoop_conf = sc._jsc.hadoopConfiguration()
    fs = spark._jvm.org.apache.hadoop.fs.FileSystem.get(hadoop_conf)
    path_obj = spark._jvm.org.apache.hadoop.fs.Path(path)
    
    file_list = []
    if fs.exists(path_obj):
        statuses = fs.listStatus(path_obj)
        for status in statuses:
            file_path = str(status.getPath())
            if file_path.endswith('.parquet') or not '.' in file_path.split('/')[-1]:
                file_list.append(file_path)
    
    if not file_list:
        file_list = [path]
    
    random.seed(62)
    
    if fraction <= 1.0:
        num_files = max(1, int(len(file_list) * fraction))
        print(f"Sampling {num_files}/{len(file_list)} files (fraction={fraction})")
    else:
        num_files = min(int(fraction), len(file_list))
        print(f"Sampling {num_files}/{len(file_list)} files (count={int(fraction)})")
    
    sampled_files = random.sample(file_list, num_files)
    print(sampled_files)
    df = spark.read.parquet(*sampled_files).select(*cols)
    
    df = prepare_data(df).cache()
    print(f"Loaded {df.count()} rows")
    return df

In [None]:
cols = ["xyz", "Intensity", "Classification", "Red", "Green", "Blue", "Infrared"]

stage_metrics.begin()
start_time = time.time()

train = load_sample(spark, f"{data_path}/train/", sample_fraction, cols)
val = load_sample(spark, f"{data_path}/val/", sample_fraction, cols)
test = load_sample(spark, f"{data_path}/test/", sample_fraction, cols)

IOPub data rate exceeded.
The Jupyter server will temporarily stop sending output
to the client in order to avoid crashing it.
To change this limit, set the config variable
`--ServerApp.iopub_data_rate_limit`.

Current values:
ServerApp.iopub_data_rate_limit=1000000.0 (bytes/sec)
ServerApp.rate_limit_window=3.0 (secs)

[Stage 2:>                                                      (53 + 4) / 4602]

## Pipeline

In [None]:
z_assembler = VectorAssembler(inputCols=["z_raw"], outputCol="z_vec")
z_scaler = StandardScaler(inputCol="z_vec", outputCol="z", withMean=False, withStd=True)

assembler = VectorAssembler(
    inputCols=["z", "Intensity", "Red", "Green", "Blue", "Infrared", "ndvi"],
    outputCol="features",
)

rf = RandomForestClassifier(labelCol="label", featuresCol="features", seed=62)
pipeline = Pipeline(stages=[z_assembler, z_scaler, assembler, rf])

## CrossValidator

In [17]:
paramGrid = ParamGridBuilder() \
    .addGrid(rf.numTrees, [50, 100, 200]) \
    .addGrid(rf.maxDepth, [10, 15, 20]) \
    .build()

evaluator = MulticlassClassificationEvaluator(
    labelCol="label", predictionCol="prediction", metricName="accuracy"
)

n_folds = 1

cv = CrossValidator(
    estimator=pipeline,
    estimatorParamMaps=paramGrid,
    evaluator=evaluator,
    numFolds=n_folds,
    parallelism=num_executors,
    seed=62
)

print(f"Training {len(paramGrid)} models × {n_folds} folds = {len(paramGrid) * n_folds} fits")

Training 9 models × 1 folds = 9 fits


In [18]:
cv_model = cv.fit(train)

25/11/09 20:54:44 WARN CacheManager: Asked to cache already cached data.
25/11/09 20:54:44 WARN CacheManager: Asked to cache already cached data.
25/11/09 20:54:44 WARN TaskSetManager: Lost task 0.0 in stage 59.0 (TID 114) (172.19.0.6 executor 0): java.lang.IllegalArgumentException: requirement failed: Nothing has been added to this summarizer.
	at scala.Predef$.require(Predef.scala:281)
	at org.apache.spark.ml.stat.SummarizerBuffer.mean(Summarizer.scala:624)
	at org.apache.spark.ml.stat.SummaryBuilderImpl$MetricsAggregate.$anonfun$eval$1(Summarizer.scala:358)
	at scala.collection.TraversableLike.$anonfun$map$1(TraversableLike.scala:286)
	at scala.collection.mutable.ResizableArray.foreach(ResizableArray.scala:62)
	at scala.collection.mutable.ResizableArray.foreach$(ResizableArray.scala:55)
	at scala.collection.mutable.ArrayBuffer.foreach(ArrayBuffer.scala:49)
	at scala.collection.TraversableLike.map(TraversableLike.scala:286)
	at scala.collection.TraversableLike.map$(TraversableLike.sc

Py4JJavaError: An error occurred while calling o402074.fit.
: org.apache.spark.SparkException: Job aborted due to stage failure: Task 0 in stage 59.0 failed 4 times, most recent failure: Lost task 0.3 in stage 59.0 (TID 120) (172.19.0.5 executor 1): java.lang.IllegalArgumentException: requirement failed: Nothing has been added to this summarizer.
	at scala.Predef$.require(Predef.scala:281)
	at org.apache.spark.ml.stat.SummarizerBuffer.mean(Summarizer.scala:624)
	at org.apache.spark.ml.stat.SummaryBuilderImpl$MetricsAggregate.$anonfun$eval$1(Summarizer.scala:358)
	at scala.collection.TraversableLike.$anonfun$map$1(TraversableLike.scala:286)
	at scala.collection.mutable.ResizableArray.foreach(ResizableArray.scala:62)
	at scala.collection.mutable.ResizableArray.foreach$(ResizableArray.scala:55)
	at scala.collection.mutable.ArrayBuffer.foreach(ArrayBuffer.scala:49)
	at scala.collection.TraversableLike.map(TraversableLike.scala:286)
	at scala.collection.TraversableLike.map$(TraversableLike.scala:279)
	at scala.collection.AbstractTraversable.map(Traversable.scala:108)
	at org.apache.spark.ml.stat.SummaryBuilderImpl$MetricsAggregate.eval(Summarizer.scala:357)
	at org.apache.spark.ml.stat.SummaryBuilderImpl$MetricsAggregate.eval(Summarizer.scala:345)
	at org.apache.spark.sql.catalyst.expressions.aggregate.TypedImperativeAggregate.eval(interfaces.scala:594)
	at org.apache.spark.sql.execution.aggregate.AggregationIterator.$anonfun$generateResultProjection$5(AggregationIterator.scala:257)
	at org.apache.spark.sql.execution.aggregate.ObjectAggregationIterator.next(ObjectAggregationIterator.scala:97)
	at org.apache.spark.sql.execution.aggregate.ObjectAggregationIterator.next(ObjectAggregationIterator.scala:33)
	at org.apache.spark.sql.catalyst.expressions.GeneratedClass$GeneratedIteratorForCodegenStage2.processNext(Unknown Source)
	at org.apache.spark.sql.execution.BufferedRowIterator.hasNext(BufferedRowIterator.java:43)
	at org.apache.spark.sql.execution.WholeStageCodegenEvaluatorFactory$WholeStageCodegenPartitionEvaluator$$anon$1.hasNext(WholeStageCodegenEvaluatorFactory.scala:43)
	at org.apache.spark.sql.execution.SparkPlan.$anonfun$getByteArrayRdd$1(SparkPlan.scala:388)
	at org.apache.spark.rdd.RDD.$anonfun$mapPartitionsInternal$2(RDD.scala:893)
	at org.apache.spark.rdd.RDD.$anonfun$mapPartitionsInternal$2$adapted(RDD.scala:893)
	at org.apache.spark.rdd.MapPartitionsRDD.compute(MapPartitionsRDD.scala:52)
	at org.apache.spark.rdd.RDD.computeOrReadCheckpoint(RDD.scala:367)
	at org.apache.spark.rdd.RDD.iterator(RDD.scala:331)
	at org.apache.spark.scheduler.ResultTask.runTask(ResultTask.scala:93)
	at org.apache.spark.TaskContext.runTaskWithListeners(TaskContext.scala:166)
	at org.apache.spark.scheduler.Task.run(Task.scala:141)
	at org.apache.spark.executor.Executor$TaskRunner.$anonfun$run$4(Executor.scala:621)
	at org.apache.spark.util.SparkErrorUtils.tryWithSafeFinally(SparkErrorUtils.scala:64)
	at org.apache.spark.util.SparkErrorUtils.tryWithSafeFinally$(SparkErrorUtils.scala:61)
	at org.apache.spark.util.Utils$.tryWithSafeFinally(Utils.scala:94)
	at org.apache.spark.executor.Executor$TaskRunner.run(Executor.scala:624)
	at java.base/java.util.concurrent.ThreadPoolExecutor.runWorker(Unknown Source)
	at java.base/java.util.concurrent.ThreadPoolExecutor$Worker.run(Unknown Source)
	at java.base/java.lang.Thread.run(Unknown Source)

Driver stacktrace:
	at org.apache.spark.scheduler.DAGScheduler.failJobAndIndependentStages(DAGScheduler.scala:2898)
	at org.apache.spark.scheduler.DAGScheduler.$anonfun$abortStage$2(DAGScheduler.scala:2834)
	at org.apache.spark.scheduler.DAGScheduler.$anonfun$abortStage$2$adapted(DAGScheduler.scala:2833)
	at scala.collection.mutable.ResizableArray.foreach(ResizableArray.scala:62)
	at scala.collection.mutable.ResizableArray.foreach$(ResizableArray.scala:55)
	at scala.collection.mutable.ArrayBuffer.foreach(ArrayBuffer.scala:49)
	at org.apache.spark.scheduler.DAGScheduler.abortStage(DAGScheduler.scala:2833)
	at org.apache.spark.scheduler.DAGScheduler.$anonfun$handleTaskSetFailed$1(DAGScheduler.scala:1253)
	at org.apache.spark.scheduler.DAGScheduler.$anonfun$handleTaskSetFailed$1$adapted(DAGScheduler.scala:1253)
	at scala.Option.foreach(Option.scala:407)
	at org.apache.spark.scheduler.DAGScheduler.handleTaskSetFailed(DAGScheduler.scala:1253)
	at org.apache.spark.scheduler.DAGSchedulerEventProcessLoop.doOnReceive(DAGScheduler.scala:3102)
	at org.apache.spark.scheduler.DAGSchedulerEventProcessLoop.onReceive(DAGScheduler.scala:3036)
	at org.apache.spark.scheduler.DAGSchedulerEventProcessLoop.onReceive(DAGScheduler.scala:3025)
	at org.apache.spark.util.EventLoop$$anon$1.run(EventLoop.scala:49)
	at org.apache.spark.scheduler.DAGScheduler.runJob(DAGScheduler.scala:995)
	at org.apache.spark.SparkContext.runJob(SparkContext.scala:2393)
	at org.apache.spark.SparkContext.runJob(SparkContext.scala:2414)
	at org.apache.spark.SparkContext.runJob(SparkContext.scala:2433)
	at org.apache.spark.SparkContext.runJob(SparkContext.scala:2458)
	at org.apache.spark.rdd.RDD.$anonfun$collect$1(RDD.scala:1049)
	at org.apache.spark.rdd.RDDOperationScope$.withScope(RDDOperationScope.scala:151)
	at org.apache.spark.rdd.RDDOperationScope$.withScope(RDDOperationScope.scala:112)
	at org.apache.spark.rdd.RDD.withScope(RDD.scala:410)
	at org.apache.spark.rdd.RDD.collect(RDD.scala:1048)
	at org.apache.spark.sql.execution.SparkPlan.executeCollect(SparkPlan.scala:448)
	at org.apache.spark.sql.execution.adaptive.AdaptiveSparkPlanExec.$anonfun$executeCollect$1(AdaptiveSparkPlanExec.scala:392)
	at org.apache.spark.sql.execution.adaptive.AdaptiveSparkPlanExec.withFinalPlanUpdate(AdaptiveSparkPlanExec.scala:420)
	at org.apache.spark.sql.execution.adaptive.AdaptiveSparkPlanExec.executeCollect(AdaptiveSparkPlanExec.scala:392)
	at org.apache.spark.sql.Dataset.collectFromPlan(Dataset.scala:4333)
	at org.apache.spark.sql.Dataset.$anonfun$head$1(Dataset.scala:3316)
	at org.apache.spark.sql.Dataset.$anonfun$withAction$2(Dataset.scala:4323)
	at org.apache.spark.sql.execution.QueryExecution$.withInternalError(QueryExecution.scala:546)
	at org.apache.spark.sql.Dataset.$anonfun$withAction$1(Dataset.scala:4321)
	at org.apache.spark.sql.execution.SQLExecution$.$anonfun$withNewExecutionId$6(SQLExecution.scala:125)
	at org.apache.spark.sql.execution.SQLExecution$.withSQLConfPropagated(SQLExecution.scala:201)
	at org.apache.spark.sql.execution.SQLExecution$.$anonfun$withNewExecutionId$1(SQLExecution.scala:108)
	at org.apache.spark.sql.SparkSession.withActive(SparkSession.scala:900)
	at org.apache.spark.sql.execution.SQLExecution$.withNewExecutionId(SQLExecution.scala:66)
	at org.apache.spark.sql.Dataset.withAction(Dataset.scala:4321)
	at org.apache.spark.sql.Dataset.head(Dataset.scala:3316)
	at org.apache.spark.sql.Dataset.head(Dataset.scala:3323)
	at org.apache.spark.sql.Dataset.first(Dataset.scala:3330)
	at org.apache.spark.ml.feature.StandardScaler.fit(StandardScaler.scala:113)
	at org.apache.spark.ml.feature.StandardScaler.fit(StandardScaler.scala:84)
	at java.base/jdk.internal.reflect.NativeMethodAccessorImpl.invoke0(Native Method)
	at java.base/jdk.internal.reflect.NativeMethodAccessorImpl.invoke(Unknown Source)
	at java.base/jdk.internal.reflect.DelegatingMethodAccessorImpl.invoke(Unknown Source)
	at java.base/java.lang.reflect.Method.invoke(Unknown Source)
	at py4j.reflection.MethodInvoker.invoke(MethodInvoker.java:244)
	at py4j.reflection.ReflectionEngine.invoke(ReflectionEngine.java:374)
	at py4j.Gateway.invoke(Gateway.java:282)
	at py4j.commands.AbstractCommand.invokeMethod(AbstractCommand.java:132)
	at py4j.commands.CallCommand.execute(CallCommand.java:79)
	at py4j.ClientServerConnection.waitForCommands(ClientServerConnection.java:182)
	at py4j.ClientServerConnection.run(ClientServerConnection.java:106)
	at java.base/java.lang.Thread.run(Unknown Source)
Caused by: java.lang.IllegalArgumentException: requirement failed: Nothing has been added to this summarizer.
	at scala.Predef$.require(Predef.scala:281)
	at org.apache.spark.ml.stat.SummarizerBuffer.mean(Summarizer.scala:624)
	at org.apache.spark.ml.stat.SummaryBuilderImpl$MetricsAggregate.$anonfun$eval$1(Summarizer.scala:358)
	at scala.collection.TraversableLike.$anonfun$map$1(TraversableLike.scala:286)
	at scala.collection.mutable.ResizableArray.foreach(ResizableArray.scala:62)
	at scala.collection.mutable.ResizableArray.foreach$(ResizableArray.scala:55)
	at scala.collection.mutable.ArrayBuffer.foreach(ArrayBuffer.scala:49)
	at scala.collection.TraversableLike.map(TraversableLike.scala:286)
	at scala.collection.TraversableLike.map$(TraversableLike.scala:279)
	at scala.collection.AbstractTraversable.map(Traversable.scala:108)
	at org.apache.spark.ml.stat.SummaryBuilderImpl$MetricsAggregate.eval(Summarizer.scala:357)
	at org.apache.spark.ml.stat.SummaryBuilderImpl$MetricsAggregate.eval(Summarizer.scala:345)
	at org.apache.spark.sql.catalyst.expressions.aggregate.TypedImperativeAggregate.eval(interfaces.scala:594)
	at org.apache.spark.sql.execution.aggregate.AggregationIterator.$anonfun$generateResultProjection$5(AggregationIterator.scala:257)
	at org.apache.spark.sql.execution.aggregate.ObjectAggregationIterator.next(ObjectAggregationIterator.scala:97)
	at org.apache.spark.sql.execution.aggregate.ObjectAggregationIterator.next(ObjectAggregationIterator.scala:33)
	at org.apache.spark.sql.catalyst.expressions.GeneratedClass$GeneratedIteratorForCodegenStage2.processNext(Unknown Source)
	at org.apache.spark.sql.execution.BufferedRowIterator.hasNext(BufferedRowIterator.java:43)
	at org.apache.spark.sql.execution.WholeStageCodegenEvaluatorFactory$WholeStageCodegenPartitionEvaluator$$anon$1.hasNext(WholeStageCodegenEvaluatorFactory.scala:43)
	at org.apache.spark.sql.execution.SparkPlan.$anonfun$getByteArrayRdd$1(SparkPlan.scala:388)
	at org.apache.spark.rdd.RDD.$anonfun$mapPartitionsInternal$2(RDD.scala:893)
	at org.apache.spark.rdd.RDD.$anonfun$mapPartitionsInternal$2$adapted(RDD.scala:893)
	at org.apache.spark.rdd.MapPartitionsRDD.compute(MapPartitionsRDD.scala:52)
	at org.apache.spark.rdd.RDD.computeOrReadCheckpoint(RDD.scala:367)
	at org.apache.spark.rdd.RDD.iterator(RDD.scala:331)
	at org.apache.spark.scheduler.ResultTask.runTask(ResultTask.scala:93)
	at org.apache.spark.TaskContext.runTaskWithListeners(TaskContext.scala:166)
	at org.apache.spark.scheduler.Task.run(Task.scala:141)
	at org.apache.spark.executor.Executor$TaskRunner.$anonfun$run$4(Executor.scala:621)
	at org.apache.spark.util.SparkErrorUtils.tryWithSafeFinally(SparkErrorUtils.scala:64)
	at org.apache.spark.util.SparkErrorUtils.tryWithSafeFinally$(SparkErrorUtils.scala:61)
	at org.apache.spark.util.Utils$.tryWithSafeFinally(Utils.scala:94)
	at org.apache.spark.executor.Executor$TaskRunner.run(Executor.scala:624)
	at java.base/java.util.concurrent.ThreadPoolExecutor.runWorker(Unknown Source)
	at java.base/java.util.concurrent.ThreadPoolExecutor$Worker.run(Unknown Source)
	... 1 more


## Results

In [None]:
best_model = cv_model.bestModel
best_params = {
    "numTrees": best_model.stages[-1].getNumTrees,
    "maxDepth": best_model.stages[-1].getMaxDepth(),
}

val_accuracy = evaluator.evaluate(cv_model.transform(val))
test_accuracy = evaluator.evaluate(cv_model.transform(test))

print(f"Best: {best_params}")
print(f"Val: {val_accuracy:.4f}, Test: {test_accuracy:.4f}")

25/11/09 20:54:44 WARN TaskSetManager: Lost task 0.0 in stage 65.0 (TID 128) (172.19.0.6 executor 0): java.lang.IllegalArgumentException: requirement failed: Nothing has been added to this summarizer.
	at scala.Predef$.require(Predef.scala:281)
	at org.apache.spark.ml.stat.SummarizerBuffer.mean(Summarizer.scala:624)
	at org.apache.spark.ml.stat.SummaryBuilderImpl$MetricsAggregate.$anonfun$eval$1(Summarizer.scala:358)
	at scala.collection.TraversableLike.$anonfun$map$1(TraversableLike.scala:286)
	at scala.collection.mutable.ResizableArray.foreach(ResizableArray.scala:62)
	at scala.collection.mutable.ResizableArray.foreach$(ResizableArray.scala:55)
	at scala.collection.mutable.ArrayBuffer.foreach(ArrayBuffer.scala:49)
	at scala.collection.TraversableLike.map(TraversableLike.scala:286)
	at scala.collection.TraversableLike.map$(TraversableLike.scala:279)
	at scala.collection.AbstractTraversable.map(Traversable.scala:108)
	at org.apache.spark.ml.stat.SummaryBuilderImpl$MetricsAggregate.eval

## Metrics

In [None]:
stage_metrics.end()
metrics = stage_metrics.aggregate_stage_metrics()
total_time = (metrics['elapsedTime'] / 1000.0) if metrics else 0

print(f"Time: {total_time:.2f}s")
print(f"CPU: {metrics.get('executorCpuTime', 0) / 1000.0:.2f}s")
print(f"Shuffle Read: {metrics.get('shuffleReadBytes', 0) / (1024**3):.2f} GB")
print(f"Shuffle Write: {metrics.get('shuffleWriteBytes', 0) / (1024**3):.2f} GB")

## Cleanup

In [None]:
train.unpersist()
val.unpersist()
test.unpersist()
spark.stop()