<a href="https://colab.research.google.com/github/margaridagomes/dataeng-basic-course/blob/main/spark/challenges/challenge_4.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# CHALLENGE 4
##  Analyze data

- Query table "vehicles_enriched" in gold layer
- Aggregate data by municipality_name (array)
- Calculate:
  - count of vehicles (id) that pass through that municipality
  - sum speed of vehicles

Questions:
  - What are the top 3 municipalities by vehicles routes?
  - What are the top 3 municipalities with higher vehicle speed on average?


Tips:
- explode array into rows -> https://spark.apache.org/docs/latest/api/python/reference/pyspark.sql/api/pyspark.sql.functions.explode.html


# Setting up PySpark

In [None]:
%pip install pyspark



In [None]:
# Set up path in the "lake"
!mkdir -p /content/lake/bronze
!mkdir -p /content/lake/silver
!mkdir -p /content/lake/gold

from pyspark.sql import SparkSession, DataFrame
from pyspark.sql.types import *
import pyspark.sql.functions as F
import requests

class ETLFlow:
    """
    Base ETL class with common extraction and loading logic.
    """

    def __init__(self, spark: SparkSession) -> None:
        self.spark = spark

    def extract_from_api(self, url: str, schema: StructType = None) -> DataFrame:
        """
        Extract data from an API endpoint, returning a Spark DataFrame.
        Applies schema if provided.
        """
        response = requests.get(url)
        rdd = spark.sparkContext.parallelize(response.json())
        if schema:
            df = spark.read.schema(schema).json(rdd)
        else:
            df = spark.read.json(rdd)
        return df

    def extract_from_file(self, format: str, path: str, **kwargs) -> DataFrame:
        """
        Read data from file (e.g. Parquet) and return a DataFrame.
        """
        df = self.spark.read.format(format).load(path)
        return df

    def load(self, df: DataFrame, format: str, path: str, partition_column: str = None, dynamic_partition_overwrite: bool = False, **kwargs) -> None:
        """
        Save the DataFrame to the specified path in the chosen format.
        If 'partition_column' is provided, partition the data accordingly.
        If 'dynamic_partition_overwrite' is True, only overwrite the partitions present in the DataFrame.
        """
        # Set Spark configuration for dynamic/static partition overwrite mode
        if dynamic_partition_overwrite:
            self.spark.conf.set("spark.sql.sources.partitionOverwriteMode", "dynamic")
        else:
            self.spark.conf.set("spark.sql.sources.partitionOverwriteMode", "static")

        # Write the DataFrame as a single parquet file, with optional partitioning
        if partition_column:
            df.coalesce(1).write.mode("overwrite").partitionBy(partition_column).format(format).save(path)
        else:
            df.coalesce(1).write.mode("overwrite").format(format).save(path)

