In [1]:
import os
# Set JAVA_HOME to Java 17 which is already installed.
os.environ["JAVA_HOME"] = "/usr/lib/jvm/java-17-openjdk-amd64"
os.environ["PATH"] += os.pathsep + os.path.join(os.environ["JAVA_HOME"], "bin")

#Install the required libraries
!pip install pyspark
# Initialize a Spark session
from pyspark.sql import SparkSession

# Stop any existing Spark session to ensure new configurations take effect
if 'spark' in locals() and spark is not None:
    spark.stop()

spark = (
    SparkSession.builder
    .appName("03_classification_experiments_pyspark")
    .master("local[*]")
    .config("spark.driver.memory", "8g")   # Increased from 6g
    .config("spark.executor.memory", "6g")
    .config("spark.sql.shuffle.partitions", "8")
    .getOrCreate()
)

spark.sparkContext.setLogLevel("WARN")



In [2]:
import os
import argparse
from pyspark.sql import SparkSession
from pyspark.ml.classification import RandomForestClassifier
from pyspark.ml.evaluation import BinaryClassificationEvaluator, MulticlassClassificationEvaluator
from pyspark.ml import PipelineModel
from pyspark.ml.tuning import CrossValidator, ParamGridBuilder
from pyspark.ml.feature import VectorAssembler
import pyspark.sql.functions as F
import json

parser = argparse.ArgumentParser()
parser.add_argument("--featured_parquet", default="/content/data/featured.paraquet")
parser.add_argument("--out_dir", default="/content/results")
args = parser.parse_args(args=[])

In [3]:
df = spark.read.parquet(args.featured_parquet)
print("Loaded featured data count:", df.count())


Loaded featured data count: 2499784


In [4]:
if 'y_binary' not in df.columns:
    raise RuntimeError("y_binary target not found in featured data. Run preprocessing to create label.")


In [5]:
df = df.filter(F.col("features").isNotNull())

In [6]:
train, test = df.randomSplit([0.7, 0.3], seed=42)
print("Train/Test counts:", train.count(), test.count())

Train/Test counts: 1749381 750403


In [7]:
train = train.repartition(4).cache()
train.count()   # materialize cache


1749381

In [8]:
rf = RandomForestClassifier(featuresCol="features", labelCol="y_binary", probabilityCol="probability", rawPredictionCol="rawPrediction", predictionCol="prediction", seed=42, numTrees=100)

In [9]:
paramGrid = (ParamGridBuilder()
             .addGrid(rf.maxDepth, [10, 20])
             .addGrid(rf.numTrees, [100])
             .build())


In [10]:
evaluator = BinaryClassificationEvaluator(labelCol="y_binary", rawPredictionCol="rawPrediction", metricName="areaUnderROC")
cv = CrossValidator(estimator=rf, estimatorParamMaps=paramGrid, evaluator=evaluator, numFolds=3, parallelism=2)


In [11]:
cv = CrossValidator(
    estimator=rf,
    estimatorParamMaps=paramGrid,
    evaluator=evaluator,
    numFolds=3,        # NOT 5
    parallelism=1     # VERY IMPORTANT
)


In [12]:
rf = RandomForestClassifier(
    featuresCol="features",
    labelCol="y_binary",
    numTrees=50,        # not 200
    maxDepth=5,         # not 10+
    subsamplingRate=0.7
)


In [15]:
from pyspark.ml.linalg import Vectors, VectorUDT
from pyspark.sql.types import BooleanType
from pyspark.sql.functions import udf
import math

# Re-define rf with smaller parameters to reduce memory footprint
rf = RandomForestClassifier(
    featuresCol="features",
    labelCol="y_binary",
    probabilityCol="probability",
    rawPredictionCol="rawPrediction",
    predictionCol="prediction",
    seed=42,
    numTrees=50,       # Reduced from 100
    maxDepth=5         # Reduced from [10, 20] search space
)

# Re-define paramGrid with a smaller, single set of parameters for testing
# This significantly reduces the memory consumption during cross-validation.
paramGrid = (ParamGridBuilder()
             .addGrid(rf.maxDepth, [5])
             .addGrid(rf.numTrees, [50])
             .build())

# Re-define cv with the new rf and paramGrid
# Ensure parallelism is set to 1 to avoid concurrent model training,
# which can exacerbate OutOfMemory errors.
evaluator = BinaryClassificationEvaluator(labelCol="y_binary", rawPredictionCol="rawPrediction", metricName="areaUnderROC")
cv = CrossValidator(
    estimator=rf,
    estimatorParamMaps=paramGrid,
    evaluator=evaluator,
    numFolds=3,
    parallelism=1
)

def has_nan_or_inf(v):
    if v is None: # Handle cases where the vector itself might be null, though previously filtered
        return True # Treat null vectors as invalid for this check
    for val in v.toArray():
        if math.isnan(val) or math.isinf(val):
            return True
    return False

has_nan_or_inf_udf = udf(has_nan_or_inf, BooleanType())

# Filter 'train' DataFrame to remove rows with NaN/Infinity in 'features' vector
# Only do this if it hasn't been done to the parent 'df' already
train = train.filter(~has_nan_or_inf_udf(train['features']))

cvModel = cv.fit(train)
bestModel = cvModel.bestModel
print("Best model params:", bestModel._java_obj.getMaxDepth(), bestModel.getNumTrees)

Best model params: 5 50


In [16]:
preds = bestModel.transform(test)
bauc = evaluator.evaluate(preds)
mce = MulticlassClassificationEvaluator(labelCol="y_binary", predictionCol="prediction", metricName="f1")
f1 = mce.evaluate(preds)
precision_ev = MulticlassClassificationEvaluator(labelCol="y_binary", predictionCol="prediction", metricName="weightedPrecision")
recall_ev = MulticlassClassificationEvaluator(labelCol="y_binary", predictionCol="prediction", metricName="weightedRecall")
prec = precision_ev.evaluate(preds)
rec = recall_ev.evaluate(preds)

In [17]:
metrics = {"roc_auc": bauc, "f1": f1, "precision": prec, "recall": rec}
os.makedirs(args.out_dir, exist_ok=True)
with open(os.path.join(args.out_dir, "rf_metrics.json"), "w") as f:
    json.dump(metrics, f, indent=2)
print("Saved metrics:", metrics)

Saved metrics: {'roc_auc': 0.6666578594716507, 'f1': 0.8744697893318645, 'precision': 0.872267851304158, 'recall': 0.8781228220036433}


In [18]:
model_path = os.path.join("models", "rf_spark_model")
bestModel.write().overwrite().save(model_path)
print("Saved best RF model to:", model_path)

Saved best RF model to: models/rf_spark_model


In [19]:
preds.select("prediction", "probability", "y_binary").limit(1000).toPandas().to_csv(os.path.join(args.out_dir, "rf_preds_sample.csv"), index=False)
print("Saved predictions sample to results/")

Saved predictions sample to results/


In [None]:
spark.stop()