In [19]:
from ts_train.step.aggregation import Aggregation
from ts_train.step.filling import Filling
from ts_train.step.time_bucketing import TimeBucketing
from ts_train.common.utils import (  # type: ignore
    cast_column_to_timestamp,  # type: ignore
)
from pyspark.sql import SparkSession
from ts_train.common.utils import (  # type: ignore
    cast_column_to_timestamp,  # type: ignore
    create_timestamps_struct,  # type: ignore
)

In [20]:
### CREATE INITIAL INPUT
spark = SparkSession.builder.getOrCreate()

def sample_dataframe_01(spark):
    df = spark.createDataFrame(
        data=[
            (348272371, "2023-01-01", 5, "shopping", "carta", "true"),
            (348272371, "2023-01-01", 6, "salute", "cash", "false"),
            (348272371, "2023-01-01", 8, "trasporti", "cash", "false"),
            (348272371, "2023-01-01", 1, "trasporti", "carta", "true"),
            (348272371, "2023-01-06", 20, "shopping", "bitcoin", "false"),
            (348272371, "2023-01-06", 43, "shopping", "carta", "true"),
            (348272371, "2023-01-06", 72, "shopping", "cash", "false"),
            (234984832, "2023-01-01", 15, "salute", "carta", "true"),
            (234984832, "2023-01-01", 36, "salute", "carta", "true"),
            (234984832, "2023-01-01", 78, "salute", "cash", "false"),
            (234984832, "2023-01-02", 2, "trasporti", "carta", "true"),
        ],
        schema=[
            "ID_BIC_CLIENTE",
            "DATA_TRANSAZIONE",
            "IMPORTO",
            "CA_CATEGORY_LIV0",
            "METODO_PAGAMENTO",
            "IS_CARTA",
        ],
    )

    return cast_column_to_timestamp(df=df, col_name="DATA_TRANSAZIONE")

data_df = sample_dataframe_01(spark)
data_df.show()

+--------------+-------------------+-------+----------------+----------------+--------+
|ID_BIC_CLIENTE|   DATA_TRANSAZIONE|IMPORTO|CA_CATEGORY_LIV0|METODO_PAGAMENTO|IS_CARTA|
+--------------+-------------------+-------+----------------+----------------+--------+
|     348272371|2023-01-01 00:00:00|      5|        shopping|           carta|    true|
|     348272371|2023-01-01 00:00:00|      6|          salute|            cash|   false|
|     348272371|2023-01-01 00:00:00|      8|       trasporti|            cash|   false|
|     348272371|2023-01-01 00:00:00|      1|       trasporti|           carta|    true|
|     348272371|2023-01-06 00:00:00|     20|        shopping|         bitcoin|   false|
|     348272371|2023-01-06 00:00:00|     43|        shopping|           carta|    true|
|     348272371|2023-01-06 00:00:00|     72|        shopping|            cash|   false|
|     234984832|2023-01-01 00:00:00|     15|          salute|           carta|    true|
|     234984832|2023-01-01 00:00

In [21]:
time_zone="Europe/Rome"
time_column_name="DATA_TRANSAZIONE"
time_bucket_size=1
time_bucket_granularity="days"
numerical_col_name=["IMPORTO"]
all_aggregation_filters=[ 
                            [("METODO_PAGAMENTO", [])],
                            [("CA_CATEGORY_LIV0", ["shopping","salute"])], 
                            [("CA_CATEGORY_LIV0", ["shopping"]),("CA_CATEGORY_LIV0", ["salute"])],
                            [("IS_CARTA", ["false"]),("METODO_PAGAMENTO", ["carta"])],
                            [("IS_CARTA", ["true"]),("METODO_PAGAMENTO", ["chash", "bitcoin"])],
                        ]
agg_funcs=["sum"]
time_bucket_col_name="bucket"
identifier_cols_name=["ID_BIC_CLIENTE"]
new_timestamp_col_name="new_timestamp"

