In [0]:
%run "./01_config"

In [0]:
class Upserter:
    def __init__(self, merge_query, temp_view_name):
        self.merge_query = merge_query
        self.temp_view_name = temp_view_name

    def upsert(self, df_micro_batch, batch_id):
        df_micro_batch.createOrReplaceTempView(self.temp_view_name)
        df_micro_batch._jdf.sparkSession().sql(self.merge_query)

In [0]:
from pyspark.sql import functions as F
from pyspark.sql.functions import broadcast
class Gold():
    def __init__(self, env):
        self.Conf = Config()
        self.landing_zone = self.Conf.landing + 'landing_zone'
        self.checkpoint_base = self.Conf.checkpoint + 'checkpoints'
        self.initial = self.Conf.medallion + "initial"
        self.bronze = self.Conf.medallion + "bronze"
        self.silver = self.Conf.medallion + "silver"
        self.gold = self.Conf.medallion + "gold"
        self.catalog = f"fitbit_{env}_catalog"
        self.db_name = self.Conf.db_name
        self.maxFilesPerTrigger = self.Conf.maxFilesPerTrigger
        spark.sql(f"USE {self.catalog}.{self.db_name}")
    
    def _get_user_date_grid(self):
        """生成用户与日期的完整笛卡尔积网格"""
        df_date = spark.read.table(f"{self.catalog}.{self.db_name}.date_list")
        df_user = spark.read.table(f"{self.catalog}.{self.db_name}.user_list")
        # 使用广播连接优化小表交叉连接性能
        return df_user.crossJoin(broadcast(df_date))
    
    def upsert_activity_daily_gold(self, once=True, processing_time="1 minute"):
        """以卡路里流为触发源，关联其他维度补全 Gold 表"""
        
        # 1. 定义多维度合并的 MERGE SQL
        # 涵盖：卡路里、步数、睡眠时长、以及四种活跃分钟数
        query = f"""
            MERGE INTO {self.catalog}.{self.db_name}.activity_daily_gold a
            USING activity_daily_gold_delta b
            ON a.user_id = b.user_id AND a.date = b.date
            WHEN MATCHED THEN UPDATE SET *
            WHEN NOT MATCHED THEN INSERT *
        """

        # 2. 定义内部的微批次处理函数 (foreachBatch)
        def process_micro_batch(df_batch, batch_id):
            if df_batch.isEmpty():
                return

            # 获取静态网格
            df_grid = self._get_user_date_grid()

            # 关联其他 Silver 层表（这些表在微批次执行时作为静态表读取当前快照）
            df_daily_steps = spark.read.table(f"{self.catalog}.{self.db_name}.steps_daily_sl")
            df_sleep = spark.read.table(f"{self.catalog}.{self.db_name}.sleep_daily_sl")
            df_daily_sleep = df_sleep.groupBy("user_id", "date").agg(F.sum("asleep_minutes").alias("asleep_minutes"), F.sum("total_minutes_in_bed").alias("total_minutes_in_bed"))
            df_daily_intensities = spark.read.table(f"{self.catalog}.{self.db_name}.intensities_daily_sl")
            df_daily_heartrate = spark.read.table(f"{self.catalog}.{self.db_name}.heartrate_daily_sl")

            # 以当前批次的 user/date 为基准，关联所有维度
            df_enriched = (df_batch.alias("cal")
                .join(df_grid, ["user_id", "date"], "inner") # 确保符合网格规范
                .join(df_daily_steps.alias("st"), ["user_id", "date"], "left")
                .join(df_daily_sleep.alias("sl"), ["user_id", "date"], "left")
                .join(df_daily_intensities.alias("it"), ["user_id", "date"], "left")
                .join(df_daily_heartrate.alias("hr"), ["user_id", "date"], "left")
                .select(
                    "user_id", "date",
                    F.coalesce(F.col("st.total_steps"), F.lit(0)).alias("total_steps"),
                    F.coalesce(F.col("cal.daily_calories"), F.lit(0)).alias("total_calories"),
                    F.coalesce(F.col("it.very_active_minutes"), F.lit(0)).alias("very_active_minutes"),
                    F.coalesce(F.col("it.fairly_active_minutes"), F.lit(0)).alias("fairly_active_minutes"),
                    F.coalesce(F.col("it.lightly_active_minutes"), F.lit(0)).alias("lightly_active_minutes"),
                    F.coalesce(F.col("it.sedentary_minutes"), F.lit(0)).alias("sedentary_minutes"),
                    F.coalesce(F.col("hr.avg_heartrate"), F.lit(0)).alias("avg_heartrate"),
                    F.coalesce(F.col("hr.max_heartrate"), F.lit(0)).alias("max_heartrate"),
                    F.coalesce(F.col("sl.asleep_minutes"), F.lit(0)).alias("asleep_minutes"),
                    F.coalesce(F.col("sl.total_minutes_in_bed"), F.lit(0)).alias("total_minutes_in_bed")
                )
            )

            # 写入临时视图并执行 MERGE
            df_enriched.createOrReplaceTempView("activity_daily_gold_delta")
            df_batch.sparkSession.sql(query)

        # 3. 启动流：以卡路里表作为驱动源
        # 只要卡路里有更新，就重新计算该用户该天的所有维度快照
        df_stream = (spark.readStream
            .option("ignoreDeletes", "true")
            .table(f"{self.catalog}.{self.db_name}.calories_daily_sl")
        )

        stream_writer = (df_stream.writeStream
            .foreachBatch(process_micro_batch)
            .outputMode("update")
            .option("checkpointLocation", f"{self.checkpoint_base}/activity_daily_gold")
            .queryName("activity_daily_gold_stream")
        )

        if once:
            return stream_writer.trigger(availableNow=True).start()
        else:
            return stream_writer.trigger(processingTime=processing_time).start()
  
    def _await_queries(self, once):
        if once:
            for stream in spark.streams.active:
                stream.awaitTermination()
    
    def upsert(self, once=True, processing_time="5 seconds"):
        import time
        start = int(time.time())
        print(f"\nExecuting gold layer upsert ...")
        self.upsert_activity_daily_gold(once, processing_time)
        self._await_queries(once)
        print(f"Completed gold layer  upsert {int(time.time()) - start} seconds")
