
# Sliding Window Features

To minimize the risk of conflating different ways to derive features, we are going to define the following categories of features:
 - **Point-in-Time Features**: A type of data join that ensures that features are derived using the first available observation before or at a given timestamp. Commonly, results are pivoted so that every lab will have it's own column.
 - **Sliding Window Features**: A way to aggregate events that happened within a specific rolling window (e.g., last 7 days, last 30 minutes). Typically these aggregates are numeric aggregates that return a scalar response like, `mean`, `min`, `max`,
 - **Events-Based Features**:	Features are computed based on occurrences (one or many) of specific events before the observation point.
 - **Cohort-Based Features**: Features are generated based on historical groupings within a fixed observation window. The difference between Event-Based and Cohort-Based is the timestamp for each patient obeservation is the same, opposed an event-based where each patient event time is not shared. 

In this notebook, we'll explore writing a convenience function, `sliding_window_numeric_aggregates`, to do **Sliding Window Features** retrieval from the lab data we created in <a href="$./00_Data_Generation" target="_blank">00_Data_Generation</a> 
, `main.default.patient_lab`.

**NOTE**: We are going to write `sliding_window_numeric_aggregates` with the naming convention `<agg_class_name>_<num_days>_<lab_type>`. This convention is necessary because we will need to dynamically create the feature names in a way that there is no name collisions.


### `sliding_window_numeric_aggregates`

Sliding window features are the most commonly thought of feature aggregates. While it is possible to write aggregate functions that can reduce strings or any other supported spark column types, the largest number of built-in aggregrate functions are actually written for numeric values. Therefore so that we can demonstrate using these functions by just passing existing built in functions, we will write the function where it pre-supposes that the `lab_value` can be coearsed into a FLOAT type. Thus, we will be working with `ua_ph` labs which is numeric.

So that we can see all the transforms that will go into creating `sliding_window_numeric_aggregates`, we will show each step and put the steps together in a final function. 

In [0]:
# to do a range join, we'll define our look back window size in days
from pyspark.sql.functions import col, to_timestamp, expr
from pyspark.sql.functions import explode, lit, date_sub

# This is approximately one month and 9 months
windows_in_days = [1*30, 9*30]

# We'll use lab_types to filter for only the labs of interest
lab_types = ['ua_ph', ]

patient_lab = spark.table("main.default.patient_lab").alias('pl')
patient_event = spark.table("main.default.patient_event").alias('pe')

patient_event_labs = patient_event.withColumn("window_days", explode(lit(windows_in_days))) \
                                  .withColumnRenamed("event_ts","end_window_ts") \
                                  .withColumn("start_window_ts", date_sub(col("end_window_ts"), col("window_days")).cast("TIMESTAMP")) \
                                  .join(patient_lab.filter(col("pl.lab_type").isin(lab_types)),
                                        (patient_lab.patient_id == patient_event.patient_id) &
                                        (patient_lab.event_ts.between(col("start_window_ts"), col("end_window_ts"))),
                                        "leftouter") \
                                  .drop(col("pl.patient_id"), "start_window_ts") \
                                  .withColumn('lab_value', col('lab_value').cast("FLOAT"))

display(patient_event_labs)

In [0]:
from pyspark.sql.functions import collect_list, struct, col, concat_ws
from pyspark.sql.functions import min, max, mean

agg_funcs = [min, max, mean]

aggs = [x("lab_value").alias(f"{x.__name__}") for x in agg_funcs]

agg_cols = [f"{x.__name__}" for x in agg_funcs]

# Group by patient_id and lab_type, and collect event_ts and lab_value into an array of structs
grouped_labs = patient_event_labs.withColumn("days_lab", concat_ws("_", "lab_type","window_days")) \
                                 .groupBy("patient_id", "days_lab", "end_window_ts") \
                                 .agg(*aggs) \
                                 .withColumn("aggregates", struct(*[col(c) for c in agg_cols])) \
                                 .drop(*agg_cols)

display(grouped_labs)

In [0]:
from pyspark.sql.functions import first_value

pivot_labs = grouped_labs.withColumnRenamed("end_window_ts","event_ts") \
                         .groupBy("patient_id", "event_ts") \
                         .pivot("days_lab") \
                         .agg(first_value("aggregates"))

display(pivot_labs)

In [0]:
from pyspark.sql.functions import col
from itertools import product

window_cols = [c for c in pivot_labs.columns if c not in ['patient_id', 'event_ts']]

feat_cols = [col(f'{w}.{a}').alias(f'{w}_{a}') for w, a in product(window_cols, agg_cols)]

rslt = pivot_labs.select("patient_id", "event_ts", *feat_cols)

display(rslt)

In [0]:
# Putting it all together, we can write a sliding_window_numeric_aggregates function such as:
from pyspark.sql import DataFrame
from pyspark.sql.functions import col, to_timestamp, expr, explode, lit, date_sub, collect_list, struct, col, concat_ws, first_value
from pyspark.sql.functions import min, max, mean
from typing import Callable
from pyspark.sql import Column
from itertools import product

def sliding_window_numeric_aggregates(patient_event_df: DataFrame,
                                      agg_funcs:[Callable[[str], Column]],
                                      lab_types: [str],
                                      windows_in_days: int):
    
    patient_lab = spark.table("main.default.patient_lab").alias('pl')

    patient_event_labs = patient_event_df.alias('pe') \
                                         .withColumn("window_days", explode(lit(windows_in_days))) \
                                         .withColumnRenamed("event_ts","end_window_ts") \
                                         .withColumn("start_window_ts", date_sub(col("end_window_ts"), col("window_days")).cast("TIMESTAMP")) \
                                         .join(patient_lab.filter(col("pl.lab_type").isin(lab_types)),
                                               (patient_lab.patient_id == patient_event_df.patient_id) &
                                               (patient_lab.event_ts.between(col("start_window_ts"), col("end_window_ts"))),
                                               "leftouter") \
                                         .drop(col("pl.patient_id"), "start_window_ts") \
                                         .withColumn('lab_value', col('lab_value').cast("FLOAT"))
    
    aggs = [x("lab_value").alias(f"{x.__name__}") for x in agg_funcs]
    agg_cols = [f"{x.__name__}" for x in agg_funcs]

    grouped_labs = patient_event_labs.withColumn("days_lab", concat_ws("_", "lab_type","window_days")) \
                                     .groupBy("patient_id", "days_lab", "end_window_ts") \
                                     .agg(*aggs) \
                                     .withColumn("aggregates", struct(*[col(c) for c in agg_cols])) \
                                     .drop(*agg_cols)

    pivot_labs = grouped_labs.withColumnRenamed("end_window_ts","event_ts") \
                             .groupBy("patient_id", "event_ts") \
                             .pivot("days_lab") \
                             .agg(first_value("aggregates"))

    window_cols = [c for c in pivot_labs.columns if c not in ['patient_id', 'event_ts']]
    feat_cols = [col(f'{w}.{a}').alias(f'{w}_{a}') for w, a in product(window_cols, agg_cols)]

    rslt = pivot_labs.select("patient_id", "event_ts", *feat_cols)

    return rslt

In [0]:
dat = sliding_window_numeric_aggregates(patient_event_df=spark.table("main.default.patient_event"),
                                        agg_funcs=[min, max, mean],
                                        lab_types=['ua_ph',],
                                        windows_in_days=[1*30, 9*30])

display(dat)