In [89]:
import time
from pyspark.ml import Pipeline
from pyspark.ml.classification import RandomForestClassifier
from pyspark.ml.evaluation import MulticlassClassificationEvaluator
from pyspark.ml.feature import StandardScaler, VectorAssembler
from pyspark.sql import SparkSession
from pyspark.sql.functions import col, when
from sparkmeasure import StageMetrics

In [90]:
# data_path = "/opt/spark/work-dir/data/FRACTAL"
data_path = "s3a://ubs-datasets/FRACTAL/data"
sample_fraction = 0.000001
num_executors = 2

In [91]:
spark = SparkSession.builder.appName("fractal-rf") \
    .master("spark://spark-master:7077") \
    .config("spark.hadoop.fs.s3a.impl", "org.apache.hadoop.fs.s3a.S3AFileSystem") \
    .config("spark.hadoop.fs.s3a.aws.credentials.provider", "com.amazonaws.auth.DefaultAWSCredentialsProviderChain") \
    .config("spark.executor.memory", "8g") \
    .config("spark.driver.memory", "8g") \
    .config("spark.executor.cores", "2") \
    .config("spark.executor.instances", str(num_executors)) \
    .config("spark.driver.maxResultSize", "4g") \
    .config("spark.eventLog.enabled", "true") \
    .config("spark.eventLog.dir", "/opt/spark/spark-events") \
    .getOrCreate()

In [92]:
stage_metrics = StageMetrics(spark)
print(f"Spark UI: {spark.sparkContext.uiWebUrl}")

Spark UI: http://2b8282c8bdeb:4040


In [93]:
def prepare_data(df):
    return df.withColumn("z_raw", col("xyz")[2]) \
        .withColumn(
            "ndvi",
            when(
                (col("Infrared") + col("Red")) != 0,
                (col("Infrared") - col("Red")) / (col("Infrared") + col("Red")),
            ).otherwise(0),
        ) \
        .select(
            "z_raw", "Intensity", "Red", "Green", "Blue", "Infrared", "ndvi",
            col("Classification").alias("label"),
        )


def load_sample(spark, path, fraction, cols):
    print(f"Loading data from {path} with fraction={fraction}")
    
    sc = spark.sparkContext
    hadoop_conf = sc._jsc.hadoopConfiguration()
    
    uri = sc._jvm.java.net.URI(path)
    fs = sc._jvm.org.apache.hadoop.fs.FileSystem.get(uri, hadoop_conf)
    file_path = sc._jvm.org.apache.hadoop.fs.Path(path)
    
    all_files = [
        str(f.getPath()) for f in fs.listStatus(file_path)
        if str(f.getPath()).endswith(".parquet")
    ]
    
    num_files = max(1, int(len(all_files) * fraction))
    selected_files = sorted(all_files)[:num_files]
    
    print(f"Loading {num_files}/{len(all_files)} files ({fraction*100:.1f}%)")
    
    df = spark.read.parquet(*selected_files).select(*cols)
    df = prepare_data(df)
    row_count = df.count()
    
    if row_count == 0:
        raise ValueError(f"No data loaded from {path}. Check data path and fraction.")
    
    print(f"Loaded {row_count} rows")
    return df

In [94]:
cols = ["xyz", "Intensity", "Classification", "Red", "Green", "Blue", "Infrared"]

stage_metrics.begin()
start_time = time.time()

print("Loading datasets")
train = load_sample(spark, f"{data_path}/train/", sample_fraction, cols)
val = load_sample(spark, f"{data_path}/val/", sample_fraction, cols)
test = load_sample(spark, f"{data_path}/test/", sample_fraction, cols)

Loading datasets
Loading data from s3a://ubs-datasets/FRACTAL/data/train/ with fraction=1e-06
Loading 1/80000 files (0.0%)


                                                                                

