In [None]:
from pyspark.sql import SparkSession, DataFrame, Window, Column
from pyspark.sql import functions as psf
from pyspark.sql.types import StructType, StructField, StringType, TimestampType
from datetime import datetime, timedelta
from typing import List
from uuid import uuid4
from time import time
import random

In [None]:
spark = SparkSession.builder.appName("closest_timestamp_join").getOrCreate()

In [None]:
event_schema = StructType([
    StructField("event_id", StringType(), True),
    StructField("event_time", TimestampType(), True),
])
record_schema = StructType([
    StructField("record_id", StringType(), True),
    StructField("record_time", TimestampType(), True),
])

In [None]:
def generate_dataframe(spark: SparkSession, schema: StructType, row_count: int, start_time: datetime = datetime.now()) -> DataFrame:
    records = [(str(uuid4()), start_time + timedelta(seconds=i+random.random())) for i in range(row_count)]
    return spark.createDataFrame(records, schema)

# Use rank method of window function

|Row count|execution time|
|:-:|:-:|
|100|1.712 sec|
|1,000|2.525 sec|
|10,000|44.915 sec|
|100,000|timeout|
|1,000,000|timeout|

In [None]:
def join_using_rank(event_df: DataFrame, record_df: DataFrame) -> DataFrame:
    w = Window.partitionBy("e.event_id").orderBy("diff")
    rank_df = (event_df.orderBy("event_time").alias("e")
        .join(record_df.orderBy("record_time").alias("r"))
        .withColumn("diff", psf.abs( psf.col("e.event_time").cast("double") - psf.col("r.record_time").cast("double") ))
        .withColumn("rank", psf.rank().over(w))
        .filter("rank == 1")
        .select("e.event_id", "r.record_id", "e.event_time", "r.record_time", "rank", "diff")
    )
    return rank_df

In [None]:
result = []
for i in [100, 1_000, 10_000, 100_000, 1_000_000]:
    event_df = generate_dataframe(spark, event_schema, i)
    record_df = generate_dataframe(spark, record_schema, i)
    
    s = time()
    df_using_rank_join = join_using_rank(event_df, record_df)
    print(df_using_rank_join.count())
    exec_time = time() - s
    print(exec_time)
    result = [*result, {"row_count": i, "exec_time": exec_time}]
result

# Use window start and end join

|Row count|execution time|
|:-:|:-:|
|100|0.860 sec|
|1,000|0.796 sec|
|10,000|0.855 sec|
|100,000|1,756 sec|
|1,000,000|24.059 sec|

In [None]:
def add_prev_and_next_record_info(df: DataFrame) -> DataFrame:
    id_col = next(iter([c.name for c in df.schema if "_id" in c.name]))
    time_col = next(iter([c.name for c in df.schema if "_time" in c.name]))
    w = Window.orderBy(time_col)
    return (df
        .orderBy(time_col)
        .withColumn("prev_time", psf.lag(time_col).over(w))
        .withColumn("current_time", psf.col(time_col))
        .withColumn("next_time", psf.lead(time_col).over(w))
        .withColumn("prev_id", psf.lag(id_col).over(w))
        .withColumn("current_id", psf.col(id_col))
        .withColumn("next_id", psf.lead(id_col).over(w))
        .withColumn("window_start", psf.window(timeColumn=time_col, windowDuration="1 minute")["start"])
        .withColumn("window_end", psf.window(timeColumn=time_col, windowDuration="1 minute")["end"])
    )

def generate_diff_from_join_base_time(base_time_col_name: str, comp_col_prefixes: List[str]) -> Column:
    
    diff_info_structs = []
    for prefix in comp_col_prefixes:
        s = psf.struct(
            psf.abs(psf.col(base_time_col_name).cast("double") - psf.col(f"{prefix}_time").cast("double")).alias("diff"),
            psf.col(f"{prefix}_id").alias("id"),
            psf.col(f"{prefix}_time").alias("time")
        )
        diff_info_structs = [*diff_info_structs, s]
    
    return psf.element_at(psf.array_sort(psf.array(*diff_info_structs)), 1)

def join_using_window_start_end(event_df: DataFrame, record_df: DataFrame) -> DataFrame:
    event_df_with_window = add_prev_and_next_record_info(event_df)
    record_df_with_window = add_prev_and_next_record_info(record_df)
    df = (event_df_with_window.alias("e")
        .join(record_df_with_window.alias("r"), ["window_start", "window_end"])
        .filter(
            (psf.col("e.current_time") >= psf.col("r.prev_time"))
            & (psf.col("e.current_time") < psf.col("r.next_time"))
        )
        .withColumn("min_diff", generate_diff_from_join_base_time("e.current_time", ["r.prev", "r.current", "r.next"]))
        .select(
            psf.col("e.current_time").alias("event_time"),
            psf.col("e.current_id").alias("event_id"),
            psf.col("min_diff.id").alias("record_id"),
            psf.col("min_diff.time").alias("record_time"),
            psf.col("min_diff.diff").alias("diff"),
        )
        .dropDuplicates(["event_id", "record_id"])
    )
    
    return df

In [None]:
result = []
for i in [100, 1_000, 10_000, 100_000, 1_000_000]:
    event_df = generate_dataframe(spark, event_schema, i)
    record_df = generate_dataframe(spark, record_schema, i)
    
    s = time()
    df_using_window_start_end_join = join_using_window_start_end(event_df, record_df)
    print(df_using_window_start_end_join.count())
    exec_time = time() - s
    print(exec_time)
    result = [*result, {"row_count": i, "exec_time": exec_time}]
result