# Creation of the dataset

## Imports & helpers

In [None]:
import pandas as pd
import numpy as np
import seaborn as sns
import matplotlib.pyplot as plt
import pyspark.sql.functions as F
import pyspark.sql.types as T
from pyspark.sql.window import Window
from pyspark.sql.functions import pandas_udf, PandasUDFType, udf

In [None]:
database = "dialysis_blood_pressure_v3"
spark.sql(f"CREATE DATABASE IF NOT EXISTS {database}")
spark.sql(f"USE {database}")

Helper function implemented as spark user defined function which takes an array of values and
does a simple linear regression to return the slope of the regression line as indicator if and how much values are increasing or decreasing

In [None]:
@udf(returnType=T.FloatType())
def extract_trend_from_array(arr, X = None):
    # adapted from https://stackoverflow.com/questions/10048571/python-finding-a-trend-in-a-set-of-numbers
    if X is None: 
      X = range(len(arr))
    Y = arr
    N = len(X)
    Sx = Sy = Sxx = Syy = Sxy = 0.0
    for x, y in zip(X, Y):
        Sx = Sx + x
        Sy = Sy + y
        Sxx = Sxx + x*x
        Syy = Syy + y*y
        Sxy = Sxy + x*y
    det = Sxx * N - Sx * Sx
    if det == 0:
      return 0
    a, b = (Sxy * N - Sy * Sx)/det, (Sxx * Sy - Sx * Sxy)/det
    return a

## Loading the basis dataset

The basis of this dataset is downloaded from https://figshare.com/articles/dataset/Hemrec_VIP_csv/6260654

​The research paper "Dataset supporting blood pressure prediction for the management of chronic hemodialysis"
describing the creation and content of the dataset can be found here https://www.nature.com/articles/s41597-019-0319-8

In [None]:
pdf_idp  = pd.read_csv("/dbfs/FileStore/shared_uploads/dialysis/idp.csv")
pdf2_d1  = pd.read_csv("/dbfs/FileStore/shared_uploads/dialysis/d1.csv")
pdf3_vip = pd.read_csv("/dbfs/FileStore/shared_uploads/dialysis/vip.csv")

## Data clearning & some first augmentation
There were two data quality issues found: 
* Timestamp of vip table has a difference of 8 hours to d1 table
    * fixed by substracting 8 hours from vip table
* In a few dialysis sessions the documented weight after dialysis is more than before dialysis and in some others the weight after dialysis a lot less than the target weight (dry weight). I was not able to find hints if and why this can happen in reality
    * Corresponding sessions filtered out

In [None]:
sessionWindow  = Window.partitionBy(["pid", "date"]).orderBy("tsp_dialysisstart")

pdf2_d1["diff_to_dry_start"] = (pdf2_d1["dryweight"] - pdf2_d1["weightstart"]).abs()
pdf2_d1["diff_to_dry_end"] = (pdf2_d1["dryweight"] - pdf2_d1["weightend"]).abs()
pdf2_d1["weight_loss"] = pdf2_d1["weightstart"] - pdf2_d1["weightend"]

# Note: Timestamp of vip table has a difference of 8 hours to d1 table
df3_vip = spark.createDataFrame(pdf3_vip) \
  .withColumn("timestamp", F.to_timestamp(F.col("datatime"),"yyyy-MM-dd HH:mm:ss") - F.expr("INTERVAL 8 HOURS")) \
  .withColumn("date", F.to_date(F.col("timestamp")))

df2_d1 = spark.createDataFrame(pdf2_d1).withColumn("timestamp", F.to_timestamp(F.col("keyindate"),"yyyy-MM-dd HH:mm:ss")) \
  .withColumn("date", F.to_date(F.col("timestamp"))) \
  .withColumn("tsp_dialysisstart", F.to_timestamp(F.concat(F.col("date"), F.lit(" "), F.col("dialysisstart")))) \
  .withColumn("tsp_dialysisend",   F.to_timestamp(F.concat(F.col("date"), F.lit(" "), F.col("dialysisend")))) \
  .withColumn("session_number", F.row_number().over(sessionWindow))
  
