- Tumbling window are of fixed size and non-overlapping.
- Each event will be part of only 1 window.

In [0]:
from pyspark.sql.functions import *
from pyspark.sql.types import *

class TradeSummary:
    def __init__(self):
        self.raw_table_name = "dev.finance.trades_raw"
    
    def get_trades_raw_schema(self):
        return StructType([
            StructField("CreatedTime", StringType(), True),
            StructField("Type", StringType(), True),
            StructField("Amount", DoubleType(), True),
            StructField("BrokerCode", StringType(), True),
        ])
    
    def get_trades_raw(self):
        trades_raw_df = (spark
                         .readStream
                         .format("delta")
                         .table(self.raw_table_name)
                         )
        return trades_raw_df
    
    def format_trades_raw(self,trades_raw_df):
        formatted_trades_df = (trades_raw_df
                               .select(from_json(col("value"), self.get_trades_raw_schema()).alias("value"))
                               .select("value.*")
                               .withColumn("CreatedTime", to_timestamp("CreatedTime", "yyyy-MM-dd HH:mm:ss"))
                               .withColumn("BUY", expr('case when type == "BUY" then amount  else 0 end'))
                               .withColumn("SELL", expr('case when type == "SELL" then amount  else 0 end'))
                               )
        return formatted_trades_df
    
    def process_trades(self,formatted_trades_df):
        # Watermark is basically to tell how far back to process the data. How long should wait for late record.
        # Otherwise the state will always be maintained in the state store.
        
        processed_trades_df = (formatted_trades_df.withWatermark("CreatedTime", "30 minutes")
                               .groupBy(window(col("CreatedTime"),"15 minutes"))
                               .agg(sum("BUY").alias("total_buy"), sum("SELL").alias("total_sell"))
                               .select("window.start","window.end", "total_buy", "total_sell")
                               )
        return processed_trades_df
    
    def write_processed_trades(self,processed_trades_df):
        streaming_agg_query = (processed_trades_df
                               .writeStream
                               .queryName("trades_agg")
                               .format("delta")
                               .outputMode("complete")
                               .option("checkpointLocation","/Volumes/dev/finance/trades_summary_checkpoint")
                               # Complete mode is supported with toTable. Complete mode is equivalent to insert overwrite
                               .table("dev.finance.trades_summary")
                               )
        return streaming_agg_query

    def start_trade_aggregates(self):
        raw_data = self.get_trades_raw()
        formatted_data = self.format_trades_raw(raw_data)
        processed_data = self.process_trades(formatted_data)
        return self.write_processed_trades(processed_data)

    


In [0]:
trades_agg = TradeSummary()
streaming_agg_query = trades_agg.start_trade_aggregates()

In [0]:
streaming_agg_query.stop()

In [0]:
%sql
select * from dev.finance.trades_summary order by end;

start,end,total_buy,total_sell
2019-02-05T10:00:00Z,2019-02-05T10:15:00Z,800.0,0.0
2019-02-05T10:15:00Z,2019-02-05T10:30:00Z,600.0,400.0
2019-02-05T10:30:00Z,2019-02-05T10:45:00Z,900.0,0.0
2019-02-05T10:45:00Z,2019-02-05T11:00:00Z,0.0,500.0
