In [0]:
from pyspark.ml.feature import StringIndexer


test_data_raw = spark.read.parquet("dbfs:/FileStore/tables/preprocessed_data.parquet")


indexer = StringIndexer(inputCol="label_column_name", outputCol="label")
test_data = indexer.fit(test_data_raw).transform(test_data_raw)


In [0]:
from pyspark.ml.classification import LogisticRegressionModel


lr_model = LogisticRegressionModel.load("dbfs:/models/logistic_model")


predictions = lr_model.transform(test_data)
predictions.select("label", "prediction", "probability").show(5)


+-----+----------+--------------------+
|label|prediction|         probability|
+-----+----------+--------------------+
|  0.0|       0.0|[0.54335579608903...|
|  0.0|       0.0|[0.93508260141202...|
|  1.0|       0.0|[0.73459801856015...|
|  0.0|       0.0|[0.94274052413697...|
|  1.0|       1.0|[0.33092113174103...|
+-----+----------+--------------------+
only showing top 5 rows



In [0]:
from pyspark.ml.evaluation import MulticlassClassificationEvaluator

evaluator = MulticlassClassificationEvaluator(
    labelCol="label", predictionCol="prediction", metricName="accuracy")

accuracy = evaluator.evaluate(predictions)
print(f"Accuracy: {accuracy}")


Accuracy: 0.8006257110352674


In [0]:
precision_evaluator = MulticlassClassificationEvaluator(
    labelCol="label", predictionCol="prediction", metricName="weightedPrecision")
recall_evaluator = MulticlassClassificationEvaluator(
    labelCol="label", predictionCol="prediction", metricName="weightedRecall")
f1_evaluator = MulticlassClassificationEvaluator(
    labelCol="label", predictionCol="prediction", metricName="f1")

precision = precision_evaluator.evaluate(predictions)
recall = recall_evaluator.evaluate(predictions)
f1_score = f1_evaluator.evaluate(predictions)

print(f"Precision: {precision}")
print(f"Recall: {recall}")
print(f"F1-score: {f1_score}")


Precision: 0.7915636712614778
Recall: 0.8006257110352675
F1-score: 0.7939215866017353


In [0]:
predictions.groupBy("label", "prediction").count().show()


+-----+----------+-----+
|label|prediction|count|
+-----+----------+-----+
|  1.0|       1.0| 1000|
|  0.0|       1.0|  533|
|  1.0|       0.0|  869|
|  0.0|       0.0| 4630|
+-----+----------+-----+