# Note: Some sessions had weird values regarding weight 
print("Session count before filtering weird weight values", df2_d1.count())
df2_d1 = df2_d1.filter("weight_loss > 0 and weightend > dryweight - dryweight * 0.1" ) 
print("Session count after  filtering weird weight values", df2_d1.count())
df_idp =  spark.createDataFrame(pdf_idp)

It turned out, that in a few cases a patient has two dialysis sessions instead of one like normally on the same day. 
So here the VIP table (values during dialysis) and the d1 Table (values about dialysis sessions) are joined to assign 
a session_number to the measurements in the VIP table which enables us to distinguish between sessions on the same day

In [None]:
df_vip_augmented = df2_d1.select("pid", "date", "session_number", "tsp_dialysisstart", "tsp_dialysisend") \
  .join(df3_vip.drop("datatime", "measuretime", "time"), ["pid", "date"], how="inner") \
  .withColumn("date_diff_start_min", (F.col("timestamp").cast("long") - F.col("tsp_dialysisstart").cast("long"))/60) \
  .withColumn("date_diff_end_min",   (F.col("timestamp").cast("long") - F.col("tsp_dialysisend").cast("long"))/60) \
  .filter("date_diff_start_min  >= -60 and date_diff_end_min <= 60") \
  .drop("date_diff_start_min", "date_diff_end_min")

## Configuration parameters

* Definition of hypotenstion according to https://www.mayoclinic.org/diseases-conditions/low-blood-pressure/symptoms-causes/syc-20355465 
* SBP is systolic blood pressure

In [None]:
# rounding of float variables 
precision = 3
# Definition of hypotension
df_vip_augmented_w_hypo_raw = df_vip_augmented.withColumn("is_hypotension", F.col("sbp") < 90)

## Create Measurements master table including classes for prediction
In the VIP table parameters from the periodic (between each 15 or 30 minutes) blood pressure measurements and the automatically recording of parameters (instantly recorded if something is changed) of the dialysis machine are mixed up. Since the goal of this dataset is to predict hypotension which is identified by a low sbp (systolic blood pressure) value the idea behind the measurements master table is to extract the measurements explicitly.
Moreover the classes for prediction are defined as following: 
```python 
     .withColumn("class", F.when( (F.col("hypotension_in_session") == True)  & ((F.col("latest_measurement") - 2  < F.col("measurement_id")) | (F.col("minutes_to_first_hypotension") <= 60)), \
                                "hypotension_is_coming").otherwise("no_hypotension_in_sight")) 
```
* Translated to human language easurements are assigned to class "hypotension_is_coming" if they are the last 2 measurements before a hypotension appeared of if they are in a timeframe of 60 minutes before a hypotension appeared
*  **Additionally please be aware of this:**
```python
    .filter("minutes_to_first_hypotension >= 15")
```
* All measurements that are taken in a timeframe of **15 minutes** before a hypotension has happened are sorted out since we want predict hypotensions so this makes sure that we have at least 15 minutes time to react on a prediction if this woudl be a real scenario (at least it is real data)
* Moreover all measurements after the first hypotension has happened is sorted out as well
* Measurements that have the class **"no_hypotension_in_sight"** are:
  * Measured at least more than 60 minutes before a hypotension appeared 
  * Never after a hypotension appeared (since all these measurements are dropped)
  * Part of a dialysis sessions where no hypotension appeared in the whole session

In [None]:
lagWindow  = Window.partitionBy(["pid", "date", "session_number"]).orderBy("timestamp")

df_measurements_raw = df_vip_augmented_w_hypo_raw \
  .withColumn("time_m",  F.round((F.col("timestamp").cast("long") - F.col("tsp_dialysisstart").cast("long")) / 60).cast("int")) \
  .withColumn("time_m_end",  F.round((F.col("tsp_dialysisend").cast("long") - F.col("timestamp").cast("long")) / 60).cast("int")) \
  .groupBy("pid", "date", "session_number","sbp", "dbp") \
  .agg( \
         F.max("timestamp").alias("timestamp"), \
         F.max("is_hypotension").alias("is_hypotension"), \
         F.min("time_m").alias("minutes_since_dialysisstart"), \
         F.min("time_m_end").alias("minutes_to_dialysisend") \
      ) \
  .withColumn("minutes_since_last_measurement", F.round((F.col("timestamp").cast("long") - F.lag("timestamp", 1).over(lagWindow).cast("long")) / 60)) \
  .na.fill({"minutes_since_last_measurement" :0})

