In [25]:
import pandas as pd
from pyspark.sql import SparkSession
from pyspark.sql.functions import col, when, split, explode, array, lit, concat_ws, expr, lag
from pyspark.sql.types import StructType, StructField, StringType, IntegerType
from pyspark.sql.window import Window
from pyspark.sql import functions as F
# Initialize SparkSession
spark = SparkSession.builder.appName("LagFeatureExample").getOrCreate()

In [206]:
# Example data including "all" intervals
data = {
    "geohashId": ["abcd", "abcd", "abcd", "abcd", "abcd", "abcd", "abcd", "abcd", "abcd", "abcd",
                  "efgh", "efgh", "efgh", "efgh", "efgh", "efgh", "efgh", "efgh", "efgh", "efgh"],
    "date": ["2023-08-01", "2023-08-01", "2023-08-01", "2023-08-01", "2023-08-01",
             "2023-08-01", "2023-08-01", "2023-08-01", "2023-08-01", "2023-08-01",
             "2023-08-01", "2023-08-01", "2023-08-01", "2023-08-01", "2023-08-01",
             "2023-08-02", "2023-08-02", "2023-08-02", "2023-08-02", "2023-08-02"],
    "hour": [0, 0, 0, 1, 1, 1, 2, 2, 2, 3,
             0, 0, 0, 1, 1, 1, 2, 2, 2, 3],
    "interval": ["0-14", "15-29", "30-44", "0-14", "15-29", "all", "0-14", "15-29", "30-44", "all",
                 "0-14", "15-29", "30-44", "0-14", "15-29", "all", "0-14", "15-29", "30-44", "all"],
    "featureA_variance": [1, 2, 3, 4, 5, 6, 7, 8, 9, 10,
                          1, 2, 3, 4, 5, 6, 7, 8, 9, 10],
    "featureA_median": [10, 20, 30, 40, 50, 60, 70, 80, 90, 100,
                       10, 20, 30, 40, 50, 60, 70, 80, 90, 100],
    "featureB_variance": [1, 1, 1, 2, 2, 2, 3, 3, 3, 4,
                          1, 1, 1, 2, 2, 2, 3, 3, 3, 4],
    "featureB_median": [5, 5, 5, 6, 6, 6, 7, 7, 7, 8,
                     5, 5, 5, 6, 6, 6, 7, 7, 7, 8]
}
# Create DataFrame
df = pd.DataFrame(data)
sdf = spark.createDataFrame(df)
sdf.show()

+---------+----------+----+--------+-----------------+---------------+-----------------+---------------+
|geohashId|      date|hour|interval|featureA_variance|featureA_median|featureB_variance|featureB_median|
+---------+----------+----+--------+-----------------+---------------+-----------------+---------------+
|     abcd|2023-08-01|   0|    0-14|                1|             10|                1|              5|
|     abcd|2023-08-01|   0|   15-29|                2|             20|                1|              5|
|     abcd|2023-08-01|   0|   30-44|                3|             30|                1|              5|
|     abcd|2023-08-01|   1|    0-14|                4|             40|                2|              6|
|     abcd|2023-08-01|   1|   15-29|                5|             50|                2|              6|
|     abcd|2023-08-01|   1|     all|                6|             60|                2|              6|
|     abcd|2023-08-01|   2|    0-14|                7| 

In [207]:
from pyspark.sql.functions import col

# hourと15minutesのデータに切り分ける
sdf_hour = sdf.filter(col("interval") == "all")
sdf_15min = sdf.filter(col("interval") != "all")

In [208]:
first_date = sdf.agg(F.min('date')).collect()[0][0]
first_date

'2023-08-01'

In [209]:
last_date = sdf.agg(F.max('date')).collect()[0][0]
last_date

'2023-08-02'

In [210]:
base_timestamp = F.unix_timestamp(lit(f"{first_date} 00:00:00"))
base_timestamp

Column<'unix_timestamp(2023-08-01 00:00:00, yyyy-MM-dd HH:mm:ss)'>

In [211]:
sdf_hour = sdf_hour.withColumn(
    "timestamp",
    F.unix_timestamp(F.concat(
        F.col("date"),
        F.lit(" "),
        F.lpad(F.col("hour").cast("string"), 2, "0"),
        F.lit(":00:00")
    ))
)

sdf_hour.toPandas().head()

