In [None]:
from pyspark.sql import SparkSession
from pyspark.ml import PipelineModel
from pyspark.ml.feature import VectorAssembler
from pyspark.ml.classification import RandomForestClassificationModel
from pyspark.sql.functions import col, from_json, expr
from pyspark.sql.functions import from_csv
from pyspark.sql.types import StructType, StructField, FloatType
from pyspark.sql.functions import udf
from pyspark.sql.types import DoubleType
from pyspark.ml.linalg import Vectors
import matplotlib.pyplot as plt
import os
import sys

os.environ['PYSPARK_PYTHON'] = sys.executable
os.environ['PYSPARK_DRIVER_PYTHON'] = sys.executable

# Initialize Spark session
spark = SparkSession.builder \
    .appName("KafkaConsumerApplication") \
    .config("spark.jars.packages", "org.apache.spark:spark-sql-kafka-0-10_2.12:3.0.1") \
    .getOrCreate()

# Define the schema of the input data
schema = StructType([
    StructField("WDIR", FloatType(), True),
    StructField("WSPD", FloatType(), True),
    StructField("GST", FloatType(), True),
    StructField("PRES", FloatType(), True),
    StructField("ATMP", FloatType(), True),
])

# Read the streaming data from Kafka
df = spark \
    .readStream \
    .format("kafka") \
    .option("kafka.bootstrap.servers", "localhost:9092") \
    .option("subscribe", "test") \
    .load()

# # Load the trained model
model = RandomForestClassificationModel.load("models/random_forest_model")

# Select the 'value' column and cast it to a string
value_df = df.select(col("value").cast("string"))

# Deserialize JSON from the 'value' column using the provided schema
json_df = df.select(from_json(col("value").cast("string"), schema).alias("data"))

# Flatten the structure and select the JSON fields
flattened_df = json_df.select("data.*")

# UDF to extract the probability of the positive class
extract_prob_udf = udf(lambda v: float(v[1]), DoubleType())


def process_batch(batch_df, epoch_id):
    
    # Define the assembler with the input column names that match trained model's feature names
    assembler = VectorAssembler(inputCols=["WDIR", "WSPD", "GST", "PRES", "ATMP"], outputCol="features")

    # Transform the batch DataFrame to create the feature vector
    batch_df = assembler.transform(batch_df)

    # Ensure there is data in the batch before making predictions
    if batch_df.count() > 0:
        # Make predictions using the loaded model
        predictions = model.transform(batch_df)
        predictions = predictions.withColumn("probability", extract_prob_udf(col("probability")))

        # Select the relevant columns to display
        predictions = predictions.select(
            col("WDIR"),
            col("WSPD"),
            col("GST"),
            col("PRES"),
            col("ATMP"),
            col("probability"),
            col("prediction")
        )

        # Show the formatted predictions
        predictions.show(truncate=False)

# Apply the process_batch function to each micro-batch of data
query = flattened_df.writeStream.foreachBatch(process_batch).start()

query.awaitTermination()

+-----+----+----+------+----+-------------------+----------+
|WDIR |WSPD|GST |PRES  |ATMP|probability        |prediction|
+-----+----+----+------+----+-------------------+----------+
|310.0|8.0 |10.0|1018.2|15.4|0.28295528416307675|0.0       |
|330.0|4.0 |5.0 |1019.4|14.4|0.3093080747049778 |0.0       |
|350.0|6.0 |7.0 |1020.2|14.1|0.29826484162227107|0.0       |
|350.0|7.0 |9.0 |1022.6|13.2|0.3143281837151583 |0.0       |
|290.0|4.0 |5.0 |1016.2|17.6|0.28089443894267196|0.0       |
|310.0|1.0 |2.0 |1019.7|15.4|0.24955256522482916|0.0       |
|340.0|5.0 |7.0 |1023.0|12.9|0.31003743501318853|0.0       |
|340.0|8.0 |10.0|1017.3|15.4|0.3658174074831528 |0.0       |
|340.0|5.0 |7.0 |1019.1|14.8|0.3093080747049778 |0.0       |
|270.0|2.0 |3.0 |1014.9|16.9|0.3412338204278479 |0.0       |
|310.0|9.0 |11.0|1015.5|16.2|0.4258387917944176 |0.0       |
|310.0|8.0 |10.0|1015.3|16.2|0.4258387917944176 |0.0       |
|260.0|4.0 |5.0 |1014.6|18.1|0.31273490920463176|0.0       |
+-----+----+----+------+

+-----+----+----+-------+----+-------------------+----------+
|WDIR |WSPD|GST |PRES   |ATMP|probability        |prediction|
+-----+----+----+-------+----+-------------------+----------+
|310.0|9.0 |11.0|1018.15|15.4|0.28295528416307675|0.0       |
|330.0|6.0 |8.0 |1019.1 |14.1|0.3093080747049778 |0.0       |
|340.0|6.0 |8.0 |1020.4 |14.1|0.29826484162227107|0.0       |
|350.0|7.0 |9.0 |1022.6 |13.3|0.296413499048581  |0.0       |
|320.0|5.0 |6.0 |1016.2 |17.6|0.33813196516780863|0.0       |
|310.0|3.0 |4.0 |1019.8 |15.4|0.2621984048706664 |0.0       |
|340.0|5.0 |8.0 |1023.0 |13.0|0.31003743501318853|0.0       |
|340.0|7.0 |8.0 |1017.2 |15.4|0.3322649438140547 |0.0       |
|320.0|5.0 |7.0 |1019.1 |14.9|0.3093080747049778 |0.0       |
|250.0|3.0 |4.0 |1014.8 |16.8|0.31273490920463176|0.0       |
|310.0|8.0 |10.0|1015.4 |16.2|0.4258387917944176 |0.0       |
|310.0|8.0 |9.0 |1015.3 |16.2|0.4258387917944176 |0.0       |
|270.0|6.0 |7.0 |1014.6 |18.1|0.3437835654533454 |0.0       |
+-----+-