df_measurements = df_measurements_raw \
  .join(df_measurements_raw.filter("is_hypotension  = True").groupBy("pid", "date", "session_number").agg(F.min("timestamp").alias("first_hypotension")), ["pid", "date", "session_number"]) \
  .filter("timestamp < first_hypotension") \
  .withColumn("minutes_to_first_hypotension", F.round((F.col("first_hypotension").cast("long") - F.col("timestamp").cast("long")) / 60).cast("int")) \
  .withColumn("measurement_id", F.row_number().over(lagWindow)) \
  .withColumn("hypotension_in_session", F.lit(True))

df_first_hypotensions = df_measurements.select("pid", "date", "session_number", "first_hypotension").distinct()

df_measurements_non_hypo = df_measurements_raw \
  .join(df_first_hypotensions.select("pid", "date", "session_number"), ["pid", "date", "session_number"], how="left_anti") \
  .withColumn("first_hypotension",F.lit(None)) \
  .withColumn("minutes_to_first_hypotension", F.lit(2000)) \
  .withColumn("measurement_id", F.row_number().over(lagWindow)) \
  .withColumn("hypotension_in_session", F.lit(False))

df_measurements =  df_measurements.union(df_measurements_non_hypo)

def sorter(l, name):
  import operator
  res = sorted(l, key=operator.itemgetter("minutes_since_dialysisstart"))
  return [item[name] for item in res]
sort_udf = F.udf(sorter, T.ArrayType(T.LongType()))

df_trends = df2_d1.select("pid", "date", "session_number").distinct() \
  .join(df_first_hypotensions, ["pid", "date", "session_number"], how="left_outer") \
  .join(df_measurements_raw, ["pid", "date", "session_number"]) \
  .filter("first_hypotension is null or timestamp < first_hypotension" ) \
  .groupBy("pid", "date", "session_number").agg(F.min("first_hypotension").alias("first_hypotension"),  \
                                                F.collect_list(F.struct("minutes_since_dialysisstart", "sbp", "dbp")).alias("data")) \
  .withColumn("sbp", sort_udf(F.col("data"), F.lit("sbp")))  \
  .withColumn("dbp", sort_udf(F.col("data"), F.lit("dbp")))  \
  .withColumn("minutes_since_dialysisstart", sort_udf(F.col("data"), F.lit("minutes_since_dialysisstart"))) \
  .withColumn("trend_sbp", F.round(extract_trend_from_array(F.col("sbp"), F.col("minutes_since_dialysisstart")), precision) ) \
  .withColumn("trend_dbp", F.round(extract_trend_from_array(F.col("dbp"), F.col("minutes_since_dialysisstart")), precision) ) \
  .drop("data", "sbp", "dbp", "minutes_since_dialysisstart")

df_measurements = df_measurements.select("pid", "date", "session_number", "measurement_id", "timestamp", \
                                         "minutes_since_last_measurement", "minutes_since_dialysisstart", "minutes_to_dialysisend", \
                                         "minutes_to_first_hypotension", "sbp", "dbp", "hypotension_in_session") \
    .filter("minutes_to_first_hypotension >= 15")