In [22]:
#### TIME BUCKETING
time_bucket_step =  TimeBucketing(
        time_zone=time_zone,
        time_column_name=time_column_name,
        time_bucket_size=time_bucket_size,
        time_bucket_granularity=time_bucket_granularity, # type: ignore
        time_bucket_col_name=time_bucket_col_name,
    )

time_bucket_df = time_bucket_step(data_df, spark)
time_bucket_df.show(truncate=False)

+--------------+-------------------+-------+----------------+----------------+--------+------------------------------------------+
|ID_BIC_CLIENTE|DATA_TRANSAZIONE   |IMPORTO|CA_CATEGORY_LIV0|METODO_PAGAMENTO|IS_CARTA|bucket                                    |
+--------------+-------------------+-------+----------------+----------------+--------+------------------------------------------+
|348272371     |2023-01-01 00:00:00|5      |shopping        |carta           |true    |{2022-12-31 01:00:00, 2023-01-01 01:00:00}|
|348272371     |2023-01-01 00:00:00|6      |salute          |cash            |false   |{2022-12-31 01:00:00, 2023-01-01 01:00:00}|
|348272371     |2023-01-01 00:00:00|8      |trasporti       |cash            |false   |{2022-12-31 01:00:00, 2023-01-01 01:00:00}|
|348272371     |2023-01-01 00:00:00|1      |trasporti       |carta           |true    |{2022-12-31 01:00:00, 2023-01-01 01:00:00}|
|348272371     |2023-01-06 00:00:00|20     |shopping        |bitcoin         |false

In [25]:

df_before_agg = spark.createDataFrame(
    data=[
        (348272371, "2023-01-01", "2023-01-02", 5, "shopping", "carta", "true"),
        (348272371, "2023-01-01", "2023-01-02", 6, "salute", "cash", "false"),
        (348272371, "2023-01-01", "2023-01-02", 8, "trasporti", "cash", "false"),
        (348272371, "2023-01-01", "2023-01-02", 1, "trasporti", "carta", "true"),
        (
            348272371,
            "2023-01-06",
            "2023-01-07",
            20,
            "shopping",
            "bitcoin",
            "false",
        ),
        (348272371, "2023-01-06", "2023-01-07", 43, "shopping", "carta", "true"),
        (348272371, "2023-01-06", "2023-01-07", 72, "shopping", "cash", "false"),
        (234984832, "2023-01-01", "2023-01-01", 15, "salute", "carta", "true"),
        (234984832, "2023-01-01", "2023-01-01", 36, "salute", "carta", "true"),
        (234984832, "2023-01-01", "2023-01-01", 78, "salute", "cash", "false"),
        (234984832, "2023-01-02", "2023-01-02", 2, "trasporti", "carta", "true"),
    ],
    schema=[
        "ID_BIC_CLIENTE",
        "bucket_start",
        "bucket_end",
        "IMPORTO",
        "CA_CATEGORY_LIV0",
        "METODO_PAGAMENTO",
        "IS_CARTA",
    ],
)

df_before_agg = create_timestamps_struct(
    df=df_before_agg, cols_name=("bucket_start", "bucket_end"), struct_col_name="bucket", format="yyyy-MM-dd HH:mm:ss",
)

df_before_agg.show(truncate=False)

+--------------+-------+----------------+----------------+--------+------------+
|ID_BIC_CLIENTE|IMPORTO|CA_CATEGORY_LIV0|METODO_PAGAMENTO|IS_CARTA|bucket      |
+--------------+-------+----------------+----------------+--------+------------+
|348272371     |5      |shopping        |carta           |true    |{null, null}|
|348272371     |6      |salute          |cash            |false   |{null, null}|
|348272371     |8      |trasporti       |cash            |false   |{null, null}|
|348272371     |1      |trasporti       |carta           |true    |{null, null}|
|348272371     |20     |shopping        |bitcoin         |false   |{null, null}|
|348272371     |43     |shopping        |carta           |true    |{null, null}|
|348272371     |72     |shopping        |cash            |false   |{null, null}|
|234984832     |15     |salute          |carta           |true    |{null, null}|
|234984832     |36     |salute          |carta           |true    |{null, null}|
|234984832     |78     |salu

