In [0]:
import org.apache.spark.sql.SparkSession
import org.apache.spark.sql.streaming.StreamingQuery
import play.api.libs.json._
import play.api.libs.functional.syntax._
import com.databricks.dais2025.SensorDataGenerator
import org.apache.spark.sql.execution.streaming.functions.current_batch_id // EDGE

object HelperFunctions {

  // ---------------------------
  // Latency parsing
  // ---------------------------
 case class LatencyMetrics(P0: Long, P50: Long, P90: Long, P95: Long, P99: Long)
  case class Latencies(
    processingLatencyMs: Option[LatencyMetrics],
    sourceQueuingLatencyMs: Option[LatencyMetrics],
    e2eLatencyMs: Option[LatencyMetrics]
  )
  case class StateOperatorMetrics(
    operatorName: String,
    numRowsTotal: Long,
    numRowsUpdated: Long,
    numRowsRemoved: Long,
    commitTimeMs: Long,
    memoryUsedBytes: Long,
    customMetrics: Map[String, Long]
  )

  // JSON readers
  implicit val latencyMetricsReads: Reads[LatencyMetrics] = Json.reads[LatencyMetrics]
  implicit val latenciesReads: Reads[Latencies] = Json.reads[Latencies]
  implicit val stateOperatorMetricsReads: Reads[StateOperatorMetrics] = (
    (JsPath \ "operatorName").read[String] and
    (JsPath \ "numRowsTotal").readWithDefault[Long](0L) and
    (JsPath \ "numRowsUpdated").readWithDefault[Long](0L) and
    (JsPath \ "numRowsRemoved").readWithDefault[Long](0L) and
    (JsPath \ "commitTimeMs").readWithDefault[Long](0L) and
    (JsPath \ "memoryUsedBytes").readWithDefault[Long](0L) and
    (JsPath \ "customMetrics").readWithDefault[Map[String, Long]](Map.empty)
  )(StateOperatorMetrics.apply _)


  // ---------------------------
  // Core Functions
  // ---------------------------

  /** Parses the JSON field from StreamingQueryProgress into Latencies + State Operators */
  def parseProgressMetrics(progress: org.apache.spark.sql.streaming.StreamingQueryProgress): (Option[Latencies], Seq[StateOperatorMetrics]) = {
    val jsValue = Json.parse(progress.json)

    val latenciesOpt = (jsValue \ "latencies").asOpt[Latencies]
    val stateOps = (jsValue \ "stateOperators").asOpt[Seq[StateOperatorMetrics]].getOrElse(Seq.empty)

    (latenciesOpt, stateOps)
  }


  /** Pretty prints latency metrics if present */
  def printLatencies(latenciesOpt: Option[Latencies]): Unit = {
    def fmt(label: String, opt: Option[LatencyMetrics]): String = opt match {
      case Some(m) => f"$label%-28s P50=${m.P50}%-8d P90=${m.P90}%-8d P95=${m.P95}%-8d P99=${m.P99}%-8d"
      case None    => f"$label%-28s N/A"
    }

    latenciesOpt match {
      case Some(lat) =>
        println("Latencies:")
        println("  " + fmt("E2E Latency (ms):", lat.e2eLatencyMs))
        println("  " + fmt("Processing Latency (ms):", lat.processingLatencyMs))
        println("  " + fmt("Source Queuing Latency (ms):", lat.sourceQueuingLatencyMs))
      case None =>
        println("Latencies: N/A")
    }
  }

  /** Main entry point â€” prints everything */
  def printRTMStreamMetrics(
    progressSeq: Seq[org.apache.spark.sql.streaming.StreamingQueryProgress],
    printLatenciesFlag: Boolean = true,
    printTransformWithStateMetricsFlag: Boolean = false
  ): Unit = {
    progressSeq.foreach { progress =>
      val batchId = progress.batchId
      val processedRows = progress.processedRowsPerSecond
      println(s"\nBatchId: $batchId | ProcessedRows/s: $processedRows")

      val (latenciesOpt, stateOps) = parseProgressMetrics(progress)

      if (printLatenciesFlag) {
        printLatencies(latenciesOpt)
      }
      
      if (printTransformWithStateMetricsFlag) {
        println("Not Added Yet")
        // printTransformWithStateMetrics(stateOps)
      }
    }
  }