df_measurements = df_measurements.join(df_measurements.groupBy("pid", "date", "session_number").agg(F.max("measurement_id").alias("latest_measurement")),   \
                                       ["pid", "date", "session_number"]) \
    .withColumn("class", F.when( (F.col("hypotension_in_session") == True)  & ((F.col("latest_measurement") - 2  < F.col("measurement_id")) | (F.col("minutes_to_first_hypotension") <= 60)), \
                                "hypotension_is_coming").otherwise("no_hypotension_in_sight")) \
    .drop("latest_measurement") \
    .withColumn("change_sbp", F.col("sbp") - F.lag("sbp", 1).over(lagWindow)) \
    .withColumn("change_dbp", F.col("dbp") -F.lag("dbp", 1).over(lagWindow)) \
    .withColumn("dbps", F.collect_list("dbp").over(lagWindow)) \
    .withColumn("sbps", F.collect_list("sbp").over(lagWindow)) \
    .withColumn("minutes", F.collect_list("minutes_since_dialysisstart").over(lagWindow))  \
    .withColumn("trend_sbp_so_far", F.round(extract_trend_from_array(F.col("sbps"), F.col("minutes")), precision) ) \
    .withColumn("trend_dbp_so_far", F.round(extract_trend_from_array(F.col("dbps"), F.col("minutes")), precision) ) \
    .drop("dbps", "sbps", "minutes") \

df_measurements_p = df2_d1.select("pid", "date", "session_number", "weightstart", "dryweight", \
                                   F.round("diff_to_dry_start", precision).alias("diff_to_dry_start"), "temperature") \
  .join(df_measurements, ["pid", "date", "session_number"], how="inner") 
df_measurements_p = df_idp.join(df_measurements_p, ["pid"], how="inner") \
  .withColumnRenamed("DM",   "has_diabetes") \
  .withColumnRenamed("diff_to_dry_start", "diff_dryweight_weightstart") \
  .dropna()

## Creating stats for measurements of the dialysis machine
* There are the following parameters and sensor values of the dialysis machine:
    * dia_temp_value -> Temperature of the dialysate in °C
        * BTW: "Effect of cool temperature dialysate on the quality and patients’ perception of haemodialysis" https://academic.oup.com/ndt/article/19/1/190/1813435
    * conductivity   -> The conductivity of the dialysis fluid in mS/cm
        * "Dialysis fluid consists of a solution of inorganic salts that are dissociated in electrically charged ions. These ions can move in an electric field giving the salt solution electrically conducting properties, called conductivity. The conductivity of the dialysis fluid is a parameter well suited for measuring the total concentration of salt.  "https://pubmed.ncbi.nlm.nih.gov/16083025/
    * uf -> Ultrafiltration Rate in liters per hour   
        * The rate of ultrafiltration (ultrafiltration-fluid removal during hemodialysis).
        * "The ultrafiltration rate, as well as length of dialysis treatment time, control the amount of fluid to be removed. Your dialysis staff will set the ultrafiltration rate of your treatment based on your fluid weight gain since your last treatment. The goal is to get to your target or “dry weight”. If you drink too much fluid between dialysis treatments and your body cannot tolerate a higher ultrafiltration rate because fluid is being removed too fast, you may experience low blood pressure and cramping. Additionally, you may require a longer dialysis treatment and/or an extra treatment day if the extra fluid cannot be removed safely at one time https://www.kidney.org/atoz/content/ultrafiltration
    * blood_flow -> Blood flow in ml/min (milliliters per minute)
        * "During hemodialysis, a blood pump is set to a constant speed to push your blood through the dialyzer and back to your body. Your doctor prescribes the blood flow rate. It’s usually between 300 and 500 mL/min (milliliters per minute). Ask your technician to show you how to see the blood flow rate on your machine. With many dialyzers, blood flow rates greater than 400 mL/min can increase the removal of toxins. Blood flow rate is limited by the size of your access, the tubing and the needles." https://www.davita.com/treatment-services/dialysis/on-dialysis/how-does-my-doctor-know-if-dialysis-is-working

The idea behind df_assignments is to assign each change of these parameters between the blood pressure measurements to the following blood pressure measurement so we can computer the average, min and max values for since the last blood pressure measurement. 

In [None]:
window  = Window.partitionBy(["pid", "date", "session_number"]).orderBy("timestamp")
stat_cols = ["dia_temp_value", "conductivity", "uf", "blood_flow"]
aggs = [F.sum(F.col("diff_last_ts_s")).alias("sum_diff_last_ts_s")]

