In [1]:
from pyspark.sql import SparkSession, DataFrame
from pyspark.sql import functions as F
from pyspark.sql.window import Window
from pyspark.sql.types import (
    StructType,
    StructField,
    StringType,
    TimestampType,
    LongType,
    DateType,
)
from datetime import datetime, timedelta
import os

In [2]:
from datetime import datetime
from pyspark.sql import SparkSession, SQLContext

spark = (
    SparkSession.builder.appName("data-optimization-{}".format(datetime.today()))
    .master("spark://spark-master:7077")
    .getOrCreate()
)


sqlContext = SQLContext(spark)
sc = spark.sparkContext
sc



# Data preparation

In [3]:
class ETLPipeline:
    def __init__(self, spark):
        self.bucket_name = f"s3a://warehouse"
        self.spark = spark

        # Define storage paths
        self.bronze_path = f"{self.bucket_name}/bronze"
        self.silver_path = f"{self.bucket_name}/silver"
        self.gold_path = f"{self.bucket_name}/gold"

    def define_schemas(self):
        """Define schemas for the datasets"""
        self.interactions_schema = StructType(
            [
                StructField("user_id", StringType(), False),
                StructField("timestamp", TimestampType(), False),
                StructField("action_type", StringType(), False),
                StructField("page_id", StringType(), False),
                StructField("duration_ms", LongType(), False),
                StructField("app_version", StringType(), False),
            ]
        )

        self.metadata_schema = StructType(
            [
                StructField("user_id", StringType(), False),
                StructField("join_date", DateType(), False),
                StructField("country", StringType(), False),
                StructField("device_type", StringType(), False),
                StructField("subscription_type", StringType(), False),
            ]
        )

    def ingest_to_bronze(self, csv_path: str, dataset_type: str):
        """Ingest CSV files to bronze layer in parquet format"""
        schema = (
            self.interactions_schema
            if dataset_type == "interactions"
            else self.metadata_schema
        )

        df = self.spark.read.schema(schema).csv(csv_path)

        if dataset_type == "interactions":
            # Partition by date for interactions
            df = df.withColumn("partition_date", F.to_date("timestamp"))
            output_path = f"{self.bronze_path}/interactions"
            partition_by = ["partition_date"]
        else:
            # Partition by country for metadata
            output_path = f"{self.bronze_path}/metadata"
            partition_by = ["country"]

        df.write.mode("append").partitionBy(partition_by).parquet(output_path)

    def process_silver_layer(self, process_date: datetime = None):
        """
        Process bronze data into silver layer with cleaned and validated data.
        Supports incremental processing by date.

        Args:
            process_date: Optional date to process. If None, processes current date
        """
        if process_date is None:
            process_date = datetime.now().date()

        # Read only the partition we need from bronze interactions
        interactions_df = self.spark.read.option(
            "basePath", f"{self.bronze_path}/interactions"
        ).parquet(f"{self.bronze_path}/interactions/partition_date={process_date}")

        # For metadata, check if we need to process updates
        metadata_path = f"{self.silver_path}/dim_users"
        metadata_df = self.spark.read.parquet(f"{self.bronze_path}/metadata")

        # Get existing metadata last modified date if exists
        try:
            existing_metadata = self.spark.read.parquet(metadata_path)
            last_modified = existing_metadata.agg(F.max("_modified_date")).collect()[0][
                0
            ]
        except:
            last_modified = None
            existing_metadata = None
            
        # Process metadata changes first
        if existing_metadata is not None:
            # Identify new or updated metadata records
            new_metadata_df = metadata_df.withColumn(
                "_modified_date", F.current_date()
            ).join(
                existing_metadata, "user_id", "left_anti"
            )  # Get only new records

            # Combine existing and new metadata
            combined_metadata = existing_metadata.union(new_metadata_df)
        else:
            combined_metadata = metadata_df.withColumn("_modified_date", F.current_date())
            new_metadata_df = combined_metadata

        # Process interactions incrementally
        clean_interactions = interactions_df.filter(
            F.col("duration_ms").between(0, 7200000)
        ).dropDuplicates(["user_id", "timestamp", "action_type", "page_id"])

        # Join with metadata to create enriched fact table
        # Broadcast the metadata as it's typically smaller
        fact_interactions = (
            clean_interactions
            .join(
                F.broadcast(combined_metadata.select(
                    "user_id",
                    "country",
                    "device_type",
                    "subscription_type"
                )),
                "user_id",
                "left"
            )
            .select(
                "user_id",
                "timestamp",
                "action_type",
                "page_id",
                "duration_ms",
                "partition_date",
                "country",
                "device_type",
                "subscription_type"
            )
            .withColumn("_modified_date", F.current_date())
        )

        # Write fact table incrementally
        fact_path = f"{self.silver_path}/fact_interactions"

        (
            fact_interactions.write.mode("append")  # Use append mode for incremental
            .partitionBy("partition_date")
            .option(
                "replaceWhere", f"partition_date = '{process_date}'"
            )  # Overwrite only this partition
            .parquet(fact_path)
        )

        # Process metadata changes
        if existing_metadata is not None:
            # Identify new or updated metadata records
            metadata_df = metadata_df.withColumn(
                "_modified_date", F.current_date()
            ).join(
                existing_metadata, "user_id", "left_anti"
            )  # Get only new records
        else:
            metadata_df = metadata_df.withColumn("_modified_date", F.current_date())

        if metadata_df.count() > 0:  # Only process if we have changes
            dim_users = metadata_df.dropDuplicates(["user_id"]).select(
                "user_id",
                "join_date",
                "country",
                "device_type",
                "subscription_type",
                "_modified_date",
            )

            # Write dimension table
            # For small dimension tables, we can use overwrite mode
            # For larger ones, consider using merge/upsert operations
            if existing_metadata is None:
                write_mode = "overwrite"
            else:
                write_mode = "append"

            (
                dim_users.write.mode(write_mode)
                .partitionBy("country")
                .parquet(metadata_path)
            )

        # Return metrics about processed data
        return {
            "date_processed": process_date,
            "interactions_processed": clean_interactions.count(),
            "metadata_updates": metadata_df.count() if metadata_df.count() > 0 else 0,
        }

    def process_date_range(self, start_date: datetime, end_date: datetime):
        """Process a range of dates incrementally"""
        current_date = start_date
        processing_metrics = []

        while current_date <= end_date:
            try:
                metrics = self.process_silver_layer(current_date)
                processing_metrics.append(metrics)
                current_date += timedelta(days=1)
            except Exception as e:
                print(f"Error processing date {current_date}: {str(e)}")
                raise

        return processing_metrics

    def cleanup_old_partitions(self, retention_days: int = 90):
        """Clean up old partitions based on retention policy"""
        cutoff_date = datetime.now().date() - timedelta(days=retention_days)

        # List partitions
        bronze_partitions = self.spark._jvm.org.apache.hadoop.fs.Path(
            f"{self.bronze_path}/interactions"
        )
        silver_partitions = self.spark._jvm.org.apache.hadoop.fs.Path(
            f"{self.silver_path}/fact_interactions"
        )

        # Delete old partitions
        fs = bronze_partitions.getFileSystem(self.spark._jsc.hadoopConfiguration())

        for path in [bronze_partitions, silver_partitions]:
            if fs.exists(path):
                for partition in fs.listStatus(path):
                    partition_date = datetime.strptime(
                        partition.getPath().getName().split("=")[1], "%Y-%m-%d"
                    ).date()

                    if partition_date < cutoff_date:
                        fs.delete(partition.getPath(), True)

    def _calculate_session_metrics(
        self,
        fact_interactions: DataFrame,
        process_date: datetime,
        lookback_days: int = 1,
    ) -> DataFrame:
        """
        Calculate session-based metrics with window functions, handling session boundaries.

        Args:
            fact_interactions: DataFrame of interactions
            process_date: Date to process
            lookback_days: Number of days to look back for ongoing sessions
        """
        # Calculate date range for session boundary handling
        start_date = process_date - timedelta(days=lookback_days)
        end_date = process_date + timedelta(days=1)  # Include full day

        # Create window specs without range specification for lag/lead
        user_window = Window.partitionBy("user_id").orderBy("timestamp")

        # Create window spec for cumulative operations
        cumulative_window = Window.partitionBy("user_id").orderBy("timestamp")

        sessions_df = (
            fact_interactions.filter(
                F.col("partition_date").between(start_date, process_date)
            )
            .withColumn("prev_timestamp", F.lag("timestamp").over(user_window))
            .withColumn(
                "time_diff_minutes",
                F.when(
                    F.col("prev_timestamp").isNotNull(),
                    (F.unix_timestamp("timestamp") - F.unix_timestamp("prev_timestamp"))
                    / 60,
                ).otherwise(0),
            )
            .withColumn(
                "is_new_session",
                F.when(F.col("time_diff_minutes") >= 30, 1).otherwise(0),
            )
            .withColumn(
                "session_id",
                F.concat(
                    F.col("user_id"),
                    F.lit("_"),
                    F.date_format("partition_date", "yyyyMMdd"),
                    F.lit("_"),
                    F.sum("is_new_session").over(cumulative_window),
                ),
            )
        )

        # Calculate metrics only for sessions that end on process_date
        return (
            sessions_df.withColumn(
                "next_timestamp", F.lead("timestamp").over(user_window)
            )
            .withColumn(
                "session_end",
                F.when(
                    F.col("next_timestamp").isNull()
                    | (
                        (
                            F.unix_timestamp("next_timestamp")
                            - F.unix_timestamp("timestamp")
                        )
                        / 60
                        >= 30
                    ),
                    True,
                ).otherwise(False),
            )
            .filter(
                (F.col("partition_date") == process_date)
                | (F.col("session_end") == True)
            )
            .groupBy("session_id")
            .agg(
                F.count("*").alias("actions_per_session"),
                F.sum("duration_ms").alias("session_duration_ms"),
                F.first("partition_date").alias("session_date"),
                F.last("timestamp").alias("session_end_time"),
            )
        )

    def create_gold_layer(self, process_date: datetime = None):
        """
        Create gold layer with pre-aggregated data and business metrics incrementally.

        Args:
            process_date: Date to process, defaults to current date
        """
        if process_date is None:
            process_date = datetime.now().date()

        # Load relevant data from silver layer
        fact_interactions = self.spark.read.option(
            "basePath", f"{self.silver_path}/fact_interactions"
        ).parquet(f"{self.silver_path}/fact_interactions/partition_date={process_date}")

        dim_users = self.spark.read.parquet(f"{self.silver_path}/dim_users")
        broadcast_users = F.broadcast(dim_users)

        # Calculate daily metrics
        daily_metrics = fact_interactions.groupBy("partition_date").agg(
            F.countDistinct("user_id").alias("daily_active_users"),
            F.count("*").alias("total_actions"),
            F.avg("duration_ms").alias("avg_duration_ms"),
        )

        # Update monthly metrics
        month_start = process_date.replace(day=1)
        month_end = (process_date + timedelta(days=32)).replace(day=1) - timedelta(
            days=1
        )

        # Read existing monthly metrics for current month if exists
        monthly_path = f"{self.gold_path}/monthly_metrics"
        try:
            existing_monthly = self.spark.read.option("basePath", monthly_path).parquet(
                f"{monthly_path}/month_date={month_start}"
            )
        except:
            existing_monthly = None

        # Calculate monthly metrics for current month
        month_interactions = (
            self.spark.read.option("basePath", f"{self.silver_path}/fact_interactions")
            .parquet(f"{self.silver_path}/fact_interactions")
            .filter(F.col("partition_date").between(month_start, month_end))
        )

        monthly_metrics = (
            month_interactions.withColumn(
                "month_date", F.date_trunc("month", F.col("partition_date"))
            )
            .groupBy("month_date")
            .agg(
                F.countDistinct("user_id").alias("monthly_active_users"),
                F.count("*").alias("total_monthly_actions"),
            )
        )

        # Calculate session metrics with lookback
        session_metrics = self._calculate_session_metrics(
            fact_interactions, process_date, lookback_days=1
        )

        # Write metrics to gold layer
        # Daily metrics - append mode with partition replacement
        (
            daily_metrics.write.mode("append")
            .partitionBy("partition_date")
            .option("replaceWhere", f"partition_date = '{process_date}'")
            .parquet(f"{self.gold_path}/daily_metrics")
        )

        # Monthly metrics - replace partition for current month
        (
            monthly_metrics.write.mode("append")
            .partitionBy("month_date")
            .option("replaceWhere", f"month_date = '{month_start}'")
            .parquet(monthly_path)
        )

        # Session metrics - append mode with date partitioning
        (
            session_metrics.write.mode("append")
            .partitionBy("session_date")
            .option("replaceWhere", f"session_date = '{process_date}'")
            .parquet(f"{self.gold_path}/session_metrics")
        )

        return {
            "date_processed": process_date,
            "daily_metrics_updated": daily_metrics.count(),
            "monthly_metrics_updated": monthly_metrics.count(),
            "sessions_processed": session_metrics.count(),
        }

    def backfill_gold_metrics(
        self, start_date: datetime, end_date: datetime, parallel: bool = False
    ):
        """
        Backfill gold metrics for a date range.

        Args:
            start_date: Start date for backfill
            end_date: End date for backfill
            parallel: Whether to process dates in parallel
        """
        if parallel:
            # Create list of dates to process
            dates = [
                (start_date + timedelta(days=x)).date()
                for x in range((end_date - start_date).days + 1)
            ]

            # Process dates in parallel using Spark
            date_df = self.spark.createDataFrame(
                [(date,) for date in dates], ["process_date"]
            )

            date_df.repartition(min(len(dates), 50)).foreach(
                lambda row: self.create_gold_layer(row.process_date)
            )
        else:
            current_date = start_date
            while current_date <= end_date:
                try:
                    self.create_gold_layer(current_date)
                    current_date += timedelta(days=1)
                except Exception as e:
                    print(f"Error processing {current_date}: {str(e)}")
                    raise

