# Spark Streaming


You integrate Flypipe graphs to Spark streaming.

One way of doing it is using Spark [foreachBatch](https://spark.apache.org/docs/3.1.1/api/python/reference/api/pyspark.sql.streaming.DataStreamWriter.foreachBatch.html). Here we create a function `total_sales(batch_df, batch_id)` that receives the bacth dataframe and calls a Flypipe node with provided inputs.

## Cleaning environment

In [None]:
import shutil

spark.sql(
f"""
    CREATE DATABASE IF NOT EXISTS flypipe
    LOCATION '/data/warehouse/flypipe'
"""
)

spark.sql("drop table if exists flypipe.total_sales")
shutil.rmtree("/spark-warehouse/flypipe/total_sales", ignore_errors=True)
shutil.rmtree("/data/tmp/stream/_checkpoints", ignore_errors=True)
shutil.rmtree("/data/tmp/stream/json", ignore_errors=True)

## Adding json files simulating a source

In [None]:
import json
from time import time
import os
import uuid
import random
from pprint import pprint

JSON_LOCATION = "/data/tmp/stream/json"

os.makedirs(JSON_LOCATION, exist_ok=True)

def add_sale(quantity):
    
    for _ in range(quantity):
        sale_id = str(uuid.uuid4())
        out_file = open(f"{JSON_LOCATION}/{sale_id}.json", "w")
        
        data = {
            'sale_id': sale_id,
            'product_id': random.randrange(1, 5, 1),
            'price': random.randrange(100, 1000, 1),
            'quantity': random.randrange(1, 10, 1),
            'sale_datetime': int(time())
        }
        

        json.dump(data, out_file)
        
        print(f"\nAdded {JSON_LOCATION}/{sale_id}.json")
        pprint(data)
             
        
add_sale(5)

## Flypipe graph to process the data

In [None]:
from flypipe import node
from flypipe.schema import Schema, Column
from flypipe.schema.types import Decimal, String
from flypipe.datasource.spark import Spark
import pyspark.sql.functions as F

@node(
    type="pyspark",
    dependencies=[
        Spark("sales")
    ],
    output=Schema(
     Column("product_id", String(), "product identifier"),   
     Column("total_sales", Decimal(18,2), "total sales amount"),
    )
)
def total_sales_node(sales):
    df = sales.groupBy("product_id").agg(F.sum(F.col("price") * F.col("quantity")).alias("total_sales"))
    return df

## Defines a bacth function that wraps the Flypipe graph

In [None]:
from pyspark.sql.types import StructType, ArrayType, StructField, StringType, DecimalType, IntegerType, TimestampType


def total_sales(batch_df, batch_id):
    
    print("Batch dataframe received:")
    display(batch_df)
    
    total_sales_df = (
        total_sales_node
        .run(inputs = {
            Spark("sales"): batch_df
        })
    )
    
    print("===> Saving dataframe calculated with node `total_sales_node` into table `total_sales`")
    
    (
      total_sales_df
      .write
      .format('delta')
      .mode('overwrite')
      .saveAsTable("flypipe.total_sales")
    )
    
    return total_sales_df

## Sets up and start the streaming

In [None]:
# Create Stream
json_schema = StructType([
    StructField("sale_id", StringType(), True),
    StructField("product_id", StringType(), True),
    StructField("price", DecimalType(18,2), True),
    StructField("quantity", IntegerType(), True),
    StructField("sale_datetime", TimestampType(), True),
])


(
  spark
  .readStream
  .json(JSON_LOCATION, schema=json_schema)
  .writeStream
  .trigger(availableNow=True) # <-- Change the trigger as you wish
  .option("checkpointLocation", "/data/tmp/stream/_checkpoints/")
  .foreachBatch(total_sales)
  .start()
)

# Waitting process
from time import sleep

while True:
    try:
        spark.sql("select * from flypipe.total_sales")
        break
    except Exception as e:
        sleep(2)

## Display results

In [None]:
display(spark.sql("select * from flypipe.total_sales"))