In [None]:
from google.colab import drive
drive.mount('/content/drive')

In [None]:
import os
import matplotlib.pyplot as plt
from pyspark.sql import SparkSession
from pyspark.sql.functions import month, hour, avg, col


class TaxiAnalytics:
    """
    Classe responsável por responder consultas analíticas sobre os dados da Silver Layer.
    Gera também gráficos dos resultados.
    """

    def __init__(self, spark: SparkSession, silver_base: str):
        """
        Inicializa a classe com a SparkSession e o caminho da silver_layer.

        :param spark: Instância de SparkSession.
        :param silver_base: Caminho base da silver_layer.
        """
        self.spark = spark
        self.silver_base = silver_base.rstrip("/")

    def average_total_amount_by_month(self, show_plot=True):
        """
        Calcula a média do valor total recebido por mês (apenas yellow táxis).

        :param show_plot: Se True, exibe um gráfico de barras.
        :return: DataFrame do resultado.
        """
        yellow_df = self.spark.read.parquet(f"{self.silver_base}/yellow")
        result = (
            yellow_df
            .withColumn("month", month(col("tpep_pickup_datetime")))
            .groupBy("month")
            .agg(avg("total_amount").alias("avg_total_amount"))
            .orderBy("month")
        )

        if show_plot:
            data = result.toPandas()
            plt.figure(figsize=(8, 5))
            plt.bar(data["month"], data["avg_total_amount"], color="gold")
            plt.xlabel("Mês")
            plt.ylabel("Média de total_amount (USD)")
            plt.title("🟨 Média de total_amount por mês (Yellow Táxis)")
            plt.xticks(data["month"])
            plt.grid(axis='y', linestyle="--", alpha=0.5)
            plt.tight_layout()
            plt.show()

        return result

    def average_passenger_count_may_by_hour(self, show_plot=True):
        """
        Calcula a média de passageiros por hora do dia, somente para o mês de maio.

        :param show_plot: Se True, exibe um gráfico de linha.
        :return: DataFrame do resultado.
        """
        all_df = self.spark.read.parquet(self.silver_base)
        result = (
            all_df
            .withColumn("month", month(col("tpep_pickup_datetime")))
            .filter(col("month") == 5)
            .withColumn("hour", hour(col("tpep_pickup_datetime")))
            .groupBy("hour")
            .agg(avg("passenger_count").alias("avg_passenger_count"))
            .orderBy("hour")
        )

        if show_plot:
            data = result.toPandas()
            plt.figure(figsize=(9, 5))
            plt.plot(data["hour"], data["avg_passenger_count"], marker='o', linestyle='-')
            plt.xlabel("Hora do dia")
            plt.ylabel("Média de passageiros")
            plt.title("Média de passageiros por hora (Maio - Yellow + Green)")
            plt.xticks(range(0, 24))
            plt.grid(True, linestyle="--", alpha=0.5)
            plt.tight_layout()
            plt.show()

        return result


In [None]:
spark = SparkSession.builder \
    .appName("TaxiAnalytics") \
    .getOrCreate()

silver_base = "/content/drive/MyDrive/ifood/teste_2/silver_layer"

analytics = TaxiAnalytics(spark, silver_base)

analytics.average_total_amount_by_month()
analytics.average_passenger_count_may_by_hour()