In [1]:
from pyspark.sql import SparkSession
import time

In [2]:
spark = SparkSession.builder \
    .appName("fractal-cv-rf") \
    .master("spark://spark-master:7077") \
    .config("spark.hadoop.fs.s3a.impl", "org.apache.hadoop.fs.s3a.S3AFileSystem") \
    .config("spark.hadoop.fs.s3a.aws.credentials.provider", "com.amazonaws.auth.DefaultAWSCredentialsProviderChain") \
    .config("spark.executor.memory", "4g") \
    .config("spark.driver.memory", "2g") \
    .getOrCreate()

Setting default log level to "WARN".
To adjust logging level use sc.setLogLevel(newLevel). For SparkR, use setLogLevel(newLevel).
25/11/05 07:22:44 WARN NativeCodeLoader: Unable to load native-hadoop library for your platform... using builtin-java classes where applicable


In [3]:
from pyspark.sql.functions import col, when
from pyspark.ml.feature import VectorAssembler
from pyspark.ml.classification import RandomForestClassifier
from pyspark.ml.tuning import CrossValidator, ParamGridBuilder
from pyspark.ml.evaluation import MulticlassClassificationEvaluator

In [4]:
parq_cols = ["xyz", "Intensity", "Classification", "Red", "Green", "Blue", "Infrared"]
data_path = "/opt/spark/work-dir/data/FRACTAL"

In [5]:
def prepare_data(df):
    df = df.withColumn("z", col("xyz")[2])
    df = df.withColumn("ndvi", when(
        (col("Infrared") + col("Red")) != 0,
        (col("Infrared") - col("Red")) / (col("Infrared") + col("Red"))
    ).otherwise(0))
    
    feature_assembler = VectorAssembler(
        inputCols=["z", "Intensity", "Red", "Green", "Blue", "Infrared", "ndvi"],
        outputCol="features"
    )
    
    df = feature_assembler.transform(df)
    return df.select("features", col("Classification").alias("label"))

In [6]:
def load_sample(path, fraction=0.2): # get 20 percent of the data
    return prepare_data(spark.read.parquet(path).select(*parq_cols).sample(fraction=fraction, seed=42))

train = load_sample(f"{data_path}/train/")
val = load_sample(f"{data_path}/val/")
test = load_sample(f"{data_path}/test/")

In [7]:
evaluator = MulticlassClassificationEvaluator(labelCol="label", predictionCol="prediction", metricName="accuracy")
best_accuracy = 0
best_params = {}

rf_template = RandomForestClassifier(labelCol="label", featuresCol="features", seed=42)

for num_trees in [50, 100, 200]:
    for max_depth in [10, 15, 20]:
        for max_bins in [32, 64]:
            rf = RandomForestClassifier(
                labelCol="label", 
                featuresCol="features", 
                numTrees=num_trees,
                maxDepth=max_depth,
                maxBins=max_bins,
                seed=42
            )
            
            model = rf.fit(val)
            accuracy = evaluator.evaluate(model.transform(val))
            
            if accuracy > best_accuracy:
                best_accuracy = accuracy
                best_params = {"numTrees": num_trees, "maxDepth": max_depth, "maxBins": max_bins}
                print(f"New best: {best_params} -> Accuracy: {accuracy:.4f}")

25/11/05 07:30:25 WARN DAGScheduler: Broadcasting large task binary with size 1453.6 KiB
25/11/05 07:32:11 WARN DAGScheduler: Broadcasting large task binary with size 2.8 MiB
25/11/05 07:34:21 WARN DAGScheduler: Broadcasting large task binary with size 5.7 MiB
25/11/05 07:37:01 WARN DAGScheduler: Broadcasting large task binary with size 1916.6 KiB
25/11/05 07:37:15 WARN DAGScheduler: Broadcasting large task binary with size 9.3 MiB
25/11/05 07:39:57 WARN DAGScheduler: Broadcasting large task binary with size 3.0 MiB
25/11/05 07:40:15 WARN DAGScheduler: Broadcasting large task binary with size 9.5 MiB
25/11/05 07:42:06 WARN DAGScheduler: Broadcasting large task binary with size 3.0 MiB
25/11/05 07:42:24 WARN DAGScheduler: Broadcasting large task binary with size 10.0 MiB
25/11/05 07:43:50 WARN DAGScheduler: Broadcasting large task binary with size 3.0 MiB
25/11/05 07:44:07 WARN DAGScheduler: Broadcasting large task binary with size 10.0 MiB
25/11/05 07:45:33 WARN DAGScheduler: Broadcast

