Configuração Inicial e Carregamento dos Dados

In [22]:
import findspark
findspark.init()
findspark.find()

from pyspark.sql import SparkSession

# Inicialização da sessão Spark
spark = SparkSession\
        .builder\
        .appName("Projeto_AJP")\
        .master("local[*]")\
        .config("spark.executor.memory", "8g")\
        .config("spark.driver.memory", "8g")\
        .config("spark.driver.maxResultSize", "1g")\
        .config("spark.memory.offHeap.enabled", False)\
        .enableHiveSupport()\
        .getOrCreate()

In [23]:
# Carregar os dados de teste
test_data = spark.read.parquet("../data/processed/test_data.parquet")

Carregamento dos Modelos Treinados

In [24]:
# Carregar os modelos treinados
from pyspark.ml.classification import LogisticRegressionModel, DecisionTreeClassificationModel, RandomForestClassificationModel, GBTClassificationModel

lr_model = LogisticRegressionModel.load("../models/logistic_regression_model")
dt_model = DecisionTreeClassificationModel.load("../models/decision_tree_model")
rf_model = RandomForestClassificationModel.load("../models/random_forest_model")

Seleção de Métricas

In [25]:
# Definir os avaliadores
accuracy_evaluator = MulticlassClassificationEvaluator(labelCol="IsDelayed", predictionCol="prediction", metricName="accuracy")
precision_evaluator = MulticlassClassificationEvaluator(labelCol="IsDelayed", predictionCol="prediction", metricName="weightedPrecision")
recall_evaluator = MulticlassClassificationEvaluator(labelCol="IsDelayed", predictionCol="prediction", metricName="weightedRecall")
f1_evaluator = MulticlassClassificationEvaluator(labelCol="IsDelayed", predictionCol="prediction", metricName="f1")
roc_auc_evaluator = BinaryClassificationEvaluator(labelCol="IsDelayed", rawPredictionCol="rawPrediction", metricName="areaUnderROC")

In [26]:
# Função para calcular as métricas
def compute_metrics(predictions):
    accuracy = accuracy_evaluator.evaluate(predictions)
    precision = precision_evaluator.evaluate(predictions)
    recall = recall_evaluator.evaluate(predictions)
    f1 = f1_evaluator.evaluate(predictions)
    roc_auc = roc_auc_evaluator.evaluate(predictions)
    return accuracy, roc_auc, precision, recall, f1

Regressão Logística

In [27]:
lr_predictions = lr_model.transform(test_data)
lr_metrics = compute_metrics(lr_predictions)

print("Logistic Regression Metrics:")
print(f"Accuracy: {lr_metrics[0]}")
print(f"ROC AUC: {lr_metrics[1]}")
print(f"Precision: {lr_metrics[2]}")
print(f"Recall: {lr_metrics[3]}")
print(f"F1 Score: {lr_metrics[4]}")

Logistic Regression Metrics:
Accuracy: 0.615651429194504
ROC AUC: 0.6541888094853675
Precision: 0.6570287089829441
Recall: 0.6156514291945041
F1 Score: 0.6272727447484185


Árvore de Decisão

In [28]:
dt_predictions = dt_model.transform(test_data)
dt_metrics = compute_metrics(dt_predictions)

print("\nDecision Tree Metrics:")
print(f"Accuracy: {dt_metrics[0]}")
print(f"ROC AUC: {dt_metrics[1]}")
print(f"Precision: {dt_metrics[2]}")
print(f"Recall: {dt_metrics[3]}")
print(f"F1 Score: {dt_metrics[4]}")


Decision Tree Metrics:
Accuracy: 0.633772944338641
ROC AUC: 0.5163442885959006
Precision: 0.6535648227382512
Recall: 0.6337729443386408
F1 Score: 0.6412064129527568


Random Forest

In [29]:
rf_predictions = rf_model.transform(test_data)
rf_metrics = compute_metrics(rf_predictions)

print("\nRandom Forest Metrics:")
print(f"Accuracy: {rf_metrics[0]}")
print(f"ROC AUC: {rf_metrics[1]}")
print(f"Precision: {rf_metrics[2]}")
print(f"Recall: {rf_metrics[3]}")
print(f"F1 Score: {rf_metrics[4]}")


Random Forest Metrics:
Accuracy: 0.6086839229336289
ROC AUC: 0.6352610490166465
Precision: 0.6448718910406317
Recall: 0.6086839229336288
F1 Score: 0.6198581541931213


Salvar as previsões

In [30]:
lr_predictions.write.mode("overwrite").parquet("../data/predictions/lr_predictions.parquet")
dt_predictions.write.mode("overwrite").parquet("../data/predictions/dt_predictions.parquet")
rf_predictions.write.mode("overwrite").parquet("../data/predictions/rf_predictions.parquet")

print("Model predictions saved.")

Model predictions saved.


Encerrar Spark

In [31]:
spark.stop()