Unnamed: 0,geohashId,date,hour,interval,featureA_variance,featureA_median,featureB_variance,featureB_median,timestamp
0,abcd,2023-08-01,1,all,6,60,2,6,1690819200
1,abcd,2023-08-01,3,all,10,100,4,8,1690826400
2,efgh,2023-08-02,1,all,6,60,2,6,1690905600
3,efgh,2023-08-02,3,all,10,100,4,8,1690912800


In [212]:
sdf_hour = sdf_hour.withColumn(
    "time",
    F.round(F.col("timestamp") - base_timestamp) / 3600,
)

sdf_hour.toPandas().head()

Unnamed: 0,geohashId,date,hour,interval,featureA_variance,featureA_median,featureB_variance,featureB_median,timestamp,time
0,abcd,2023-08-01,1,all,6,60,2,6,1690819200,1.0
1,abcd,2023-08-01,3,all,10,100,4,8,1690826400,3.0
2,efgh,2023-08-02,1,all,6,60,2,6,1690905600,25.0
3,efgh,2023-08-02,3,all,10,100,4,8,1690912800,27.0


In [213]:
sdf_15min = sdf_15min.withColumn(
    "timestamp",
    F.unix_timestamp(F.concat(
        F.col("date"),
        F.lit(" "),
        F.lpad(F.col("hour").cast("string"), 2, "0"),
        F.lit(":"),
        F.lpad(F.split(F.col("interval"), '-').getItem(0), 2, "0"),
        F.lit(":00")
    ))
)

sdf_15min.toPandas().head()

Unnamed: 0,geohashId,date,hour,interval,featureA_variance,featureA_median,featureB_variance,featureB_median,timestamp
0,abcd,2023-08-01,0,0-14,1,10,1,5,1690815600
1,abcd,2023-08-01,0,15-29,2,20,1,5,1690816500
2,abcd,2023-08-01,0,30-44,3,30,1,5,1690817400
3,abcd,2023-08-01,1,0-14,4,40,2,6,1690819200
4,abcd,2023-08-01,1,15-29,5,50,2,6,1690820100


In [214]:
sdf_15min = sdf_15min.withColumn(
    "time",
    F.round(F.col("timestamp") - base_timestamp) / 900,
)

sdf_15min.toPandas().head()

Unnamed: 0,geohashId,date,hour,interval,featureA_variance,featureA_median,featureB_variance,featureB_median,timestamp,time
0,abcd,2023-08-01,0,0-14,1,10,1,5,1690815600,0.0
1,abcd,2023-08-01,0,15-29,2,20,1,5,1690816500,1.0
2,abcd,2023-08-01,0,30-44,3,30,1,5,1690817400,2.0
3,abcd,2023-08-01,1,0-14,4,40,2,6,1690819200,4.0
4,abcd,2023-08-01,1,15-29,5,50,2,6,1690820100,5.0


In [215]:
# For loop

# # Step 2: Generate a Complete Time Grid
# # Generate a list of all 15-minute intervals within a day
# # Create a DataFrame with all intervals for each geohashId and date
# geohash_ids = sdf.select("geohashId").distinct().rdd.flatMap(lambda x: x).collect()
# num_dates = len(pd.date_range(first_date, last_date))

# complete_hour_grid = []
# for geohash in geohash_ids:
#     for i in range(num_dates * 24):
#         complete_hour_grid.append((geohash, i))

# complete_15min_grid = []
# for geohash in geohash_ids:
#     for i in range(num_dates * 24 * 4):
#         complete_15min_grid.append((geohash, i))

# schema = StructType([
#     StructField("geohashId", StringType(), True),
#     StructField("time", IntegerType(), True),
# ])

# # Join the original dataframe with the complete time grid

# complete_hour_df = spark.createDataFrame(pd.DataFrame(complete_hour_grid, columns=["geohashId", "time"]), schema)
# sdf_hour = complete_hour_df.join(sdf_hour, on=["geohashId", "time"], how="left")

# complete_15min_df = spark.createDataFrame(pd.DataFrame(complete_15min_grid, columns=["geohashId", "time"]), schema)
# sdf_15min = complete_15min_df.join(sdf_15min, on=["geohashId", "time"], how="left")

In [219]:
# Cross join

geohash_ids = sdf.select("geohashId").distinct()
times_hours = []
for i in range(num_dates * 24):
    times_hours.append(i)

times_15min = []
for i in range(num_dates * 24 * 4):
    times_15min.append(i)

schema = StructType([
    StructField("time", IntegerType(), True),
])

