In [None]:
%run ./CityLocations

In [0]:
package com.databricks.dais2025

import org.apache.spark.sql.{DataFrame, SparkSession}
import org.apache.spark.sql.functions._
import org.apache.spark.sql.types._

/**
 * Generate synthetic environmental sensor data using Spark rate source
 * 
 * This is a Scala equivalent of the Python example6_rate_source.py
 */
object SensorDataGenerator {

  val sensorSchema = StructType(Seq(
    StructField("sensor_id", StringType, nullable = true),
    StructField("city", StringType, nullable = true),
    StructField("location", StringType, nullable = true),
    StructField("timestamp", TimestampType, nullable = true),
    StructField("temperature", DoubleType, nullable = false),
    StructField("humidity", DoubleType, nullable = false),
    StructField("co2_level", DoubleType, nullable = false),
    StructField("pm25_level", DoubleType, nullable = false)
  ))

  /**
   * Generate environmental sensor data with city information using Spark rate source
   * 
   * @param spark SparkSession object
   * @param rowsPerSecond Rate of data generation (default: 10)
   * @param numPartitions Number of partitions for the rate source (default: 4)
   * @return Streaming DataFrame with environmental sensor data
   */
  def createStream(
      spark: SparkSession,
      rowsPerSecond: Int = 10,
      numPartitions: Int = 4,
      partitionBySensorId: Boolean = false
  ): DataFrame = {

    // Reference city locations from external object (see CityLocations.ipynb)
    val cityLocations = CityLocations.cityLocations
    val cities = cityLocations.keys.toSeq

    // Start with rate source
    var df = spark
      .readStream
      .format("rate")
      .option("rowsPerSecond", rowsPerSecond)
      .option("numPartitions", numPartitions)
      .load()

    // Rename value to generator_id to match dbldatagen
    df = df.withColumnRenamed("value", "generator_id")

    // Add city_for_id column (cycling through cities based on generator_id)
    df = df.withColumn(
      "city_for_id",
      element_at(
        array(cities.map(lit): _*),
        (col("generator_id") % lit(cities.length)).cast(IntegerType) + 1
      )
    )

    // Add sensor_id
    df = df.withColumn(
      "sensor_id",
      concat(
        substring(col("city_for_id"), 1, 3),
        lit("-SENSOR-"),
        col("generator_id").cast(StringType)
      )
    )

    // Add city column
    df = df.withColumn("city", col("city_for_id"))

    // Add location column with case statement for each city
    var locationExpr: org.apache.spark.sql.Column = null

    for ((city, locations) <- cityLocations) {
      val cityArray = array(locations.map(lit): _*)
      val cityCondition = when(
        col("city") === lit(city),
        element_at(cityArray, (floor(rand() * locations.length)).cast(IntegerType) + 1)
      )

      if (locationExpr == null) {
        locationExpr = cityCondition
      } else {
        locationExpr = locationExpr.when(
          col("city") === lit(city),
          element_at(cityArray, (floor(rand() * locations.length)).cast(IntegerType) + 1)
        )
      }
    }

    df = df.withColumn("location", locationExpr)

    // Use the streaming timestamp from rate source
    df = df.withColumnRenamed("timestamp", "reading_timestamp")

    // Add temperature with multiplier (to create anomalies)
    df = df.withColumn(
      "temp_multiplier",
      when(rand() < 0.2, lit(3.0))
        .when(rand() < 0.4, lit(0.2))
        .otherwise(lit(1.0))
    )

    // Base temperature by city
    df = df.withColumn(
      "base_temp",
      when(col("city") === "Tokyo", lit(25) + rand() * 5)
        .when(col("city") === "Sydney", lit(22) + rand() * 5)
        .when(col("city") === "New York", lit(20) + rand() * 5)
        .when(col("city") === "London", lit(18) + rand() * 5)
        .when(col("city") === "Paris", lit(21) + rand() * 5)
        .otherwise(lit(20) + rand() * 5)
    )

    df = df.withColumn("temperature", col("base_temp") * col("temp_multiplier"))

    // Humidity with multiplier
    df = df.withColumn(
      "humidity_multiplier",
      when(rand() < 0.25, lit(1.8)).otherwise(lit(1.0))
    )

    df = df.withColumn(
      "humidity",
      least(lit(100), (rand() * 60 + 30) * col("humidity_multiplier"))
    )

    // CO2 level with multiplier
    df = df.withColumn(
      "co2_multiplier",
      when(rand() < 0.3, lit(2.5)).otherwise(lit(1.0))
    )

    df = df.withColumn(
      "co2_level",
      lit(350) + rand() * 800 * col("co2_multiplier")
    )

    // PM2.5 level with multiplier
    df = df.withColumn(
      "pm25_multiplier",
      when(rand() < 0.25, lit(3.0)).otherwise(lit(1.0))
    )

    df = df.withColumn(
      "pm25_level",
      rand() * 30 * col("pm25_multiplier")
    )

    // Drop intermediate columns
    df = df.drop(
      "base_temp", "temp_multiplier", "humidity_multiplier",
      "co2_multiplier", "pm25_multiplier", "city_for_id"
    )

    df.selectExpr(
      "sensor_id",
      "city",
      "location",
      "reading_timestamp as timestamp",
      "temperature",
      "humidity",
      "co2_level",
      "pm25_level"
    )
  }
}