In [1]:
import sys 
import time
import json
import logging
import requests
from itertools import cycle

from delta import DeltaTable
from pyspark.sql import Window
from pyspark.sql import DataFrame
from pyspark.sql import SparkSession
from pyspark.sql import  functions as F
from pyspark.sql.avro.functions import from_avro
from pyspark.sql.streaming.query import StreamingQuery

#### Logging Init

In [2]:
logger = logging.getLogger("postgres_delta_cdc")
logger.setLevel(level=logging.DEBUG)

log_format = logging.Formatter(
    fmt=f"%(levelname)s %(asctime)s  - %(message)s",
    datefmt="%Y-%m-%d %H:%M:%S",
)

# Adding file and console handlers for logging.

file_handler = logging.FileHandler(filename='/app/logs/streaming.log')
file_handler.setFormatter(log_format)
file_handler.setLevel(level=logging.INFO)

console_handler = logging.StreamHandler(stream=sys.stdout)
console_handler.setFormatter(log_format)
console_handler.setLevel(level=logging.DEBUG)

logger.addHandler(file_handler)
logger.addHandler(console_handler)

#### Spark Init

In [None]:
spark = (
    SparkSession.builder
    .config("spark.jars.packages","org.apache.spark:spark-sql-kafka-0-10_2.12:3.4.0,org.apache.spark:spark-avro_2.12:3.4.0,io.delta:delta-core_2.12:2.4.0")
    .config("spark.jars.repositories","https://mvnrepository.com/artifact/io.delta/delta-core,https://mvnrepository.com/artifact/org.apache.hadoop/hadoop-core,https://mvnrepository.com/artifact/org.apache.spark/spark-sql-kafka-0-10,https://mvnrepository.com/artifact/org.apache.spark/spark-avro")
    .config("spark.sql.extensions", "io.delta.sql.DeltaSparkSessionExtension")
    .config("spark.sql.catalog.spark_catalog", "org.apache.spark.sql.delta.catalog.DeltaCatalog")
    .getOrCreate()
)

#### Functions

In [4]:
def get_schema_registry_latest(topic:str)->str:
    '''Get the Avro JSON schema for the current topic.'''
    response = requests.get(f'http://host.docker.internal:8081/subjects/{topic}-value/versions/latest/schema')
    return json.dumps(response.json())

def desserialize_avro_column(df:DataFrame,schema:str)->DataFrame:
    '''Remove confluent extra 5 bytes, then transform the Avro column.'''
    return (
        df
        .withColumn("value", F.expr("substring(value, 6, length(value)-5)"))
        .withColumn("value", from_avro('value',jsonFormatSchema=schema))
   )

In [5]:
def get_avaiable_tables(table_properties:list)->list:
    return list(map(lambda x: x.get('Table'),table_properties))

def get_kafka_topics(table_properties:list)-> str:
    return ','.join(get_avaiable_tables(table_properties=table_properties))

def get_table_properites(table_properties:str, table:str):
    return list(filter(lambda x: x.get('Table')==table, table_properties))[0]

In [6]:
def foreachBatch(
        df_cdc_read_stream_batch:DataFrame,
        batch_id:int,
        topic:str,
        table_schema:str
        ,partition_window:Window, 
        join_query:str
    )->None:
    """
     For each batch of CDC (Change Data Capture), deserialize the Avro column and retrieve the highest operation done, partitioned by the available table primary keys.

    Args:
        df_cdc_read_stream_batch (DataFrame): Current batch Spark DataFrame.
        batch_id (int): Current batch ID of the streaming table.
        topic (str): Kafka topic name.
        table_schema (str): Table schema registry Avro-dumped JSON.
        partition_window (Window): Window that groups CDC rows to get the last operation based on primary keys.
        join_query (str): SQL query to match the primary keys.
    """     

    df_cdc_read_stream = (
        desserialize_avro_column(df_cdc_read_stream_batch, table_schema)
        .withColumn('values', 
            F.when(
                F.col('value.op')!='d',
                F.col('value.after')
            ).otherwise(
                F.col('value.before')
            )
        )
        .select('values.*', F.col('value.op').alias('last_operation'), F.col('value.ts_ms').alias('ts_ms'))
        .withColumn('max_value', F.row_number().over(partition_window))
        .filter(F.col('max_value') == 1)
        .drop('max_value', 'ts_ms')
    )

    if not DeltaTable.isDeltaTable(spark, f'/app/data/{topic}/'):
        logger.info(f'Saving topic {topic} for the first time')
        df_cdc_read_stream.write.format('delta').save(f'/app/data/{topic}/') 
    else:
        logger.info(f'Saving a new batch for the topic {topic}')

        deltaTable = DeltaTable.forPath(spark,f'/app/data/{topic}/')
        
        deltaTable.alias("stored_val").merge(
            df_cdc_read_stream.alias("new_val"),
            join_query
        )\
        .whenMatchedDelete(
            condition = "new_val.last_operation = 'd'"
        ) \
        .whenMatchedUpdateAll(
            condition="new_val.last_operation = 'u'",
        )\
        .whenNotMatchedInsertAll(
            condition = "new_val.last_operation != 'd'", # if delete was the last operation, do nothing
        ).execute()