# Perform cross join between all unique value DataFrames
complete_hour_grid = geohash_ids.crossJoin(spark.createDataFrame([(i,) for i in times_hours], schema))
complete_15min_grid = geohash_ids.crossJoin(spark.createDataFrame([(i,) for i in times_15min], schema))

# Join the original dataframe with the complete time grid

sdf_hour = complete_hour_grid.join(sdf_hour, on=["geohashId", "time"], how="left")

sdf_15min = complete_15min_grid.join(sdf_15min, on=["geohashId", "time"], how="left")

In [220]:
# Fill missing values with zeros for all columns except the key columns
agg_dict = {
    "featureA": ["variance", "median"],
    "featureB": ["variance", "median"]
}
value_columns = {f"{key}_{metric}" for key, metrics in agg_dict.items() for metric in metrics}
fill_dict = {column: 0 for column in value_columns}

sdf_hour = sdf_hour.fillna(fill_dict)
sdf_15min = sdf_15min.fillna(fill_dict)

In [221]:
# Step 3: Define Window Specification

# Define the window specification to lag by one interval (15 minutes or one hour)
windowSpec = Window.partitionBy('geohashId').orderBy('time')

In [222]:
# Step 4: Create Lag Features for All Specified Value Columns
# Create lag features
lag_exprs = [lag(col_name, 1).over(windowSpec).alias(f'{col_name}_lag') for col_name in value_columns]

# Select existing columns and add the new lag columns
sdf_hour = sdf_hour.select('*', *lag_exprs)
sdf_15min = sdf_15min.select('*', *lag_exprs)

In [223]:
# Fill any null values that result from lagging at the start of the partitions
fill_dict = {f'{column}_lag': 0 for column in value_columns}
sdf_hour = sdf_hour.fillna(fill_dict)
sdf_15min = sdf_15min.fillna(fill_dict)

In [224]:
sdf_hour = sdf_hour.orderBy(["geohashId", "time"])
sdf_15min = sdf_15min.orderBy(["geohashId", "time"])

In [225]:
sdf_hour = sdf_hour.withColumn(
    "timestamp",
    F.to_timestamp(base_timestamp + F.col("time") * 3600)
)
sdf_15min = sdf_15min.withColumn(
    "timestamp",
    F.to_timestamp(base_timestamp + F.col("time") * 900)
)

In [226]:
sdf_hour.select("geohashId", "timestamp", "featureA_variance", "featureA_variance_lag").show()

+---------+-------------------+-----------------+---------------------+
|geohashId|          timestamp|featureA_variance|featureA_variance_lag|
+---------+-------------------+-----------------+---------------------+
|     abcd|2023-08-01 00:00:00|                0|                    0|
|     abcd|2023-08-01 01:00:00|                6|                    0|
|     abcd|2023-08-01 02:00:00|                0|                    6|
|     abcd|2023-08-01 03:00:00|               10|                    0|
|     abcd|2023-08-01 04:00:00|                0|                   10|
|     abcd|2023-08-01 05:00:00|                0|                    0|
|     abcd|2023-08-01 06:00:00|                0|                    0|
|     abcd|2023-08-01 07:00:00|                0|                    0|
|     abcd|2023-08-01 08:00:00|                0|                    0|
|     abcd|2023-08-01 09:00:00|                0|                    0|
|     abcd|2023-08-01 10:00:00|                0|               

In [227]:
sdf_15min.select("geohashId", "timestamp", "featureA_variance", "featureA_variance_lag").show()

+---------+-------------------+-----------------+---------------------+
|geohashId|          timestamp|featureA_variance|featureA_variance_lag|
+---------+-------------------+-----------------+---------------------+
|     abcd|2023-08-01 00:00:00|                1|                    0|
|     abcd|2023-08-01 00:15:00|                2|                    1|
|     abcd|2023-08-01 00:30:00|                3|                    2|
|     abcd|2023-08-01 00:45:00|                0|                    3|
|     abcd|2023-08-01 01:00:00|                4|                    0|
|     abcd|2023-08-01 01:15:00|                5|                    4|
|     abcd|2023-08-01 01:30:00|                0|                    5|
|     abcd|2023-08-01 01:45:00|                0|                    0|
|     abcd|2023-08-01 02:00:00|                7|                    0|
|     abcd|2023-08-01 02:15:00|                8|                    7|
|     abcd|2023-08-01 02:30:00|                9|               