In [None]:
#### AGGREGATION
aggregation_step =  Aggregation(
    numerical_col_name=numerical_col_name,
    identifier_cols_name=identifier_cols_name,
    all_aggregation_filters=all_aggregation_filters,
    agg_funcs=agg_funcs,
)


aggregated_df = aggregation_step(time_bucket_df, spark)
aggregated_df.show(truncate=False)

pivot
pivot
+--------------+------------------------------------------+---------------------------------------+--------------------------------------+-----------------------------------------+------------------------------------------+----------------------------------------+----------------------------------------------------+------------------------------------------------------------------+-------------------------------------------------------------------------+
|ID_BIC_CLIENTE|bucket                                    |sum_IMPORTO_by_METODO_PAGAMENTO_(carta)|sum_IMPORTO_by_METODO_PAGAMENTO_(cash)|sum_IMPORTO_by_METODO_PAGAMENTO_(bitcoin)|sum_IMPORTO_by_CA_CATEGORY_LIV0_(shopping)|sum_IMPORTO_by_CA_CATEGORY_LIV0_(salute)|sum_of_IMPORTO_by_CA_CATEGORY_LIV0_(shopping_salute)|sum_of_IMPORTO_by_IS_CARTA_(false)_and_by_METODO_PAGAMENTO_(carta)|sum_of_IMPORTO_by_IS_CARTA_(true)_and_by_METODO_PAGAMENTO_(chash_bitcoin)|
+--------------+------------------------------------------+-----------

In [18]:
#### Filter
filling_step = Filling(
        time_bucket_col_name=time_bucket_col_name,
        identifier_cols_name=identifier_cols_name,
        time_bucket_size=time_bucket_size,
        time_bucket_granularity=time_bucket_granularity, # type: ignore
        new_timestamp_col_name=new_timestamp_col_name,
    )
filled_df = filling_step(aggregated_df, spark)
filled_df.show(truncate=False)

+--------------+-------------------+---------------------------------------+--------------------------------------+-----------------------------------------+------------------------------------------+----------------------------------------+----------------------------------------------------+------------------------------------------------------------------+-------------------------------------------------------------------------+
|ID_BIC_CLIENTE|new_timestamp      |sum_IMPORTO_by_METODO_PAGAMENTO_(carta)|sum_IMPORTO_by_METODO_PAGAMENTO_(cash)|sum_IMPORTO_by_METODO_PAGAMENTO_(bitcoin)|sum_IMPORTO_by_CA_CATEGORY_LIV0_(shopping)|sum_IMPORTO_by_CA_CATEGORY_LIV0_(salute)|sum_of_IMPORTO_by_CA_CATEGORY_LIV0_(shopping_salute)|sum_of_IMPORTO_by_IS_CARTA_(false)_and_by_METODO_PAGAMENTO_(carta)|sum_of_IMPORTO_by_IS_CARTA_(true)_and_by_METODO_PAGAMENTO_(chash_bitcoin)|
+--------------+-------------------+---------------------------------------+--------------------------------------+-------------

In [11]:
from pyspark.sql import SparkSession, Row
from pyspark.sql.types import StructType, StructField, StringType, FloatType
from typing import *
from pyspark_assert import assert_frame_equal  # type: ignore

from pyspark.sql.dataframe import DataFrame
from pyspark.sql import functions as F
from pyspark.sql.types import DataTypeSingleton, DateType, TimestampType
from pyspark.sql.types import StringType, BooleanType
from pyspark.sql.types import (
    ByteType,
    ShortType,
    IntegerType,
    LongType,
    FloatType,
    ArrayType,
    DoubleType,
)

