In [1]:
# Create the Spark Session
from pyspark.sql import SparkSession, Window
from pyspark.sql.functions import from_json, col, expr, when , lit , window, current_timestamp
from pyspark.sql.types import StructType, StructField, StringType, LongType, DoubleType, IntegerType, TimestampType, ArrayType

# Create the Spark Session
spark = (
    SparkSession.builder
    .appName("Transaction Streamin Job") 
    .config("spark.streaming.stopGracefullyOnShutdown", True)
    .config("spark.jars.packages", "org.apache.spark:spark-sql-kafka-0-10_2.12:3.5.0")
    .config("spark.sql.shuffle.partitions", "4")
    .master("local[*]")
    .getOrCreate()
)

In [2]:
# Check on spark object
spark

In [3]:
# Define the schema for the JSON data
schema = StructType([
    StructField("user_id", StringType(), True),
    StructField("amount", DoubleType(), True),
    StructField("timestamp", TimestampType(), True),
    StructField("source", StringType(), True)
])

In [4]:
# Create the kafka_df to read from kafka
kafka_df = (
    spark.readStream \
    .format("kafka") \
    .option("kafka.bootstrap.servers", "kafka:9094") \
    .option("subscribe", "transactions") \
    .option("startingOffsets", "earliest") \
    .load()
)

In [5]:
kafka_df.printSchema()

root
 |-- key: binary (nullable = true)
 |-- value: binary (nullable = true)
 |-- topic: string (nullable = true)
 |-- partition: integer (nullable = true)
 |-- offset: long (nullable = true)
 |-- timestamp: timestamp (nullable = true)
 |-- timestampType: integer (nullable = true)



In [6]:
# Deserialize and create the value_df to read from kafka    
value_df = kafka_df.select(from_json(col("value").cast("string"), schema).alias("value")).select("value.*")


In [7]:
# value_df.show()

In [8]:
value_df.printSchema()

root
 |-- user_id: string (nullable = true)
 |-- amount: double (nullable = true)
 |-- timestamp: timestamp (nullable = true)
 |-- source: string (nullable = true)



In [9]:
# Validate the data
validated_df = value_df.withColumn("error_reason", 
                                   when(col("user_id").isNull() | col("amount").isNull() | col("timestamp").isNull(), "Missing mandatory fields")
                                   .when((col("amount") < 1) | (col("amount") > 10000000), "Amount out of range")
                                   .when(~col("source").isin("mobile", "web", "pos"), "Invalid source")
                                   .otherwise(None)
                                  ).withColumn("is_valid", col("error_reason").isNull())

In [11]:
# Split to valid and invalid df
valid_df = validated_df.filter(col("is_valid") == True) \

invalid_df = validated_df.filter(col("is_valid") == False)

In [12]:
# Apply watermark and dedup

valid_df  = valid_df.withWatermark("timestamp", "3 minutes") \
                .dropDuplicates(["user_id", "timestamp"])

In [18]:
# Apply tumbling window monitoring

# window_agg_df = valid_df \
#                 .groupBy(window(col("timestamp"), "1 minutes")) \
#                 .count() \
#                 .select(current_timestamp().alias("timestamp"), col("window.start").alias("window_start"), col("window.end").alias("window_end"), col("count").alias("total_transactions"))

# window_agg_df.show()



window_agg_df = valid_df \
                .groupBy(window(col("timestamp"), "1 minutes")) \
                .count() \
                .select(
                    current_timestamp().alias("current_timestamp"),
                    col("window.start").alias("window_start"),
                    col("window.end").alias("window_end"),
                    col("count").alias("total_transactions"))


In [19]:
window_agg_df.printSchema()

root
 |-- window: struct (nullable = false)
 |    |-- start: timestamp (nullable = true)
 |    |-- end: timestamp (nullable = true)
 |-- count: long (nullable = false)



In [20]:
# Calculate running total
window_with_total = window_agg_df \
                    .withColumn("running_total", sum("count").over(Window.orderBy("window.start")))\
                    .select(
                        current_timestamp().alias("timestamp"),
                        col("window.start").alias("window_start"),
                        col("window.end").alias("window_end"),
                        col("transactions_in_window"),
                        col("running_total")
                    )
                        


TypeError: unsupported operand type(s) for +: 'int' and 'str'

In [11]:
# console_output_df = window_agg_df.select(
#                     current_timestamp().alias("timestamp"), 
#                     col("total_transactions").alias("running_total")
#                 )

In [12]:
# Write valid data to kafka

# Console output
query_console = window_with_total.writeStream \
                .outputMode("complete") \
                .format("console") \
                .option("truncate", "false") \
                .start()

# Valid data to kafka
valid_query = valid_df.selectExpr("to_json(struct(*)) as value") \
                        .writeStream \
                        .format("kafka") \
                        .option("kafka.bootstrap.servers", "kafka:9094") \
                        .option("topic", "transactions_valid") \
                        .option("checkpointLocation", "checkpoints/valid") \
                        .start()

# Invalid to dlq
invalid_dlq = invalid_df.selectExpr("to_json(struct(*)) as value") \
                        .writeStream \
                        .format("kafka") \
                        .option("kafka.bootstrap.servers", "kafka:9094") \
                        .option("topic", "transactions_dlq") \
                        .option("checkpointLocation", "checkpoints/dlq") \
                        .start()
                        
spark.streams.awaitAnyTermination()

ERROR:root:KeyboardInterrupt while sending command.
Traceback (most recent call last):
  File "/usr/local/spark/python/lib/py4j-0.10.9.7-src.zip/py4j/java_gateway.py", line 1038, in send_command
    response = connection.send_command(command)
               ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/usr/local/spark/python/lib/py4j-0.10.9.7-src.zip/py4j/clientserver.py", line 511, in send_command
    answer = smart_decode(self.stream.readline()[:-1])
                          ^^^^^^^^^^^^^^^^^^^^^^
  File "/opt/conda/lib/python3.11/socket.py", line 706, in readinto
    return self._sock.recv_into(b)
           ^^^^^^^^^^^^^^^^^^^^^^^
KeyboardInterrupt


KeyboardInterrupt: 