In [0]:
from pyspark.ml import PipelineModel
from pyspark.ml.classification import GBTClassifier
from pyspark.mllib.evaluation import MulticlassMetrics
from pyspark.sql.functions import col, udf
from pyspark.sql.types import DoubleType
import warnings
warnings.filterwarnings("ignore", category=FutureWarning)

# Load feature slicer and prepared data
slicer_model = PipelineModel.load("/FileStore/models/slicer_top10")
train_ready = spark.read.format("delta").load("/FileStore/data/train_ready")
val_ready = spark.read.format("delta").load("/FileStore/data/val_ready")

# Apply feature selection
train_topk = slicer_model.transform(train_ready)
val_topk = slicer_model.transform(val_ready)

# Light undersampling of the majority class
minority_df = train_topk.filter(col("label") == 1)
majority_df = train_topk.filter(col("label") != 1)
train_balanced = majority_df.sample(False, 0.2, seed=42).union(minority_df)

# Check class distribution after balancing
print("\nClass distribution after balancing:")
train_balanced.groupBy("label").count().show()

# Define and train the GBT model
gbt = GBTClassifier(
    labelCol="label",
    featuresCol="features",
    maxIter=20,
    maxDepth=5,
    seed=42
)
model = gbt.fit(train_balanced)

# Predict on validation set
val_preds = model.transform(val_topk)

# Apply known threshold
def apply_threshold(df, threshold):
    predict_udf = udf(lambda prob: float(1.0) if prob[1] > threshold else float(0.0), DoubleType())
    return df.withColumn("adjusted_prediction", predict_udf(col("probability")))

# Use known best threshold
best_threshold = 0.30
val_preds_adjusted = apply_threshold(val_preds, best_threshold)

# Final evaluation
final_rdd = val_preds_adjusted.select("adjusted_prediction", "label").rdd.map(lambda r: (float(r[0]), float(r[1])))
metrics = MulticlassMetrics(final_rdd)

labels = [0.0, 1.0]

# Prepare report dictionary
report = {}
total_support = 0
weighted_sum = {"precision": 0, "recall": 0, "f1": 0}

for label in labels:
    precision = metrics.precision(label)
    recall = metrics.recall(label)
    f1 = metrics.fMeasure(label)
    support = final_rdd.filter(lambda r: r[1] == label).count()
    report[label] = {"precision": precision, "recall": recall, "f1-score": f1, "support": support}
    total_support += support
    weighted_sum["precision"] += precision * support
    weighted_sum["recall"] += recall * support
    weighted_sum["f1"] += f1 * support

macro_avg = {
    "precision": sum(metrics.precision(l) for l in labels) / len(labels),
    "recall": sum(metrics.recall(l) for l in labels) / len(labels),
    "f1-score": sum(metrics.fMeasure(l) for l in labels) / len(labels),
    "support": total_support
}

# Print confusion matrix
print("\nConfusion Matrix:")
print(metrics.confusionMatrix().toArray().astype(int))

# Print classification report like sklearn output
print("\nClassification Report:")
print(f"{'Class':<10}{'Precision':>10}{'Recall':>10}{'F1-Score':>10}{'Support':>10}")
for label in labels:
    vals = report[label]
    print(f"{str(int(label)):<10}{vals['precision']:10.4f}{vals['recall']:10.4f}{vals['f1-score']:10.4f}{vals['support']:10d}")

print(f"\n{'Accuracy':<10}{'':>10}{'':>10}{metrics.accuracy:10.4f}{total_support:10d}")
print(f"{'Macro Avg':<10}{macro_avg['precision']:10.4f}{macro_avg['recall']:10.4f}{macro_avg['f1-score']:10.4f}{macro_avg['support']:10d}")
print(f"{'Weighted Avg':<10}{(weighted_sum['precision']/total_support):10.4f}{(weighted_sum['recall']/total_support):10.4f}{(weighted_sum['f1']/total_support):10.4f}{total_support:10d}")

# Save model
model.write().overwrite().save("/FileStore/models/gbt_top10_no_weights")


Class distribution after balancing:
+-----+------+
|label| count|
+-----+------+
|  0.0|674488|
|  1.0|125441|
+-----+------+


Confusion Matrix:
[[713749    172]
 [ 26693      5]]

Classification Report:
Class      Precision    Recall  F1-Score   Support
0             0.9639    0.9998    0.9815    713921
1             0.0282    0.0002    0.0004     26698

Accuracy                          0.9637    740619
Macro Avg     0.4961    0.5000    0.4910    740619
Weighted Avg    0.9302    0.9637    0.9462    740619