# Data processing

In [12]:
# Initialize ETL pipeline
etl = ETLPipeline(spark)
etl.define_schemas()

## Bronze layer: Ingest and convert to parquet

In [5]:
# Bronze layer: Ingest and convert to parquet
etl.ingest_to_bronze(
    "s3a://warehouse/data/user_interactions_sample.csv", "interactions"
)
etl.ingest_to_bronze("s3a://warehouse/data/user_metadata_sample.csv", "metadata")

## Silver layer: Clean data and create fact/dimension tables

In [13]:
# Silver layer: Clean data and create fact/dimension tables
processing_date = datetime(2023, 1, 1).date()
etl.process_silver_layer(processing_date)

{'date_processed': datetime.date(2023, 1, 1),
 'interactions_processed': 27488,
 'metadata_updates': 0}

## Silver layer: process multiple dates incrementally

In [14]:
%%time
start_date = datetime(2023, 1, 2).date()
end_date = datetime(2023, 1, 5).date()
etl.process_date_range(start_date, end_date)

CPU times: user 121 ms, sys: 24 ms, total: 146 ms
Wall time: 28.8 s


[{'date_processed': datetime.date(2023, 1, 2),
  'interactions_processed': 27530,
  'metadata_updates': 0},
 {'date_processed': datetime.date(2023, 1, 3),
  'interactions_processed': 27306,
  'metadata_updates': 0},
 {'date_processed': datetime.date(2023, 1, 4),
  'interactions_processed': 27683,
  'metadata_updates': 0},
 {'date_processed': datetime.date(2023, 1, 5),
  'interactions_processed': 27364,
  'metadata_updates': 0}]