class ETLTask(ETLFlow):
    """
    ETL pipeline for lines, vehicles, and municipalities ingestion, cleansing, enrichment and analysis.
    """

    def __init__(self, spark: SparkSession) -> None:
        self.spark = spark

    def ingestion_lines(self):
        """
        Extract 'lines' data from the API, enforce schema, and write as a single parquet file to bronze layer.
        """
        # Schema definition for lines
        lines_schema = StructType([StructField('color', StringType(), True),
                                   StructField('facilities', ArrayType(StringType(), True), True),
                                   StructField('id', StringType(), True),
                                   StructField('localities', ArrayType(StringType(), True), True),
                                   StructField('long_name', StringType(), True),
                                   StructField('municipalities', ArrayType(StringType(), True), True),
                                   StructField('patterns', ArrayType(StringType(), True), True),
                                   StructField('routes', ArrayType(StringType(), True), True),
                                   StructField('short_name', StringType(), True),
                                   StructField('text_color', StringType(), True)])

        # Ingestion
        df = self.extract_from_api(url="https://api.carrismetropolitana.pt/lines", schema=lines_schema)
        # Load
        self.load(df=df, format="parquet", path="/content/lake/bronze/lines")

    def ingestion_vehicles(self):
        """
        Extract 'vehicles' data from the API, enforce schema, add a partitioning 'date' column,
        and write to bronze layer as a partitioned parquet file with dynamic partition overwrite enabled.
        """
        # Schema definition for vehicles
        vehicle_schema = StructType([StructField('bearing', IntegerType(), True),
                                  StructField('block_id', StringType(), True),
                                  StructField('current_status', StringType(), True),
                                  StructField('id', StringType(), True),
                                  StructField('lat', FloatType(), True),
                                  StructField('line_id', StringType(), True),
                                  StructField('lon', FloatType(), True),
                                  StructField('pattern_id', StringType(), True),
                                  StructField('route_id', StringType(), True),
                                  StructField('schedule_relationship', StringType(), True),
                                  StructField('shift_id', StringType(), True),
                                  StructField('speed', FloatType(), True),
                                  StructField('stop_id', StringType(), True),
                                  StructField('timestamp', TimestampType(), True),
                                  StructField('trip_id', StringType(), True)])

        # Ingestion
        df = self.extract_from_api(url="https://api.carrismetropolitana.pt/vehicles", schema=vehicle_schema)

        # Create "date" column from "timestamp"
        df = df.withColumn("date", F.expr("date(timestamp)"))

        # Load
        self.load(df=df, format="parquet", path="/content/lake/bronze/vehicles", partition_column="date", dynamic_partition_overwrite=True)

    def ingestion_municipalities(self):
        """
        Extract municipalities data from API, enforce schema, write to bronze layer as a single parquet file.
        """
        # Schema definition for municipalities
        municipalities_schema = StructType([StructField('district_id', StringType(), True),
                                            StructField('district_name', StringType(), True),
                                            StructField('id', StringType(), True),
                                            StructField('name', StringType(), True),
                                            StructField('prefix', StringType(), True),
                                            StructField('region_id', StringType(), True),
                                            StructField('region_name', StringType(), True)])

        # Ingestion
        df = self.extract_from_api(url="https://api.carrismetropolitana.pt/municipalities", schema=municipalities_schema)

        # Load
        self.load(df=df, format="parquet", path="/content/lake/bronze/municipalities")

    def cleansing_vehicles(self):
        """
        Cleansing vehicles: rename lat/lon, remove duplicates, drop nulls and corrupted records, save to silver as a partitioned parquet file with dynamic partition overwrite enabled.
        """
        # Ingestion
        df = self.extract_from_file(format="parquet", path="/content/lake/bronze/vehicles")

        # Renaming columns
        df = df.withColumnRenamed("lat", "latitude").withColumnRenamed("lon", "longitude")

        # Removing duplicates
        df = df.dropDuplicates()

        # Remove rows where 'current_status' is null
        df = df.filter(F.col("current_status").isNotNull())

        # Remove corrupted records
        if "_corrupt_record" in df.columns:
          df = df.filter(F.col("_corrupt_record").isNull()).drop("_corrupt_record")

        # Load
        self.load(df=df, format="parquet", path="/content/lake/silver/vehicles", partition_column="date", dynamic_partition_overwrite=True)


    def cleansing_lines(self):
        """
        Cleansing lines: remove duplicates, remove nulls and corrupted records, save to silver as a single parquet file.
        """
        # Ingestion
        df = self.extract_from_file(format="parquet", path="/content/lake/bronze/lines")

        # Remove duplicates
        df = df.dropDuplicates()

        # Remove rows where 'long_name' is null
        df = df.filter(F.col("long_name").isNotNull())

        # Remove corrupted records
        if "_corrupt_record" in df.columns:
          df = df.filter(F.col("_corrupt_record").isNull()).drop("_corrupt_record")

        # Load
        self.load(df=df, format="parquet", path="/content/lake/silver/lines")


    def cleansing_municipalities(self):
        """
        Cleansing municipalities: remove duplicates, remove nulls and corrupted records, save to silver as a single parquet file.
        """
        # Ingestion
        df = self.extract_from_file(format="parquet", path="/content/lake/bronze/municipalities")

        # Remove duplicates
        df = df.dropDuplicates()

        # Remove rows where 'name' or 'district_name' is null
        df = df.filter(F.col("name").isNotNull() & F.col("district_name").isNotNull())

        # Remove corrupted records (records missing important columns)
        if "_corrupt_record" in df.columns:
          df = df.filter(F.col("_corrupt_record").isNull()).drop("_corrupt_record")

        # Load
        self.load(df=df, format="parquet", path="/content/lake/silver/municipalities")


    def enrich_vehicles(self):
        """
        Enrich vehicles with lines and municipalities (gold layer).
        Joins, explodes, and aggregates municipality names.
        """

        # Ingestion
        df_vehicles = self.extract_from_file(format="parquet", path="/content/lake/silver/vehicles")
        df_lines = self.extract_from_file(format="parquet", path="/content/lake/silver/lines")
        df_municipalities = self.extract_from_file(format="parquet", path="/content/lake/silver/municipalities")

        # Explode municipalities array from lines to create one row per municipality
        df_lines_exploded = df_lines.select(
            F.col("id").alias("line_id_from_lines"),
            F.col("long_name").alias("line_name"),
            F.explode_outer("municipalities").alias("municipality_id")
        )

        # Join vehicles with lines (left join on line_id)
        df_joined = df_vehicles.join(
            df_lines_exploded,
            df_vehicles["line_id"] == df_lines_exploded["line_id_from_lines"],
            how="left"
        ).drop("line_id_from_lines")

        df_municipalities_selected = df_municipalities.select(
            F.col("id").alias("municipality_id_select"),
            F.col("name")
        )

        # Join with municipalities table to get municipality names
        df_enriched = df_joined.join(
            df_municipalities_selected,
            df_joined["municipality_id"] == df_municipalities_selected["municipality_id_select"],
            how="left"
        ).drop("municipality_id_select")

        # Select all vehicle columns + line_name + municipality name
        vehicle_columns = [col for col in df_vehicles.columns]
        df_selected = df_enriched.select(
            *vehicle_columns,
            F.col("line_name"),
            F.col("name")
        )

        # Group by all vehicle columns + line_name and collect municipality names into array
        group_columns = vehicle_columns + ["line_name"]
        df_final = df_selected.groupBy(*group_columns).agg(
            F.collect_list("name").alias("municipality_name")
        )

        # Remove duplicates from municipality_name array and handle nulls
        df_final = df_final.withColumn(
            "municipality_name",
            F.array_distinct(
                F.filter(F.col("municipality_name"), lambda x: x.isNotNull())
            )
        )

        # Writing vehicles enriched to gold layer partitioned by date
        self.load(df=df_final, format="parquet", path="/content/lake/gold/vehicles_enriched", partition_column="date", dynamic_partition_overwrite=True)

    def analyze_vehicles_by_municipality(self):
        """
        Analysis: aggregate vehicle stats by municipality, answering business questions.
        """

        # Read the enriched vehicles data from gold layer
        df_enriched = self.extract_from_file(format="parquet", path="/content/lake/gold/vehicles_enriched")

        # Explode the municipality_name array to create one row per municipality
        df_exploded = df_enriched.select("*", F.explode(F.col("municipality_name")).alias("municipality"))

        # Aggregate data by municipality
        df_aggregated = df_exploded.groupBy("municipality").agg(
            F.countDistinct("id").alias("vehicle_count"),
            F.round(F.sum("speed"),2).alias("total_speed"),
            F.round(F.avg("speed"),2).alias("avg_speed")
        )

        # Question 1: Top 3 municipalities by vehicle routes (count of unique vehicles)
        print("\n=== Top 3 municipalities by vehicle routes ===")
        df_top_by_routes = df_aggregated.orderBy(F.col("vehicle_count").desc()).limit(3)
        df_top_by_routes.select("municipality", "vehicle_count").show()

        # Question 2: Top 3 municipalities with highest average vehicle speed
        print("\n=== Top 3 municipalities with highest average vehicle speed ===")
        df_top_by_speed = df_aggregated.orderBy(F.col("avg_speed").desc()).limit(3)
        df_top_by_speed.select("municipality", "avg_speed").show()

        # Load the aggregated results
        self.load(df=df_aggregated, format="parquet", path="/content/lake/gold/municipality_vehicle_stats")

