In [None]:
from datetime import datetime

from pyspark import SparkConf, SparkContext
from pyspark.sql import SparkSession
from pyspark.sql import functions as F
from pyspark.sql.types import DateType
from pyspark.sql.functions import expr, col
from pyspark.sql.types import DateType

packages = [
    'org.apache.hadoop:hadoop-aws:3.3.4',
    'org.apache.hadoop:hadoop-client-api:3.3.4',
    'org.apache.hadoop:hadoop-client-runtime:3.3.4',
    'io.delta:delta-core_2.12:2.4.0',
]

conf = SparkConf() \
    .setAppName("MyApp") \
    .set("spark.driver.memory", "8g") \
    .set("spark.executor.memory", "8g") \
    .set('spark.jars.packages', ','.join(packages)) \
    .set("spark.sql.extensions", "io.delta.sql.DeltaSparkSessionExtension") \
    .set("spark.sql.catalog.spark_catalog", "org.apache.spark.sql.delta.catalog.DeltaCatalog")

sc = SparkContext(conf=conf)

hadoop_config = sc._jsc.hadoopConfiguration()
hadoop_config.set('fs.s3a.impl', 'org.apache.hadoop.fs.s3a.S3AFileSystem')
hadoop_config.set('com.amazonaws.services.s3.enableV4', 'true')

spark = SparkSession(sc)

In [5]:
class PartitionProcessor:
    def __init__(self, group_by_cols, agg_dict, date_col="ds", partition_col="partitionArea", median_accuracy=100, variance_default=0.0):
        self.date_col = date_col
        self.partition_col = partition_col
        self.group_by_cols = group_by_cols
        self._group_by_cols = [self.date_col, self.partition_col] + self.group_by_cols
        self.agg_dict = agg_dict
        self.median_accuracy = median_accuracy
        self.variance_default = variance_default

        self._agg_exprs = self._get_aggregation_expressions()
    
    def _get_aggregation_expressions(self):
        agg_exprs = []
        for col_name, agg_funcs in self.agg_dict.items():
            for func in agg_funcs:
                if func == 'variance':
                    agg_exprs.append(F.var_pop(col_name).alias(f"{col_name}_variance"))
                elif func == 'count':
                    agg_exprs.append(F.count(col_name).alias(f"{col_name}_count"))
                elif func == 'sum':
                    agg_exprs.append(F.sum(col_name).alias(f"{col_name}_sum"))
                elif func == 'avg':
                    agg_exprs.append(F.avg(col_name).alias(f"{col_name}_avg"))
                elif func == 'median':
                    agg_exprs.append(F.expr(f'percentile_approx({col_name}, 0.5, {self.median_accuracy})').alias(f"{col_name}_median"))
                else:
                    raise ValueError(f"Unsupported aggregation function: {func}")
        return agg_exprs

    def group_and_aggregate(self, df):
        result_df = df.groupBy(self._group_by_cols).agg(*self._agg_exprs)
        
        for col_name, agg_funcs in self.agg_dict.items():
            if 'variance' in agg_funcs:
                result_df = result_df.fillna({f"{col_name}_variance": self.variance_default})
        
        return result_df
    
    def process_dates(self, dates, table_name, num_partitions=200):
        for date in dates:
            print(date)

            df_date = df.filter(col(self.date_col) == date)
            df_date_repart = df_date.repartition(num_partitions, self._group_by_cols)
            result_df = self.group_and_aggregate(df_date_repart)

            (
                result_df
                # .repartition(num_partitions)
                .write
                .partitionBy(self.date_col, self.partition_col)
                .format("delta")
                .mode('overwrite')
                .option('replaceWhere', f"{self.date_col} = '{date}'")
                .option('mergeSchema', 'true')
                .saveAsTable(table_name)
            )

In [15]:
# initial data
data = [
    ("2021-07-30", "partition1", "A", "X", 10),
    ("2021-07-30", "partition1", "A", "X", 20),
    ("2021-07-30", "partition2", "A", "Y", 30),
    ("2021-07-31", "partition1", "B", "X", 40),
    ("2021-07-31", "partition2", "B", "Y", 50),
    ("2021-07-31", "partition2", "C", "Y", None),  # Example with a null value
]