New best: {'numTrees': 50, 'maxDepth': 10, 'maxBins': 32} -> Accuracy: 0.7419


25/11/05 08:20:45 WARN DAGScheduler: Broadcasting large task binary with size 1455.3 KiB
25/11/05 08:22:51 WARN DAGScheduler: Broadcasting large task binary with size 2.8 MiB
25/11/05 08:25:31 WARN DAGScheduler: Broadcasting large task binary with size 4.7 MiB
25/11/05 08:27:59 WARN DAGScheduler: Broadcasting large task binary with size 1563.4 KiB
25/11/05 08:28:17 WARN DAGScheduler: Broadcasting large task binary with size 4.8 MiB
25/11/05 08:30:01 WARN DAGScheduler: Broadcasting large task binary with size 1563.4 KiB
25/11/05 08:30:19 WARN DAGScheduler: Broadcasting large task binary with size 5.0 MiB
25/11/05 08:31:36 WARN DAGScheduler: Broadcasting large task binary with size 1563.4 KiB
25/11/05 08:31:52 WARN DAGScheduler: Broadcasting large task binary with size 5.5 MiB
25/11/05 08:32:56 WARN DAGScheduler: Broadcasting large task binary with size 1563.4 KiB
25/11/05 08:33:12 WARN DAGScheduler: Broadcasting large task binary with size 5.3 MiB
25/11/05 08:34:15 WARN DAGScheduler: Br

New best: {'numTrees': 50, 'maxDepth': 10, 'maxBins': 64} -> Accuracy: 0.7427


25/11/05 09:22:30 WARN DAGScheduler: Broadcasting large task binary with size 1453.6 KiB
25/11/05 09:24:27 WARN DAGScheduler: Broadcasting large task binary with size 2.8 MiB
25/11/05 09:26:50 WARN DAGScheduler: Broadcasting large task binary with size 5.7 MiB
ERROR:root:KeyboardInterrupt while sending command.            (100 + 20) / 574]
Traceback (most recent call last):
  File "/usr/local/lib/python3.8/dist-packages/py4j/java_gateway.py", line 1038, in send_command
    response = connection.send_command(command)
  File "/usr/local/lib/python3.8/dist-packages/py4j/clientserver.py", line 511, in send_command
    answer = smart_decode(self.stream.readline()[:-1])
  File "/usr/lib/python3.8/socket.py", line 669, in readinto
    return self._sock.recv_into(b)
KeyboardInterrupt

KeyboardInterrupt: 

In [None]:
print(f"\nBest Params: {best_params}")

In [None]:
best_model = RandomForestClassifier(
    labelCol="label", 
    featuresCol="features",
    numTrees=best_params["numTrees"],
    maxDepth=best_params["maxDepth"],
    maxBins=best_params["maxBins"],
    seed=42
).fit(train)

In [None]:
print(f"Test Accuracy: {evaluator.evaluate(best_model.transform(test)):.4f}")

25/11/05 09:27:40 ERROR Instrumentation: org.apache.spark.SparkException: Job 69 cancelled because SparkContext was shut down
	at org.apache.spark.scheduler.DAGScheduler.$anonfun$cleanUpAfterSchedulerStop$1(DAGScheduler.scala:1259)
	at org.apache.spark.scheduler.DAGScheduler.$anonfun$cleanUpAfterSchedulerStop$1$adapted(DAGScheduler.scala:1257)
	at scala.collection.mutable.HashSet.foreach(HashSet.scala:79)
	at org.apache.spark.scheduler.DAGScheduler.cleanUpAfterSchedulerStop(DAGScheduler.scala:1257)
	at org.apache.spark.scheduler.DAGSchedulerEventProcessLoop.onStop(DAGScheduler.scala:3129)
	at org.apache.spark.util.EventLoop.stop(EventLoop.scala:84)
	at org.apache.spark.scheduler.DAGScheduler.$anonfun$stop$3(DAGScheduler.scala:3015)
	at org.apache.spark.util.Utils$.tryLogNonFatalError(Utils.scala:1375)
	at org.apache.spark.scheduler.DAGScheduler.stop(DAGScheduler.scala:3015)
	at org.apache.spark.SparkContext.$anonfun$stop$12(SparkContext.scala:2258)
	at org.apache.spark.util.Utils$.tryL

In [None]:
spark.stop()