for c in stat_cols:
  aggs.append(F.sum(F.col(c) * F.col("diff_last_ts_s")).alias("sum_mult_" +c ))
  aggs.append(F.min(F.col(c)).alias("min_" +c ))
  aggs.append(F.max(F.col(c)).alias("max_" +c ))

df_measurements_f_assign = df_measurements_p.select("pid", "date", "session_number", "measurement_id", "timestamp", "hypotension_in_session")

df_assignments_raw = df_vip_augmented.drop("tsp_dialysisstart", "tsp_dialysisend") \
  .join(df_measurements_f_assign, ["pid", "date", "session_number", "timestamp"], how="left_outer") \
  .withColumn("measurement_ts", F.when(F.col("measurement_id").isNotNull(), F.col("timestamp")))
df_assignments_raw = df_assignments_raw.groupBy("pid", "date", "session_number") \
  .applyInPandas(lambda pdf : pdf.sort_values(by=['timestamp']).backfill(), df_assignments_raw.schema) 
df_assignments_raw = df_assignments_raw \
  .withColumn("diff_to_measurement_ts_s", (F.col("timestamp").cast("long") - F.col("measurement_ts").cast("long"))) \
  .withColumn("diff_last_ts_s", F.lag("timestamp", -1).over(window).cast("long") - F.col("timestamp").cast("long")) \
  .dropna() 

df_assignments = df_assignments_raw.groupBy("pid", "date", "session_number", "measurement_id").agg(*aggs)

for c in stat_cols:
  df_assignments = df_assignments \
      .withColumn("avg_" + c, F.round(F.col("sum_mult_"  + c) / F.col("sum_diff_last_ts_s"), precision)).drop("sum_mult_"  + c)
df_assignments = df_assignments.drop("sum_diff_last_ts_s")  


## Finally building the dataset
Here the measurements (including the joined patient information like gender, age, first dialysis etc. ) are joined with the statistics of the dialysis machine. 
Some augmentations are done like computing the days since first dialysis and some columns are renamed for better self explainability

In [None]:
df_dataset  = df_measurements_p.join(df_assignments, ["pid", "date", "session_number", "measurement_id"]) \
  .dropna() \
  .withColumn("date_first_dialysis", F.to_date("first_dialysis", "yyyy-MM")) \
  .withColumn("days_since_first_dialysis",F.datediff( F.col("date"), F.col("date_first_dialysis"))) \
  .withColumn("age_at_treatment",F.year("date") - F.col("birthday")) \
  .withColumnRenamed("pid", "patient_id") \
  .withColumnRenamed("date", "date_of_treatment") \
  .withColumnRenamed("birthday", "year_of_birth") \
  .withColumnRenamed("temperature", "body_temperature") \
  .withColumnRenamed("timestamp", "timestamp_measurement") \
  .withColumnRenamed("sbp","sbp_systolic_blood_pressure") \
  .withColumnRenamed("dbp","dbp_diastolic_blood_pressure")

The data is persisted to two delta lake tables for downloading it to github
* dataset is the computed dataset
* dataset_4training is the dataset but with the following columns excluded "hypotension_in_session", "minutes_to_first_hypotension", "minutes_to_dialysisend"
    * **this is the dataset you can find as CSV in this git repository**
    * the exclusion was done because this information should not be available to the machine learning training since this could lead to unwanted behavoir like classifying everything to "no_hypotension_in_sight" where "hypotension_in_session" is false

In [None]:
df_dataset_table = "dataset" 
df_dataset.write.format("delta") \
 .mode("overwrite") \
 .option("overwriteSchema", "true") \
 .saveAsTable(df_dataset_table)

In [None]:
df_dataset \
 .drop("hypotension_in_session", "minutes_to_first_hypotension", "minutes_to_dialysisend") \
 .write.format("delta") \
 .mode("overwrite") \
 .option("overwriteSchema", "true") \
 .saveAsTable(df_dataset_table + "_4training")

# Some example data

In [None]:
display(spark.sql("SELECT * FROM dataset").drop("hypotension_in_session", "minutes_to_first_hypotension", "minutes_to_dialysisend"))