In [7]:
def create_postgres_delta_cdc_table(current_table:str,df_cdc_read_stream:DataFrame,table_properties:dict)-> StreamingQuery:
    """ 
    Function that creates a Spark WriteStream for the tables.

    Args:
        current_table (str): The current CDC table name.
        df_cdc_read_stream (DataFrame): The readstream Spark DataFrame.
        table_properties (dict): A metadata dictionary to obtain table properties.

    Returns:
        StreamingQuery: The Spark streaming query.
    """ 

    table_schema = get_schema_registry_latest(current_table)
    current_table_properties = get_table_properites(table_properties, current_table)
    processing_time = current_table_properties.get('ProcessingTime')  
    table_primary_keys = current_table_properties.get('PrimaryKeys')
    join_query = 'AND'.join([f'stored_val.{val} = new_val.{val}' for val in table_primary_keys])
    partition_window =  Window.partitionBy(*table_primary_keys).orderBy(F.col('ts_ms').desc())

    current_table_properties
    return df_cdc_read_stream\
        .filter(F.col("topic")==current_table)\
        .writeStream\
        .format("delta")\
        .foreachBatch(lambda  batch, batchId: foreachBatch(batch, batchId, topic=current_table,table_schema=table_schema, partition_window=partition_window, join_query=join_query))\
        .option("checkpointLocation",f'/app/data/{current_table}/_checkpoints')\
        .queryName(current_table)\
        .trigger(processingTime=processing_time)\
        .start(f'/app/data/{current_table}/')

#### Main

In [None]:
# Getting metadata

table_properties = json.load(open('/app/metadata/table_properties.json'))
kafka_bootstrap_server = 'host.docker.internal:29092'
kafka_topics = get_kafka_topics(table_properties)
kafka_schema = get_schema_registry_latest(kafka_topics)
avaiable_tables = get_avaiable_tables(table_properties)

# Starting the readstream for multiple topics

df_cdc_read_stream = (
  spark
    .readStream
    .format("kafka")
    .option("kafka.bootstrap.servers", kafka_bootstrap_server)
    .option("subscribe", kafka_topics)
    .option("startingOffsets","earliest")
    .load()
)

# Initializing the streaming

for table in avaiable_tables:
    current_stream = create_postgres_delta_cdc_table(table, df_cdc_read_stream=df_cdc_read_stream,table_properties=table_properties)

#### Simple Stupid monitoring

In [None]:
active_streams = [{'id': val.id,'name':val.name,'last_batch':val.lastProgress} for val in spark.streams.active]
iter_active_streams  = cycle(active_streams)

while True:
    time.sleep(0.2)
    current_stream_data = next(iter_active_streams)
    current_stream = spark.streams.get(current_stream_data.get('id'))
    
    if not current_stream.isActive:
        logger.warn(f' ERROR {current_stream.name} stopped for some reason !')

    if (current_stream.lastProgress is not None) & (current_stream_data.get('last_batch') is not None): 
        if current_stream.lastProgress.get('batchId') > current_stream_data.get('last_batch').get('batchId'):
            logger.debug(f'last process - {current_stream.name}  |  {current_stream.lastProgress}')
            current_stream_data.update({'last_batch':current_stream.lastProgress})

In [None]:
# spark.read.format('delta').load('/app/data/exampledb.public.bounties_fact/').orderBy(F.col('character_id').asc()).show()

In [None]:
# spark.read.format('delta').load('/app/data/exampledb.public.characters_dim/').orderBy(F.col('character_id').asc()).show()