In [None]:
from pyspark.sql import SparkSession
from pyspark.sql.functions import from_json, col, to_json, struct
from pyspark.sql.types import StructType, StructField, StringType, DoubleType, TimestampType
from pyspark.sql.streaming import GroupState, GroupStateTimeout
from pyspark.sql.functions import expr
from pyspark.sql.types import Row

# Spark session
spark = SparkSession.builder \
    .appName("StatefulAnomalyDetection") \
    .getOrCreate()

# Schema
schema = StructType([
    StructField("ticker", StringType(), True),
    StructField("price", DoubleType(), True),
    StructField("timestamp", StringType(), True)
])

# Kafka input
kafka_streaming_df = spark.readStream \
    .format("kafka") \
    .option("kafka.bootstrap.servers", "kafka:9092") \
    .option("subscribe", "test_topic") \
    .option("startingOffsets", "latest") \
    .load()

# Parse messages
parsed_df = kafka_streaming_df.selectExpr("CAST(value AS STRING)") \
    .select(from_json(col("value"), schema).alias("data")) \
    .select("data.*") \
    .withColumn("timestamp", col("timestamp").cast(TimestampType()))

# Define update function
def detect_anomalies(ticker, rows, state: GroupState):
    prev_price = state.get("prev_price") if state.exists else None
    output = []

    for row in rows:
        current_price = row.price
        timestamp = row.timestamp

        if prev_price is not None:
            change_pct = ((current_price - prev_price) / prev_price) * 100
            if abs(change_pct) > 10:
                output.append(Row(ticker=ticker,
                                  price=current_price,
                                  prev_price=prev_price,
                                  price_change_pct=change_pct,
                                  timestamp=timestamp))

        # update state
        state.update({"prev_price": current_price})

    return output

# Apply stateful logic
from pyspark.sql.functions import expr
from pyspark.sql.types import StructType

output_schema = StructType([
    StructField("ticker", StringType()),
    StructField("price", DoubleType()),
    StructField("prev_price", DoubleType()),
    StructField("price_change_pct", DoubleType()),
    StructField("timestamp", TimestampType())
])

anomalies = parsed_df.groupBy("ticker").applyInPandasWithState(
    detect_anomalies,
    output_schema,
    stateStructType=StructType([StructField("prev_price", DoubleType())]),
    outputMode="append",
    timeoutConf="10 minutes"
)

# Output to console
query = anomalies.writeStream \
    .outputMode("append") \
    .format("console") \
    .option("truncate", False) \
    .start()

print("✅ Streaming con stato attivo: rilevamento anomalie in corso...")
query.awaitTermination()