# additional data
# data = [
#     ("2021-08-01", "partition1", "D", "X", 20),
#     ("2021-08-01", "partition2", "A", "Z", 10),
#     ("2021-08-01", "partition2", "A", "Y", 30),
# ]

df = spark.createDataFrame(data, ["ds", "partitionArea", "col1", "col2", "value"])
df = df.withColumn("ds", df["ds"].cast(DateType()))

group_by_cols = ["col1", "col2"]
agg_dict = {
    "value": ["variance", "count", "sum", "avg", "median"]
}

table_name = 'test'
processor = PartitionProcessor(group_by_cols, agg_dict)

# initial data
dates = ["2021-07-30", "2021-07-31"]

# dates to append to the same table
# dates = ["2021-08-01"]

processor.process_dates(dates=dates, table_name=table_name)

2021-08-01


                                                                                

In [16]:
from delta.tables import DeltaTable

delta_table = DeltaTable.forName(spark, table_name)
delta_table.history().toPandas()

  series = series.astype(t, copy=False)


Unnamed: 0,version,timestamp,userId,userName,operation,operationParameters,job,notebook,clusterId,readVersion,isolationLevel,isBlindAppend,operationMetrics,userMetadata,engineInfo
0,4,2024-07-28 05:04:32.824,,,WRITE,"{'mode': 'Overwrite', 'partitionBy': '[""ds"",""p...",,,,3.0,Serializable,False,"{'numOutputRows': '3', 'numRemovedBytes': '0',...",,Apache-Spark/3.4.3 Delta-Lake/2.4.0
1,3,2024-07-28 04:59:20.984,,,WRITE,"{'mode': 'Overwrite', 'partitionBy': '[""ds"",""p...",,,,2.0,Serializable,False,"{'numOutputRows': '3', 'numRemovedBytes': '585...",,Apache-Spark/3.4.3 Delta-Lake/2.4.0
2,2,2024-07-28 04:59:16.903,,,WRITE,"{'mode': 'Overwrite', 'partitionBy': '[""ds"",""p...",,,,1.0,Serializable,False,"{'numOutputRows': '2', 'numRemovedBytes': '402...",,Apache-Spark/3.4.3 Delta-Lake/2.4.0
3,1,2024-07-28 04:56:55.637,,,WRITE,"{'mode': 'Overwrite', 'partitionBy': '[""ds"",""p...",,,,0.0,Serializable,False,"{'numOutputRows': '3', 'numRemovedBytes': '0',...",,Apache-Spark/3.4.3 Delta-Lake/2.4.0
4,0,2024-07-28 04:56:48.739,,,WRITE,"{'mode': 'Overwrite', 'partitionBy': '[""ds"",""p...",,,,,Serializable,True,"{'numOutputRows': '2', 'numAddedChangeFiles': ...",,Apache-Spark/3.4.3 Delta-Lake/2.4.0


In [17]:
loaded_df = spark.read.format("delta").table(table_name)

In [18]:
loaded_df.show()

+----------+-------------+----+----+--------------+-----------+---------+---------+------------+
|        ds|partitionArea|col1|col2|value_variance|value_count|value_sum|value_avg|value_median|
+----------+-------------+----+----+--------------+-----------+---------+---------+------------+
|2021-08-01|   partition1|   D|   X|           0.0|          1|       20|     20.0|          20|
|2021-07-31|   partition1|   B|   X|           0.0|          1|       40|     40.0|          40|
|2021-07-31|   partition2|   B|   Y|           0.0|          1|       50|     50.0|          50|
|2021-08-01|   partition2|   A|   Y|           0.0|          1|       30|     30.0|          30|
|2021-08-01|   partition2|   A|   Z|           0.0|          1|       10|     10.0|          10|
|2021-07-30|   partition2|   A|   Y|           0.0|          1|       30|     30.0|          30|
|2021-07-30|   partition1|   A|   X|          25.0|          2|       30|     15.0|          10|
|2021-07-31|   partition2|   C