In [6]:
df_fact_int = spark.read.parquet(
    "s3a://warehouse/silver/fact_interactions/partition_date=2023-01-03"
)
df_fact_int.printSchema()

root
 |-- user_id: string (nullable = true)
 |-- timestamp: timestamp (nullable = true)
 |-- action_type: string (nullable = true)
 |-- page_id: string (nullable = true)
 |-- duration_ms: long (nullable = true)
 |-- country: string (nullable = true)
 |-- device_type: string (nullable = true)
 |-- subscription_type: string (nullable = true)
 |-- _modified_date: date (nullable = true)



In [7]:
df_fact_int.count()

54612

In [8]:
df_fact_int.show()

+-------+-------------------+-----------+-------+-----------+-------+--------------+-----------------+--------------+
|user_id|          timestamp|action_type|page_id|duration_ms|country|   device_type|subscription_type|_modified_date|
+-------+-------------------+-----------+-------+-----------+-------+--------------+-----------------+--------------+
|u751627|2023-01-03 17:45:22|       edit|p621585|     118746|     JP|Android Tablet|            basic|    2024-12-20|
|u491077|2023-01-03 15:38:30|     create|p265882|     160165|   NULL|          NULL|             NULL|    2024-12-20|
|u807778|2023-01-03 09:56:58|      share|p554854|     151881|     BR|           Mac|          premium|    2024-12-20|
|u346752|2023-01-03 00:18:16|       edit|p949965|      56082|     IN|        iPhone|       enterprise|    2024-12-20|
|u595276|2023-01-03 19:16:12|  page_view|p224316|     264913|     MX|          iPad|          premium|    2024-12-20|
|u239288|2023-01-03 17:12:01|  page_view|p525366|     29

