In [0]:
from pyspark.sql import DataFrame
from pyspark.sql.window import Window
from pyspark.sql.functions import row_number, col
from mdd.utils import DecoratorUtil, DeltaTableUtil, FunctionUtil
from mdd.metadata import Metadata
from mdd.datareader import DeltaTableReader
from mdd.datawriter import DeltaTableWriter
from pyspark.sql import SparkSession
import logging


@DecoratorUtil.add_logger()
class TransformDataFlow:
    logger: logging.Logger
    
    def __init__(self, spark: SparkSession, metadata_yml: str):
        self.spark = spark
        self.metadata = Metadata(metadata_yml, False)
        self.debug = self.metadata.get("debug")
        self.active = self.metadata.get("active")
        self.source_name = DeltaTableUtil.qualify_table_name(self.metadata.get("reader", "source_name"))
        self.sink_name = DeltaTableUtil.qualify_table_name(self.metadata.get("writer", "sink_name"))

        dataflow_type = self.metadata.get("dataflow_type")
        if dataflow_type != "transform":
            message = (
                f"Invalid dataflow type: {dataflow_type}, it should be 'transform'"
            )
            self.logger.error(message)
            raise Exception(message)

    @DecoratorUtil.log_function()
    def read(self) -> DataFrame:
        """
        Reads data from the specified Delta table using the provided configuration.

        :return: Spark DataFrame containing the read data
        """
        mode = self.metadata.get("sync_options", "mode")
        backfill_days = self.metadata.get("sync_options", "backfill_days")

        # build the config for data reader
        config = {}
        config["source_name"] = self.source_name
        config["mode"] = mode
        config["backfill_days"] = backfill_days
        if mode == "full":
            # get the max _source_timestamp from the target table
            max_source_timestamp = DeltaTableUtil.get_max_column_value(
                self.spark, self.sink_name, "_source_timestamp"
            )
            config["full_max_processed_timestamp"] = max_source_timestamp
        elif mode == "incremental":
            max_source_commit_version = DeltaTableUtil.get_max_column_value(
                self.spark,
                "mdd.table_control",
                "source_commit_version",
                f"table_name = '{self.sink_name}' and source_name = '{self.source_name}'",
            )
            config["incremental_max_processed_version"] = max_source_commit_version
        else:
            message = f"Invalid mode: {mode}"
            self.logger.error(message)
            raise ValueError(message)

        if self.debug:
            self.logger.info(f"Config: {config}")

        reader = DeltaTableReader(spark=self.spark, config=config, debug=self.debug)
        df = reader.read()

        #df = self.deduplicate(df)
        return df

    @DecoratorUtil.log_function()
    def deduplicate(self, df: DataFrame) -> DataFrame:
        """
        Deduplicates the DataFrame by keeping only one record per primary key,
        using composite deduplication columns with optional sort direction.

        :param df: Input Spark DataFrame
        :param primary_key: Comma-separated list of primary key columns (e.g., "id,sub_id")
        :param deduplication_columns: String defining columns and sort directions
            (e.g., "event_time desc, event_sequence", defaults to asc if not specified)
        :return: Deduplicated DataFrame
        """

        primary_key = self.metadata.get("reader", "source_primarykey")
        deduplication_columns = self.metadata.get(
            "reader", "source_deduplication_columns"
        )

        if not primary_key or not deduplication_columns:
            message = "source_primarykey and source_deduplication_columns are required"
            self.logger.error(message)
            raise ValueError(message)

        primary_keys = FunctionUtil.string_to_list(primary_key)
        try:
            # Parse deduplication column string
            sort_exprs = []
            for entry in deduplication_columns.split(","):
                parts = entry.strip().split()
                if len(parts) == 1:
                    col_name, direction = parts[0], "asc"
                elif len(parts) == 2:
                    col_name, direction = parts[0], parts[1].lower()
                else:
                    raise ValueError(
                        f"Invalid format for deduplication column: '{entry.strip()}'"
                    )

                if direction == "desc":
                    sort_exprs.append(col(col_name).desc())
                elif direction == "asc":
                    sort_exprs.append(col(col_name).asc())
                else:
                    raise ValueError(
                        f"Unsupported sort order '{direction}' for column '{col_name}'"
                    )

            window_spec = Window.partitionBy(*primary_keys).orderBy(*sort_exprs)
            df_deduplicated = (
                df.withColumn("_row_number", row_number().over(window_spec))
                .filter(col("_row_number") == 1)
                .drop("_row_number")
            )

        except Exception as e:
            self.logger.error("Deduplication failed with exception: %s")
            raise

        return df_deduplicated

    @DecoratorUtil.log_function()
    def transform(self, df: DataFrame) -> DataFrame:
        return df

    @DecoratorUtil.log_function()
    def write_stream(self, df: DataFrame):
        config_writer = self.metadata.get("writer")
        writer = DeltaTableWriter(self.spark, df, config_writer, self.debug)
        query = writer.write_stream()
        query.awaitTermination()

    @DecoratorUtil.log_function()
    def run(self):
        self.logger.info("Transform start ...")

        self.logger.info("Read start...")
        df = self.read()
        self.logger.info("Read end...")

        self.logger.info("Deduplicate start...")
        source_primarykey = self.metadata.get("source_primarykey")
        source_deduplication_columns = self.metadata.get("source_deduplication_columns")
        df = self.deduplicate(df, source_primarykey, source_deduplication_columns)
        self.logger.info("Deduplicate end...")

        self.logger.info("Deduplicate start...")
        df = self.transform(df)
        self.logger.info("Deduplicate end...")

        self.logger.info("Write start...")
        self.write_stream(df)
        self.logger.info("Write end...")

In [0]:
%load_ext autoreload
%autoreload 2
import logging
import datetime
from mdd.logger import *
#from mdd.transformer import *

log_folder = "mdd_test"
log_file_name = "test_transformer"
log_timestamp = datetime.datetime.now()
debug = False
Logger.init(log_folder, log_file_name, log_timestamp, debug)

try:
    metadata_yml = "transform/gold_fact_combinedcards.yml"
    dataflow = TransformDataFlow(spark, metadata_yml)
    df_read = dataflow.read()
    display(df_read)
except Exception as e:
    print(e)
finally:
    logging.shutdown()



In [0]:
df_read.count()

In [0]:
%sql
select count(*) from lakehouse.bronze.paytronix_mid352_combinedcards;

In [0]:
drop_columns = ["_corrupt_record", "_rescued_data", "_mode", "_change_type", "_commit_version", "_commit_timestamp"]

from mdd.helper.deltatable import DeltaTableUtils as dtu
df_transformed = DeltaTableUtil.safe_drop_columns(df_read, drop_columns)
display(df_transformed)

In [0]:
df_transformed.isStreaming

In [0]:
dataflow.write_stream(df_transformed)