# Spark Streaming


You integrate Flypipe graphs to Spark streaming.

One way of doing it is using Spark [foreachBatch](https://spark.apache.org/docs/3.1.1/api/python/reference/api/pyspark.sql.streaming.DataStreamWriter.foreachBatch.html). Here we create a function `total_sales(batch_df, batch_id)` that receives the bacth dataframe and calls a Flypipe node with provided inputs.

## Cleaning environment

In [1]:
import shutil

spark.sql("drop table if exists total_sales")
shutil.rmtree("/spark-warehouse/total_sales", ignore_errors=True)
shutil.rmtree("./tmp/stream/_checkpoints", ignore_errors=True)
shutil.rmtree("./tmp/stream/json", ignore_errors=True)

## Adding json files simulating a source

In [2]:
import json
from time import time
import os
import uuid
import random
from pprint import pprint

JSON_LOCATION = "./tmp/stream/json"

os.makedirs(JSON_LOCATION, exist_ok=True)

def add_sale(quantity):
    
    for _ in range(quantity):
        sale_id = str(uuid.uuid4())
        out_file = open(f"{JSON_LOCATION}/{sale_id}.json", "w")
        
        data = {
            'sale_id': sale_id,
            'product_id': random.randrange(1, 5, 1),
            'price': random.randrange(100, 1000, 1),
            'quantity': random.randrange(1, 10, 1),
            'sale_datetime': int(time())
        }
        

        json.dump(data, out_file)
        
        print(f"\nAdded {JSON_LOCATION}/{sale_id}.json")
        pprint(data)
             
        
add_sale(5)


Added ./tmp/stream/json/a3c0a5c4-3c06-4251-a52a-fa6ab3311edf.json
{'price': 946,
 'product_id': 2,
 'quantity': 5,
 'sale_datetime': 1669756681,
 'sale_id': 'a3c0a5c4-3c06-4251-a52a-fa6ab3311edf'}

Added ./tmp/stream/json/e08ed278-c31c-4d55-a4a1-9bc3fc2de8ec.json
{'price': 105,
 'product_id': 4,
 'quantity': 3,
 'sale_datetime': 1669756681,
 'sale_id': 'e08ed278-c31c-4d55-a4a1-9bc3fc2de8ec'}

Added ./tmp/stream/json/f2df3417-e046-4006-bd36-792b24466ea3.json
{'price': 247,
 'product_id': 3,
 'quantity': 2,
 'sale_datetime': 1669756681,
 'sale_id': 'f2df3417-e046-4006-bd36-792b24466ea3'}

Added ./tmp/stream/json/5aec26b7-6ec8-4e85-8139-e56bbff4e513.json
{'price': 734,
 'product_id': 2,
 'quantity': 1,
 'sale_datetime': 1669756681,
 'sale_id': '5aec26b7-6ec8-4e85-8139-e56bbff4e513'}

Added ./tmp/stream/json/4e6c6f28-b09f-4c0a-b8a3-afc493958083.json
{'price': 496,
 'product_id': 4,
 'quantity': 8,
 'sale_datetime': 1669756681,
 'sale_id': '4e6c6f28-b09f-4c0a-b8a3-afc493958083'}


## Flypipe graph to process the data

In [3]:
from flypipe import node
from flypipe.schema import Schema, Column
from flypipe.schema.types import Decimal, String
from flypipe.datasource.spark import Spark
import pyspark.sql.functions as F

@node(
    type="pyspark",
    dependencies=[
        Spark("sales")
    ],
    output=Schema(
     Column("product_id", String(), "product identifier"),   
     Column("total_sales", Decimal(18,2), "total sales amount"),
    )
)
def total_sales_node(sales):
    df = sales.groupBy("product_id").agg(F.sum(F.col("price") * F.col("quantity")).alias("total_sales"))
    return df

## Defines a bacth function that wraps the Flypipe graph

In [4]:
from pyspark.sql.types import StructType, ArrayType, StructField, StringType, DecimalType, IntegerType, TimestampType


def total_sales(batch_df, batch_id):
    
    print("Batch dataframe received:")
    display(batch_df)
    
    total_sales_df = (
        total_sales_node
        .run(inputs = {
            Spark("sales"): batch_df
        })
    )
    
    print("===> Saving dataframe calculated with node `total_sales_node` into table `total_sales`")
    
    (
      total_sales_df
      .write
      .format('delta')
      .mode('overwrite')
      .saveAsTable("total_sales")
    )
    
    return total_sales_df

## Sets up and start the streaming

In [5]:
# Create Stream
json_schema = StructType([
    StructField("sale_id", StringType(), True),
    StructField("product_id", StringType(), True),
    StructField("price", DecimalType(18,2), True),
    StructField("quantity", IntegerType(), True),
    StructField("sale_datetime", TimestampType(), True),
])


(
  spark
  .readStream
  .json(JSON_LOCATION, schema=json_schema)
  .writeStream
  .trigger(availableNow=True) # <-- Change the trigger as you wish
  .option("checkpointLocation", "./tmp/stream/_checkpoints/")
  .foreachBatch(total_sales)
  .start()
)

# Waitting process
from time import sleep

while True:
    try:
        spark.sql("select * from total_sales")
        break
    except Exception as e:
        sleep(2)

Batch dataframe received:


sale_id,product_id,price,quantity,sale_datetime
a3c0a5c4-3c06-425...,2,946.0,5,2022-11-29 21:18:01
e08ed278-c31c-4d5...,4,105.0,3,2022-11-29 21:18:01
f2df3417-e046-400...,3,247.0,2,2022-11-29 21:18:01
5aec26b7-6ec8-4e8...,2,734.0,1,2022-11-29 21:18:01
4e6c6f28-b09f-4c0...,4,496.0,8,2022-11-29 21:18:01


===> Saving dataframe calculated with node `total_sales_node` into table `total_sales`


                                                                                

## Display results

In [6]:
display(spark.sql("select * from total_sales"))

product_id,total_sales
3,494.0
2,5464.0
4,4283.0