if __name__ == '__main__':
    # Stop Spark if already running (prevents "Only one SparkContext" error)
    try:
        spark.stop()
    except:
        pass

    # Initialize SparkSession
    from pyspark.sql import SparkSession
    spark = SparkSession.builder.master('local').appName('ETL Program').getOrCreate()

    try:
        print("Starting ETL program.")

        etl = ETLTask(spark)

        # Ingestion steps (bronze)
        print("Running Task - Ingestion Vehicles")
        etl.ingestion_vehicles()
        print("Running Task - Ingestion Lines")
        etl.ingestion_lines()
        print("Running Task - Ingestion Municipalities")
        etl.ingestion_municipalities()

        # Cleansing steps (silver)
        print("Running Task - Cleansing Vehicles")
        etl.cleansing_vehicles()
        print("Running Task - Cleansing Lines")
        etl.cleansing_lines()
        print("Running Task - Cleansing Municipalities")
        etl.cleansing_municipalities()

        # Enrichment (gold)
        print("Running Task - Enrichment")
        etl.enrich_vehicles()

        # Analysis (gold)
        print("Running Task - Analysis")
        etl.analyze_vehicles_by_municipality()

        print("ETL program completed.")

        # Preview gold layer results
        print("Preview - municipality vehicle stats (gold)")
        spark.read.parquet("/content/lake/gold/municipality_vehicle_stats").show(truncate=False)

    finally:
        # Always stop SparkSession to free resources
        spark.stop()


Starting ETL program
Running Task - Ingestion Vehicles
Running Task - Ingestion Lines
Running Task - Ingestion Municipalities
Running Task - Cleansing Vehicles
Running Task - Cleansing Lines
Running Task - Cleansing Municipalities
Running Task - Enrichment
Running Task - Analysis

=== Top 3 municipalities by vehicle routes ===
+------------+-------------+
|municipality|vehicle_count|
+------------+-------------+
|      Lisboa|          134|
|      Sintra|          118|
|      Loures|           77|
+------------+-------------+


=== Top 3 municipalities with highest average vehicle speed ===
+------------+---------+
|municipality|avg_speed|
+------------+---------+
|Vendas Novas|     7.22|
|       Mafra|     6.89|
|     Palmela|     6.26|
+------------+---------+

ETL program completed
Preview - municipality vehicle stats
+-------------------+-------------+-----------+---------+
|municipality       |vehicle_count|total_speed|avg_speed|
+-------------------+-------------+-----------+----