In [0]:
from pyspark.ml import PipelineModel
from pyspark.ml.tuning import ParamGridBuilder, TrainValidationSplit
from pyspark.ml.evaluation import MulticlassClassificationEvaluator
from pyspark.ml.classification import DecisionTreeClassifier
from pyspark.sql import functions as F

# Load slicing stage (top-10 features)
slicer_model = PipelineModel.load("/FileStore/models/slicer_top10")

# Load preprocessed datasets
train_ready = spark.read.format("delta").load("/FileStore/data/train_ready")
val_ready   = spark.read.format("delta").load("/FileStore/data/val_ready")

# Apply slicer to keep only top-10 features
train_topk = slicer_model.transform(train_ready)
val_topk   = slicer_model.transform(val_ready)

# Efficient class balancing using sampleBy
fractions = {0: 0.6, 1: 1.0}  # 60% of majority, 100% of minority
train_sample = train_topk.sampleBy("label", fractions=fractions, seed=42)

# Define Decision Tree
dt = DecisionTreeClassifier(
    labelCol="label",
    featuresCol="features_topK",
    seed=42
)

# Small param grid (to limit resource usage)
param_grid = ParamGridBuilder() \
    .addGrid(dt.maxDepth, [5, 10, 15]) \
    .addGrid(dt.maxBins, [32, 64]) \
    .build()

# Evaluator
evaluator = MulticlassClassificationEvaluator(
    labelCol="label",
    predictionCol="prediction",
    metricName="f1"
)

# TrainValidationSplit (full training set used for internal tuning)
tvs = TrainValidationSplit(
    estimator=dt,
    estimatorParamMaps=param_grid,
    evaluator=evaluator,
    trainRatio=1.0,
    parallelism=1  # For Databricks CE
)

# Train the model with grid search
tvs_model = tvs.fit(train_sample)

# Evaluate on external validation set
val_preds = tvs_model.transform(val_topk)
f1_score = evaluator.evaluate(val_preds)

print(f"Best Decision Tree model F1-score on validation set: {f1_score:.4f}")

# Save best model
tvs_model.bestModel.write().overwrite().save("/FileStore/models/dt_top10_model_grid")

[0;31m---------------------------------------------------------------------------[0m
[0;31mPy4JJavaError[0m                             Traceback (most recent call last)
File [0;32m<command-2282065208578701>:55[0m
[1;32m     46[0m tvs [38;5;241m=[39m TrainValidationSplit(
[1;32m     47[0m     estimator[38;5;241m=[39mlr,
[1;32m     48[0m     estimatorParamMaps[38;5;241m=[39mparam_grid,
[0;32m   (...)[0m
[1;32m     51[0m     parallelism[38;5;241m=[39m[38;5;241m1[39m
[1;32m     52[0m )
[1;32m     54[0m [38;5;66;03m# 11. Treino[39;00m
[0;32m---> 55[0m tvs_model [38;5;241m=[39m tvs[38;5;241m.[39mfit(train_sample)
[1;32m     57[0m [38;5;66;03m# 12. Avaliação[39;00m
[1;32m     58[0m val_preds [38;5;241m=[39m tvs_model[38;5;241m.[39mtransform(val_topk)

File [0;32m/databricks/python_shell/dbruntime/MLWorkloadsInstrumentation/_pyspark.py:30[0m, in [0;36m_create_patch_function.<locals>.patched_method[0;34m(self, *args, **kwargs)[0m
[1;32m    