def create_timestamps_struct(
    df: DataFrame,
    cols_name: Tuple[str, str],
    struct_col_name: str,
    struct_fields_name: Tuple[str, str] = ("start", "end"),
    format: str = "yyyy-MM-dd",
) -> DataFrame:
    return df.withColumn(
        struct_col_name,
        F.struct(
            F.to_timestamp(F.col(cols_name[0]), format).alias(struct_fields_name[0]),
            F.to_timestamp(F.col(cols_name[1]), format).alias(struct_fields_name[1]),
        ),
    ).drop(*cols_name)

# Creazione della SparkSession
#spark = SparkSession.builder.appName("DataFrameCreation").getOrCreate()

spark.conf.set("spark.sql.session.timeZone", "Europe/Rome")

# Schema del DataFrame
schema = [
    "ID_BIC_CLIENTE", 
    "bucket_start",
    "bucket_end",
    "sum_IMPORTO_by_METODO_PAGAMENTO_(carta)", 
    "sum_IMPORTO_by_METODO_PAGAMENTO_(cash)", 
    "sum_IMPORTO_by_METODO_PAGAMENTO_(bitcoin)", 
    "sum_IMPORTO_by_CA_CATEGORY_LIV0_(shopping)", 
    "sum_IMPORTO_by_CA_CATEGORY_LIV0_(salute)", 
    "sum_of_IMPORTO_by_CA_CATEGORY_LIV0_(shopping_salute)", 
    "sum_of_IMPORTO_by_IS_CARTA_(false)_and_by_METODO_PAGAMENTO_(carta)", 
    "sum_of_IMPORTO_by_IS_CARTA_(true)_and_by_METODO_PAGAMENTO_(chash_bitcoin)", 
]

# Dati per il DataFrame
data = [
    (348272371, "2023-01-05 01:00:00", "2023-01-06 01:00:00", 43, 72, 20, 135, 0,135, 0,0 ),
    (234984832, "2022-12-31 01:00:00", "2023-01-01 01:00:00", 51,78,0,0,129,129,0,0),
    (234984832, "2023-01-01 01:00:00", "2023-01-02 01:00:00", 2,0,0,0,0,0,0,0),
    (348272371, "2022-12-31 01:00:00", "2023-01-01 01:00:00", 6,14,0,5,6,11,0,0),
]

# Creazione del DataFrame
expected_df = spark.createDataFrame(data, schema)

expected_df =  create_timestamps_struct(
        df=expected_df, cols_name=("bucket_start", "bucket_end"), struct_col_name="bucket", format="yyyy-MM-dd HH:mm:ss"
    )

# Riordina le colonne
expected_df = expected_df.select("ID_BIC_CLIENTE", "bucket",
               "sum_IMPORTO_by_METODO_PAGAMENTO_(carta)", "sum_IMPORTO_by_METODO_PAGAMENTO_(cash)",
               "sum_IMPORTO_by_METODO_PAGAMENTO_(bitcoin)", "sum_IMPORTO_by_CA_CATEGORY_LIV0_(shopping)",
               "sum_IMPORTO_by_CA_CATEGORY_LIV0_(salute)", "sum_of_IMPORTO_by_CA_CATEGORY_LIV0_(shopping_salute)",
               "sum_of_IMPORTO_by_IS_CARTA_(false)_and_by_METODO_PAGAMENTO_(carta)",
               "sum_of_IMPORTO_by_IS_CARTA_(true)_and_by_METODO_PAGAMENTO_(chash_bitcoin)")

expected_df = expected_df.orderBy(["ID_BIC_CLIENTE", "bucket"])

# Visualizzare il DataFrame risultante
expected_df.show(truncate=False)