# Gold layer: Create pre-aggregated business metrics

In [15]:
# Process single date
processing_date = datetime(2023, 1, 5).date()
metrics = etl.create_gold_layer(processing_date)
metrics

{'date_processed': datetime.date(2023, 1, 5),
 'daily_metrics_updated': 1,
 'monthly_metrics_updated': 1,
 'sessions_processed': 27344}

In [10]:
df_sess_metrics = spark.read.parquet("s3a://warehouse/gold/session_metrics/")
df_sess_metrics.printSchema()

root
 |-- session_id: string (nullable = true)
 |-- actions_per_session: long (nullable = true)
 |-- session_duration_ms: long (nullable = true)
 |-- session_end_time: timestamp (nullable = true)
 |-- session_date: date (nullable = true)



In [11]:
df_sess_metrics.show()

+------------------+-------------------+-------------------+-------------------+------------+
|        session_id|actions_per_session|session_duration_ms|   session_end_time|session_date|
+------------------+-------------------+-------------------+-------------------+------------+
|u004300_20230105_0|                  2|              94148|2023-01-05 05:57:05|  2023-01-05|
|u004546_20230105_0|                  2|              97398|2023-01-05 16:33:54|  2023-01-05|
|u004691_20230105_0|                  2|             583716|2023-01-05 07:32:53|  2023-01-05|
|u004814_20230105_0|                  2|             301572|2023-01-05 20:53:04|  2023-01-05|
|u016776_20230105_0|                  2|             288734|2023-01-05 05:15:32|  2023-01-05|
|u037646_20230105_0|                  2|             536454|2023-01-05 03:46:44|  2023-01-05|
|u049417_20230105_0|                  2|             269394|2023-01-05 13:18:47|  2023-01-05|
|u064662_20230105_0|                  2|             419944|