In [0]:
%run "./01_config"

In [0]:
class Upserter:
    def __init__(self, merge_query, temp_view_name):
        self.merge_query = merge_query
        self.temp_view_name = temp_view_name

    def upsert(self, df_micro_batch, batch_id):
        df_micro_batch.createOrReplaceTempView(self.temp_view_name)
        df_micro_batch._jdf.sparkSession().sql(self.merge_query)

In [0]:
class Gold():
    def __init__(self, env):
        self.Conf = Config()
        self.checkpoint_base = self.Conf.project_dir + "checkpoints"
        self.catalog = f"sbit_{env}_catalog"
        self.db_name = self.Conf.db_name
        self.maxFilesPerTrigger = self.Conf.maxFilesPerTrigger
        spark.sql(f"USE {self.catalog}.{self.db_name}")

    def upsert_workout_bpm_summary(self, once=True, processing_time="15 seconds", startingVersion=0):
        from pyspark.sql import functions as F

        # Idempotent - Once a workout session is complete, it doesn't change. So insert only the new records
        query = f"""
            MERGE INTO {self.catalog}.{self.db_name}.workout_bpm_summary a
            USING workout_bpm_summary_delta b
            ON a.user_id=b.user_id AND a.workout_id = b.workout_id AND a.session_id=b.session_id
            WHEN NOT MATCHED THEN INSERT *
        """

        data_upserter = Upserter(query, "workout_bpm_summary_delta")

        # 加载用户维度信息（来自 Silver 层的人口统计分箱表）
        df_users = spark.read.table(f"{self.catalog}.{self.db_name}.user_bins")

        # 流式读取并聚合
        df_delta = (spark.readStream
                    .option("startingVersion", startingVersion)
                    .option("ignoreDeletes", True)
                    .option("withEventTimeOrder", "true")
                    .option("maxFilesPerTrigger", self.maxFilesPerTrigger)
                    .table(f"{self.catalog}.{self.db_name}.workout_bpm")
                    .withWatermark("end_time", "30 seconds")
                    .groupBy("user_id", "workout_id", "session_id", "end_time")
                    .agg(F.min("heartrate").alias("min_bpm"), 
                         F.mean("heartrate").alias("avg_bpm"),
                         F.max("heartrate").alias("max_bpm"), 
                         F.count("heartrate").alias("num_recordings"))
                    .join(df_users, ["user_id"])
                    .select("workout_id", "session_id", "user_id", "age", "gender", "city", "state", 
                            "min_bpm", "avg_bpm", "max_bpm", "num_recordings")
                   )

        stream_writer = (df_delta.writeStream
                         .foreachBatch(data_upserter.upsert)
                         .outputMode("update")
                         .option("checkpointLocation", f"{self.checkpoint_base}/workout_bpm_summary")
                         .queryName("workout_bpm_summary_upsert_stream")
                        )

        spark.sparkContext.setLocalProperty("spark.scheduler.pool", "gold_p1")

        if once == True:
            return stream_writer.trigger(availableNow=True).start()
        else:
            return stream_writer.trigger(processingTime=processing_time).start()


    def upsert_gym_summary(self, once=True, processing_time="15 seconds", startingVersion=0):
        from pyspark.sql import functions as F

        # Idempotent - Once a workout session is complete, it doesn't change. So insert only the new records
        query = f"""
            MERGE INTO {self.catalog}.{self.db_name}.gym_summary a
            USING gym_summary_delta b
            ON a.workout_id = b.workout_id AND a.session_id=b.session_id
            WHEN NOT MATCHED THEN INSERT *
        """

        data_upserter = Upserter(query, "gym_summary_delta")

        # 加载用户维度信息（来自 Silver 层的人口统计分箱表）
        df_users = spark.read.table(f"{self.catalog}.{self.db_name}.users").select('user_id', 'mac_address')
        df_gym_logs = spark.read.table(f"{self.catalog}.{self.db_name}.gym_logs")
        df_gym_logs_combined_users = df_gym_logs.join(df_users, ["mac_address"], "left")

        # 流式读取并聚合
        df_delta = (spark.readStream
                    .option("startingVersion", startingVersion)
                    .option("ignoreDeletes", True)
                    .option("withEventTimeOrder", "true")
                    .option("maxFilesPerTrigger", self.maxFilesPerTrigger)
                    .table(f"{self.catalog}.{self.db_name}.completed_workouts")
                    .withWatermark("end_time", "30 seconds")
                    .join(df_gym_logs_combined_users, ["user_id"], 'left')
                    .withColumn("date", F.date_format(F.col('logout'), "yyyy-MM-dd"))
                    .withColumn("minutes_in_gym", (F.unix_timestamp(F.col('logout')) - F.unix_timestamp(F.col('login'))) / 60.0)
                    .withColumn("minutes_exercising", (F.unix_timestamp(F.col('end_time')) - F.unix_timestamp(F.col('start_time'))) / 60.0)
                    .select("date", "gym", "mac_address", "workout_id", "session_id", 
                    F.col("minutes_in_gym").cast("double"), 
                    F.col("minutes_exercising").cast("double"))
                   )

        stream_writer = (df_delta.writeStream
                         .foreachBatch(data_upserter.upsert)
                         .outputMode("append")
                         .option("checkpointLocation", f"{self.checkpoint_base}/gym_summary")
                         .queryName("gym_summary_upsert_stream")
                        )

        spark.sparkContext.setLocalProperty("spark.scheduler.pool", "gold_p1")

        if once == True:
            return stream_writer.trigger(availableNow=True).start()
        else:
            return stream_writer.trigger(processingTime=processing_time).start()


    def upsert(self, once=True, processing_time="5 seconds"):
        import time
        start = int(time.time())
        print(f"\nExecuting gold layer upsert ...")
        
        # 启动汇总流任务
        self.upsert_workout_bpm_summary(once, processing_time)
        self.upsert_gym_summary(once, processing_time)

        # 如果是 Batch 模式 (once=True)，等待任务完成
        if once:
            for stream in spark.streams.active:
                stream.awaitTermination()
                
        print(f"Completed gold layer upsert {int(time.time()) - start} seconds")