In [1]:
import pathlib
from datetime import datetime
from typing import List, Tuple, Union, Dict

import pyspark.sql.functions as F
import pyspark.sql.types as T
from pyspark.sql import Column

from pyspark.sql import DataFrame
from pyspark.sql import SparkSession
from pyspark.sql import Window


In [2]:
# master configuration to use only 4 CPU cores
spark = SparkSession.builder.master("local[4]").getOrCreate()

# basic configuration to use only a reasonable number of partitions
spark.conf.set("spark.sql.shuffle.partition", 4)

# configuration to work in UTC
spark.conf.set("spark.sql.session.timeZone", "UTC")


Setting default log level to "WARN".
To adjust logging level use sc.setLogLevel(newLevel). For SparkR, use setLogLevel(newLevel).
23/07/21 09:58:39 WARN NativeCodeLoader: Unable to load native-hadoop library for your platform... using builtin-java classes where applicable


In [3]:
def deduplicate(df: DataFrame, config: dict) -> DataFrame:
        order_by_dict = config["order_by"]
        partition_by = config["key_columns"]
        dedup_type = config.get("type", "row_number")

        window_func = Window.partitionBy(partition_by).orderBy(
            *[
                F.col(col).desc() if order == "desc" else F.col(col).asc()
                for col, order in order_by_dict.items()
            ]
        )

        if dedup_type == "row_number":
            df = df.withColumn("wf_col", F.row_number().over(window_func))
        elif dedup_type == "rank":
            df = df.withColumn("wf_col", F.rank().over(window_func))
        else:
            raise ValueError(
                f"""Deduplication type should be either row_number,
                empty (row_number as default) or rank. Not {dedup_type}"""
            )
        return df.filter(F.col("wf_col") == 1).drop("wf_col")

In [5]:
class DummyConfig:
    def __init__(self, config: dict):
        self.config = config

    def __getitem__(self, key):
        return self.config[key]

    def get(self, key, default=None):
        return self.config.get(key, default)


In [10]:

input_path = "/Users/emilianofrigo/repositories/pipelines-tooling/tests/unit/spark_operations/deduplicate/resources/input_rank.json"

input_df = spark.read.json(input_path, multiLine=True)


config = DummyConfig(
    {
        "key_columns": ["Subject"],
        "order_by": {"Score": "desc"},
        "type": "rank",
    }
)

actual_df = deduplicate(input_df, config)

In [11]:
actual_df.show()

+-----+-------+-------+
|Score|Student|Subject|
+-----+-------+-------+
|   95|Charlie|   Math|
|   95|    Eva|Science|
|   95|    Eva|Science|
+-----+-------+-------+