+--------------+------------------------------------------+---------------------------------------+--------------------------------------+-----------------------------------------+------------------------------------------+----------------------------------------+----------------------------------------------------+------------------------------------------------------------------+-------------------------------------------------------------------------+
|ID_BIC_CLIENTE|bucket                                    |sum_IMPORTO_by_METODO_PAGAMENTO_(carta)|sum_IMPORTO_by_METODO_PAGAMENTO_(cash)|sum_IMPORTO_by_METODO_PAGAMENTO_(bitcoin)|sum_IMPORTO_by_CA_CATEGORY_LIV0_(shopping)|sum_IMPORTO_by_CA_CATEGORY_LIV0_(salute)|sum_of_IMPORTO_by_CA_CATEGORY_LIV0_(shopping_salute)|sum_of_IMPORTO_by_IS_CARTA_(false)_and_by_METODO_PAGAMENTO_(carta)|sum_of_IMPORTO_by_IS_CARTA_(true)_and_by_METODO_PAGAMENTO_(chash_bitcoin)|
+--------------+------------------------------------------+-----------------------

In [8]:
expected_df.schema

StructType([StructField('ID_BIC_CLIENTE', LongType(), True), StructField('bucket', StructType([StructField('start', TimestampType(), True), StructField('end', TimestampType(), True)]), False), StructField('sum_IMPORTO_by_METODO_PAGAMENTO_(carta)', LongType(), True), StructField('sum_IMPORTO_by_METODO_PAGAMENTO_(cash)', LongType(), True), StructField('sum_IMPORTO_by_METODO_PAGAMENTO_(bitcoin)', LongType(), True), StructField('sum_IMPORTO_by_CA_CATEGORY_LIV0_(shopping)', LongType(), True), StructField('sum_IMPORTO_by_CA_CATEGORY_LIV0_(salute)', LongType(), True), StructField('sum_of_IMPORTO_by_CA_CATEGORY_LIV0_(shopping_salute)', LongType(), True), StructField('sum_of_IMPORTO_by_IS_CARTA_(false)_and_by_METODO_PAGAMENTO_(carta)', LongType(), True), StructField('sum_of_IMPORTO_by_IS_CARTA_(true)_and_by_METODO_PAGAMENTO_(chash_bitcoin)', LongType(), True)])

In [9]:
assert_frame_equal(
        expected_df,
        aggregated_df.fillna(0),
        check_metadata=False,
        check_column_order=False,
        check_row_order=False,
        check_nullable=False,
    )

In [10]:
aggregated_df.show()

+--------------+--------------------+---------------------------------------+--------------------------------------+-----------------------------------------+------------------------------------------+----------------------------------------+----------------------------------------------------+------------------------------------------------------------------+-------------------------------------------------------------------------+
|ID_BIC_CLIENTE|              bucket|sum_IMPORTO_by_METODO_PAGAMENTO_(carta)|sum_IMPORTO_by_METODO_PAGAMENTO_(cash)|sum_IMPORTO_by_METODO_PAGAMENTO_(bitcoin)|sum_IMPORTO_by_CA_CATEGORY_LIV0_(shopping)|sum_IMPORTO_by_CA_CATEGORY_LIV0_(salute)|sum_of_IMPORTO_by_CA_CATEGORY_LIV0_(shopping_salute)|sum_of_IMPORTO_by_IS_CARTA_(false)_and_by_METODO_PAGAMENTO_(carta)|sum_of_IMPORTO_by_IS_CARTA_(true)_and_by_METODO_PAGAMENTO_(chash_bitcoin)|
+--------------+--------------------+---------------------------------------+--------------------------------------+----------

23/07/26 16:03:58 WARN GarbageCollectionMetrics: To enable non-built-in garbage collector(s) List(G1 Concurrent GC), users should configure it(them) to spark.eventLog.gcMetrics.youngGenerationGarbageCollectors or spark.eventLog.gcMetrics.oldGenerationGarbageCollectors
