In [2]:
from pyspark.sql import SparkSession
from pyspark.sql import functions as F
from ts_train.step.time_bucketing import TimeBucketing  # type: ignore
from ts_train.step.filling import Filling  # type: ignore
from ts_train.step.aggregation import Aggregation  # type: ignore
from pyspark_assert import assert_frame_equal
from pyspark_assert._assertions import DifferentSchemaAssertionError


In [3]:
spark = SparkSession.builder.getOrCreate()

Setting default log level to "WARN".
To adjust logging level use sc.setLogLevel(newLevel). For SparkR, use setLogLevel(newLevel).
23/07/24 11:48:38 WARN NativeCodeLoader: Unable to load native-hadoop library for your platform... using builtin-java classes where applicable


In [37]:
def create_timestamps_struct(
    df,
    cols_name,
    struct_col_name: str,
    struct_fields_name = ("start", "end"),
    format: str = "yyyy-MM-dd",
):
    return df.withColumn(
        struct_col_name,
        F.struct(
            F.to_timestamp(F.col(cols_name[0]), format).alias(struct_fields_name[0]),
            F.to_timestamp(F.col(cols_name[1]), format).alias(struct_fields_name[1]),
        ),
    ).drop(*cols_name)

def sample_dataframe_02(spark):

    
    df = spark.createDataFrame(
        [
            (348272371, "2023-01-01", "2023-01-02", 61, 55, 97, 348272371),
            (348272371, "2023-01-06", "2023-01-07", None, 1354, None, 348272371),
            (234984832, "2023-01-01", "2023-01-02", 1298, None, None, 234984832),
            (234984832, "2023-01-02", "2023-01-03", None, None, 22, 234984832),
        ],
        schema=[
            "ID_BIC_CLIENTE",
            "bucket_start",
            "bucket_end",
            "salute",
            "shopping",
            "trasporti",
            "ID_BIC_CLIENTE_2",
        ],
    )

    return create_timestamps_struct(
        df=df, cols_name=("bucket_start", "bucket_end"), struct_col_name="bucket"
    )
    

In [39]:
# COMPLICATED VERSION

from pyspark.sql.functions import col, lag 
from pyspark.sql.window import Window

def test_process_samples_timestamp_distance_with_spark_utility(
    spark, sample_dataframe_pre_filling
):
    time_bucket_size = 10 
    time_bucket_granularity = "days"

    time_column_name = "timestamp"
    identifier_cols_name = ["ID_BIC_CLIENTE", "ID_BIC_CLIENTE_2"]

    # Tests that the difference between each sample for each id is one day.
    standard_filling = Filling(
        time_bucket_col_name="bucket",
        identifier_cols_name=identifier_cols_name,
        time_bucket_size=time_bucket_size,
        time_bucket_granularity=time_bucket_granularity,
    )

    df_after_filling = standard_filling(df=sample_dataframe_pre_filling, spark=spark)

    # Check if the specified column contains any null values
    contains_nulls = df_after_filling.where(col(time_column_name).isNull()).count() > 0
    assert not contains_nulls

    # Create a Window specification with partitioning by 'ID_BIC_CLIENTE' and 'altro', and ordering by 'timestamp'
    window_spec = Window.partitionBy(*identifier_cols_name).orderBy(time_column_name)

    # Calculate the time differences between all timestamps
    df_after_filling = df_after_filling.withColumn(
        f"shifted_{time_column_name}", lag(time_column_name, 1).over(window_spec)
    )
    df_after_filling = df_after_filling.withColumn(
        "time_diff",
        col(time_column_name).cast("long") - col(f"shifted_{time_column_name}").cast("long"),
    )

    # count the number of unique combinations of identifier_cols_name
    num_unique_combinations = df_after_filling.select(*identifier_cols_name).distinct().count()
    # count the number of null values in the time_diff column
    num_null_time_diff = df_after_filling.filter(col("time_diff").isNull()).count()
    assert num_unique_combinations == num_null_time_diff

    if df_after_filling.count() > num_null_time_diff:
        difference_between_timestamps = len(
            df_after_filling.select(F.collect_set("time_diff")).collect()[0][0]
        )
        assert difference_between_timestamps == 1

    df_after_filling.show(truncate=False)
    



In [40]:
sample_dataframe = sample_dataframe_02(spark)
sample_dataframe.show(truncate=False)
test_process_samples_timestamp_distance_with_spark_utility(spark, sample_dataframe)

+--------------+------+--------+---------+----------------+------------------------------------------+
|ID_BIC_CLIENTE|salute|shopping|trasporti|ID_BIC_CLIENTE_2|bucket                                    |
+--------------+------+--------+---------+----------------+------------------------------------------+
|348272371     |61    |55      |97       |348272371       |{2023-01-01 00:00:00, 2023-01-02 00:00:00}|
|348272371     |null  |1354    |null     |348272371       |{2023-01-06 00:00:00, 2023-01-07 00:00:00}|
|234984832     |1298  |null    |null     |234984832       |{2023-01-01 00:00:00, 2023-01-02 00:00:00}|
|234984832     |null  |null    |22       |234984832       |{2023-01-02 00:00:00, 2023-01-03 00:00:00}|
+--------------+------+--------+---------+----------------+------------------------------------------+

+--------------+----------------+-------------------+------+--------+---------+-----------------+---------+
|ID_BIC_CLIENTE|ID_BIC_CLIENTE_2|timestamp          |salute|shoppin

In [None]:

# SIMPLE VERSION 
from pyspark.sql.functions import col
from pyspark.sql.window import Window
from pyspark.sql.types import IntegerType, ArrayType

def test_process_samples_timestamp_distance(
    spark, sample_dataframe_02
):
    time_column_name = "timestamp"
    identifier_cols_name = "ID_BIC_CLIENTE"

    """Tests that the difference between each sample for each id is one day."""
    time_bucket_size = 20
    time_bucket_granularity = "minutes"
    
    standard_filling = Filling(
        time_bucket_col_name="bucket",
        identifier_cols_name=identifier_cols_name,
        time_bucket_size=time_bucket_size,
        time_bucket_granularity=time_bucket_granularity,
    )


    # Executes the time bucketing step and the filling step
    df_after_filling = standard_filling(df=sample_dataframe_02, spark=spark)
    df_after_filling.show(truncate=False)

    df_after_filling.show()
     
    # Check if the specified column contains any null values
    contains_nulls = df_after_filling.where(col(time_column_name).isNull()).count() > 0
    assert not contains_nulls
   

    # Convert the 'timestamp' column to Unix timestamp
    df_after_filling = df_after_filling.withColumn("timestamp_unix", F.unix_timestamp("timestamp"))

    # Group by 'ID_BIC_CLIENTE' and collect the list of all 'timestamp_unix' values for each user
    timestamps_per_user = df_after_filling.groupBy("ID_BIC_CLIENTE").agg(F.collect_list("timestamp_unix").alias("timestamps_list"))



    all_users = timestamps_per_user.select("ID_BIC_CLIENTE").distinct().collect()
    for user in all_users:
        user_timestamps = timestamps_per_user.filter(timestamps_per_user["ID_BIC_CLIENTE"] == user[0]).select("timestamps_list").collect()[0][0]
        # Calculate the differences between each element and the next one using list comprehension
        differences = [user_timestamps[i+1] - user_timestamps[i] for i in range(len(user_timestamps) - 1)]

        # Check if all differences are equal
        assert all(difference == differences[0] for difference in differences)