  def createKafkaTopic(topicName: String, props: java.util.Properties, partitionCount: Int = 4, replicationFactor: Short = 2): Unit = {
    import kafkashaded.org.apache.kafka.clients.admin.{AdminClient, NewTopic}
    import kafkashaded.org.apache.kafka.common.errors.TopicExistsException

    val adminClient = AdminClient.create(props)
    val newTopic = new NewTopic(topicName, partitionCount, replicationFactor)
    try {
      adminClient.createTopics(java.util.Collections.singleton(newTopic)).all().get()
      println(s"Kafka topic '$topicName' created successfully.")
    } catch {
      case e: java.util.concurrent.ExecutionException if e.getCause.isInstanceOf[TopicExistsException] =>
        println(s"Kafka topic '$topicName' already exists.")
      case e: Exception =>
        println(s"Error creating Kafka topic '$topicName': ${e.getMessage}")
        throw e
    } finally {
      adminClient.close()
    }
  }

  def deleteKafkaTopic(topicName: String, props: java.util.Properties): Unit = {
    import kafkashaded.org.apache.kafka.clients.admin.AdminClient
    val adminClient = AdminClient.create(props)
    adminClient.deleteTopics(java.util.Collections.singleton(topicName)).all().get()
    println(s"Kafka topic '$topicName' deleted successfully.")
    adminClient.close()
  }

  def getLastProgress(query: org.apache.spark.sql.streaming.StreamingQuery): Option[org.apache.spark.sql.streaming.StreamingQueryProgress] = {
    Option(query.lastProgress)
  }

  def stopStreamAfterBatchesCollectProgress(
    spark: SparkSession,
    stream: org.apache.spark.sql.Dataset[org.apache.spark.sql.Row],
    queryName: String,
    maxBatches: Int,
    trigger: org.apache.spark.sql.streaming.Trigger,
    writeStreamOptions: Map[String, String],
    runId: String
  ): Seq[org.apache.spark.sql.streaming.StreamingQueryProgress] = {
    val df = stream
    val triggerTypeHeader = trigger match {
      case t if t.toString.toLowerCase.contains("realtime") => ("triggertype", "realtime")
      case t if t.toString.toLowerCase.contains("processingtime") => ("triggertype", "processingtime0")
      case _ => ("triggertype", "unknown")
    }
    import org.apache.spark.sql.functions._
    val headersCol = array(
      struct(lit(triggerTypeHeader._1).as("key"), lit(triggerTypeHeader._2).cast("binary").alias("value")),
      struct(lit("runId").as("key"), lit(runId).cast("binary").alias("value")),
      struct(lit("batchId").as("key"), current_batch_id().cast("string").cast("binary").alias("value"))
    )
    val dfWithHeader = df.withColumn("headers", headersCol)
    val writeStream = dfWithHeader.writeStream
      .format("kafka")
      .queryName(queryName)
      .trigger(trigger)
    writeStreamOptions.foreach { case (k, v) => writeStream.option(k, v) }
    val query = writeStream.outputMode("update").start()
    var lastBatchId = -1L
    var batchCount = 0
    val progressBuffer = scala.collection.mutable.Buffer[org.apache.spark.sql.streaming.StreamingQueryProgress]()
    while (query.isActive && batchCount < maxBatches) {
      getLastProgress(query).foreach { s =>
        if (s.batchId > lastBatchId) {
          lastBatchId = s.batchId
          batchCount += 1
          progressBuffer += s
          println("\nPretty JSON for last StreamingQueryProgress:")
          println(Json.prettyPrint(Json.parse(s.json)))
        }
      }
      Thread.sleep(200)
    }
    if (query.isActive) query.stop()
    progressBuffer.toSeq
  }
  
  def readAndFilterKafkaBatch(
    spark: SparkSession,
    kafkaBatchReadOptions: Map[String, String],
    runId: String
  ): org.apache.spark.sql.DataFrame = {
    import org.apache.spark.sql.functions.{expr, col, from_json, unix_millis, lit}

    val kafkaBatchDF = {
      val reader = spark.read.format("kafka")
      kafkaBatchReadOptions.foreach { case (k, v) => reader.option(k, v) }
      reader.load()
    }

    kafkaBatchDF.filter(
      expr(s"exists(headers, h -> h.key = 'runId' AND h.value = CAST('$runId' AS BINARY))")
    )
    .withColumn("triggertype", expr("filter(headers, h -> h.key = 'triggertype')[0].value").cast("string"))
    .withColumn("batchId", expr("filter(headers, h -> h.key = 'batchId')[0].value").cast("string").cast("int"))
    .withColumn("data", from_json(col("value").cast("string"), SensorDataGenerator.sensorSchema))
    .withColumn("sink-timestamp", unix_millis(col("timestamp")))
    .withColumn("source-timestamp", unix_millis(col("data.timestamp")))
    .withColumn("latency", col("sink-timestamp") - col("source-timestamp")) // talk about how we calculate latency
    .withColumn("runId", lit(runId))
    .select("data.*", "triggertype", "latency", "runId", "batchId")
  }
}