Loaded 31408 rows
Loading data from s3a://ubs-datasets/FRACTAL/data/val/ with fraction=1e-06
Loading 1/10000 files (0.0%)


                                                                                

Loaded 80164 rows
Loading data from s3a://ubs-datasets/FRACTAL/data/test/ with fraction=1e-06
Loading 1/10000 files (0.0%)
Loaded 107172 rows


In [95]:
z_assembler = VectorAssembler(
    inputCols=["z_raw"], outputCol="z_vec", handleInvalid="skip"
)
z_scaler = StandardScaler(
    inputCol="z_vec", outputCol="z", withMean=False, withStd=True
)
assembler = VectorAssembler(
    inputCols=["z", "Intensity", "Red", "Green", "Blue", "Infrared", "ndvi"],
    outputCol="features",
    handleInvalid="skip",
)

rf = RandomForestClassifier(
    labelCol="label",
    featuresCol="features",
    numTrees=5,
    maxDepth=20,
    seed=62
)

pipeline = Pipeline(stages=[z_assembler, z_scaler, assembler, rf])

In [96]:
evaluator = MulticlassClassificationEvaluator(
    labelCol="label", predictionCol="prediction", metricName="accuracy"
)

model = pipeline.fit(train)

25/11/10 21:55:53 WARN DAGScheduler: Broadcasting large task binary with size 1005.8 KiB
25/11/10 21:55:53 WARN DAGScheduler: Broadcasting large task binary with size 1108.6 KiB
25/11/10 21:55:53 WARN DAGScheduler: Broadcasting large task binary with size 1189.5 KiB
25/11/10 21:55:53 WARN DAGScheduler: Broadcasting large task binary with size 1254.7 KiB
25/11/10 21:55:54 WARN DAGScheduler: Broadcasting large task binary with size 1303.5 KiB
25/11/10 21:55:54 WARN DAGScheduler: Broadcasting large task binary with size 1341.1 KiB
25/11/10 21:55:54 WARN DAGScheduler: Broadcasting large task binary with size 1366.7 KiB


## Results

In [97]:
val_predictions = model.transform(val)
val_accuracy = evaluator.evaluate(val_predictions)

test_predictions = model.transform(test)
test_accuracy = evaluator.evaluate(test_predictions)

print(f"Val: {val_accuracy:.4f}, Test: {test_accuracy:.4f}")



Val: 0.0002, Test: 0.2720


                                                                                

In [98]:
stage_metrics.end()
total_time = time.time() - start_time
print(f"Total time: {total_time:.2f}s")

Total time: 67.02s


In [99]:
stage_metrics.aggregate_stagemetrics()

{'numStages': 61, 'numTasks': 64, 'elapsedTime': 36573, 'stageDuration': 26662, 'executorRunTime': 25800, 'executorCpuTime': 4875, 'executorDeserializeTime': 1235, 'executorDeserializeCpuTime': 1082, 'resultSerializationTime': 13, 'jvmGCTime': 80, 'shuffleFetchWaitTime': 14, 'shuffleWriteTime': 7, 'resultSize': 127813, 'diskBytesSpilled': 0, 'memoryBytesSpilled': 0, 'peakExecutionMemory': 43704929, 'recordsRead': 1163968, 'bytesRead': 100428630, 'recordsWritten': 0, 'bytesWritten': 0, 'shuffleRecordsRead': 4089, 'shuffleTotalBlocksFetched': 30, 'shuffleLocalBlocksFetched': 29, 'shuffleRemoteBlocksFetched': 1, 'shuffleTotalBytesRead': 1364636, 'shuffleLocalBytesRead': 1364577, 'shuffleRemoteBytesRead': 59, 'shuffleRemoteBytesReadToDisk': 0, 'shuffleBytesWritten': 1364636, 'shuffleRecordsWritten': 4089}

In [100]:
train.unpersist()
val.unpersist()
test.unpersist()
val_predictions.unpersist()
test_predictions.unpersist()
spark.stop()