+-----+----+----+------+----+-------------------+----------+
|WDIR |WSPD|GST |PRES  |ATMP|probability        |prediction|
+-----+----+----+------+----+-------------------+----------+
|320.0|10.0|12.0|1018.0|15.4|0.3278653103523855 |0.0       |
|330.0|6.0 |8.0 |1019.0|14.2|0.3093080747049778 |0.0       |
|330.0|7.0 |8.0 |1020.5|14.1|0.3060783175969681 |0.0       |
|360.0|7.0 |9.0 |1022.3|13.2|0.3143281837151583 |0.0       |
|320.0|4.0 |5.0 |1016.2|17.6|0.33813196516780863|0.0       |
|300.0|4.0 |5.0 |1019.9|15.4|0.264398048515669  |0.0       |
|340.0|5.0 |7.0 |1022.5|12.9|0.31003743501318853|0.0       |
|340.0|7.0 |8.0 |1017.1|15.4|0.3322649438140547 |0.0       |
|320.0|5.0 |6.0 |1018.9|14.9|0.3093080747049778 |0.0       |
|260.0|3.0 |4.0 |1014.8|16.6|0.31273490920463176|0.0       |
|310.0|8.0 |9.0 |1015.5|16.1|0.4258387917944176 |0.0       |
|310.0|8.0 |10.0|1015.4|16.2|0.4258387917944176 |0.0       |
|280.0|7.0 |8.0 |1014.7|18.0|0.38509086383135094|0.0       |
+-----+----+----+------+

+-----+----+----+-------+----+-------------------+----------+
|WDIR |WSPD|GST |PRES   |ATMP|probability        |prediction|
+-----+----+----+-------+----+-------------------+----------+
|320.0|9.0 |11.0|1017.95|15.3|0.3278653103523855 |0.0       |
|330.0|6.0 |8.0 |1019.0 |14.1|0.3093080747049778 |0.0       |
|340.0|6.0 |8.0 |1020.6 |14.1|0.29826484162227107|0.0       |
|350.0|7.0 |9.0 |1022.3 |13.3|0.296413499048581  |0.0       |
|330.0|4.0 |5.0 |1016.1 |17.3|0.33813196516780863|0.0       |
|310.0|5.0 |7.0 |1019.7 |15.3|0.264398048515669  |0.0       |
|340.0|4.0 |7.0 |1022.5 |12.8|0.31100366826842435|0.0       |
|340.0|6.0 |7.0 |1017.2 |15.3|0.3244514678393577 |0.0       |
|310.0|5.0 |6.0 |1018.7 |15.0|0.264398048515669  |0.0       |
|250.0|3.0 |4.0 |1014.7 |16.6|0.31273490920463176|0.0       |
|300.0|8.0 |10.0|1015.7 |16.0|0.4195544634745162 |0.0       |
|310.0|8.0 |9.0 |1015.2 |16.2|0.4258387917944176 |0.0       |
|280.0|7.0 |8.0 |1014.7 |17.9|0.38509086383135094|0.0       |
+-----+-

+-----+----+----+------+----+-------------------+----------+
|WDIR |WSPD|GST |PRES  |ATMP|probability        |prediction|
+-----+----+----+------+----+-------------------+----------+
|330.0|10.0|12.0|1017.9|15.3|0.35067401434877293|0.0       |
|330.0|7.0 |9.0 |1019.1|14.1|0.31712155067967485|0.0       |
|330.0|7.0 |9.0 |1020.6|14.1|0.3060783175969681 |0.0       |
|360.0|7.0 |8.0 |1022.2|13.3|0.296413499048581  |0.0       |
|320.0|4.0 |5.0 |1016.1|17.2|0.33813196516780863|0.0       |
|330.0|5.0 |6.0 |1019.8|15.3|0.3093080747049778 |0.0       |
|350.0|3.0 |5.0 |1022.3|12.8|0.30880402462342177|0.0       |
|330.0|6.0 |8.0 |1017.2|15.3|0.3244514678393577 |0.0       |
|300.0|6.0 |7.0 |1018.6|15.1|0.264398048515669  |0.0       |
|250.0|3.0 |5.0 |1014.8|16.4|0.31273490920463176|0.0       |
|310.0|9.0 |11.0|1015.7|15.8|0.4195544634745162 |0.0       |
|320.0|8.0 |10.0|1015.4|16.2|0.4577555036274246 |0.0       |
|280.0|7.0 |8.0 |1014.7|17.7|0.38509086383135094|0.0       |
+-----+----+----+------+