# Save raw data to 60-20-20 split for train-test-validation sets

In [1]:
import time
import pyspark
from pyspark.sql import SparkSession, Window
from pyspark.sql.functions import when, col, to_date, rank, monotonically_increasing_id
from pyspark.sql.types import StructType, StructField, StringType, IntegerType, TimestampType, DateType, FloatType
import pathlib

# spark.stop()

In [2]:
conf = pyspark.SparkConf().setAll([\
    ('spark.app.name', 'ReduceData')])
spark = SparkSession.builder.config(conf=conf)\
    .getOrCreate()

spark.version

Setting default log level to "WARN".
To adjust logging level use sc.setLogLevel(newLevel). For SparkR, use setLogLevel(newLevel).


23/05/07 10:41:06 WARN NativeCodeLoader: Unable to load native-hadoop library for your platform... using builtin-java classes where applicable
23/05/07 10:41:07 WARN Utils: Service 'SparkUI' could not bind on port 4040. Attempting port 4041.


'3.3.1'

### Structs we might need

In [3]:
glucose_data_schema=StructType([StructField('PatientId', StringType(), True),
                                StructField('GlucoseDisplayTime', TimestampType(), True),
                                StructField('GlucoseDisplayTimeRaw', StringType(), True),
                                StructField('Value', FloatType(), True)
                                ])

raw_schema=StructType([StructField('_c0', IntegerType(),True),
                                StructField('PostDate', TimestampType(),True),
                                StructField('IngestionDate', TimestampType(),True),
                                StructField('PostId', StringType(),True),
                                StructField('PostTime', TimestampType(), True),
                                StructField('PatientId', StringType(), True),
                                StructField('Stream', StringType(), True),
                                StructField('SequenceNumber', StringType(), True),
                                StructField('TransmitterNumber', StringType(), True),
                                StructField('ReceiverNumber', StringType(), True),
                                StructField('RecordedSystemTime', TimestampType(), True),
                                StructField('RecordedDisplayTime', TimestampType(), True),
                                StructField('RecordedDisplayTimeRaw', TimestampType(), True),
                                StructField('TransmitterId', StringType(), True),
                                StructField('TransmitterTime', StringType(), True),
                                StructField('GlucoseSystemTime', TimestampType(), True),
                                StructField('GlucoseDisplayTime', TimestampType(), True),
                                StructField('GlucoseDisplayTimeRaw', StringType(), True),
                                StructField('Value', FloatType(), True),
                                StructField('Status', StringType(), True),
                                StructField('TrendArrow', StringType(), True),
                                StructField('TrendRate', FloatType(), True),
                                StructField('IsBackFilled', StringType(), True),
                                StructField('InternalStatus', StringType(), True),
                                StructField('SessionStartTime', StringType(), True)])

cohortSchema = StructType([StructField('', IntegerType(), True),
                        StructField('UserId', StringType(), True),
                        StructField('Gender', StringType(), True),
                        StructField('DOB', TimestampType(), True),
                        StructField('Age', IntegerType(), True),
                        StructField('DiabetesType', StringType(), True),
                        StructField('Treatment', StringType(), True)
                        ])

### Get all paths in to read from

In [4]:
'''all CSVs of the raw data'''
allPaths = [str(x) for x in list(pathlib.Path('/cephfs/data').glob('*.csv')) if 'glucose_records' in str(x)]
allPaths.sort()
# trainPaths = allPaths[:219]
# valPaths = allPaths[219:292]
# testPaths = allPaths[292:]
# print("train length:", len(trainPaths), "\nvalidation length:", len(valPaths),"\ntest length:", len(testPaths))

### Read in the Cohort dataframe

In [5]:
# read in cohort dataframe, with Number ID properly labeled
startTime = time.time()

cohortDf = spark.read.options(delimiter=',')\
        .csv('/cephfs/data/cohort.csv', header=True, schema=cohortSchema)\
        .withColumnRenamed('', 'NumId')

print(cohortDf.dtypes)
cohortDf.show(10)
print(time.time() - startTime)

[('NumId', 'int'), ('UserId', 'string'), ('Gender', 'string'), ('DOB', 'timestamp'), ('Age', 'int'), ('DiabetesType', 'string'), ('Treatment', 'string')]
+-----+--------------------+------+-------------------+---+------------+---------+
|NumId|              UserId|Gender|                DOB|Age|DiabetesType|Treatment|
+-----+--------------------+------+-------------------+---+------------+---------+
|    0|5lZPrCk6qk8L6Jw+S...|Female|1931-01-01 00:00:00| 92|    type-two|       no|
|    1|9qY9mZ+GV5Kd/O/NB...|  Male|1937-01-01 00:00:00| 86|    type-two|       no|
|    2|uhsyLhr4Zl6NfGbNB...|Female|1938-01-01 00:00:00| 85|    type-two|       no|
|    3|9uAVHBOgoCJ9hfcrL...|  Male|1938-01-01 00:00:00| 85|    type-two|       no|
|    4|Fyb156jU1edGykL7N...|Female|1939-01-01 00:00:00| 84|    type-two|       no|
|    5|86XfZ0fNI0VWOzWrl...|Female|1939-01-01 00:00:00| 84|    type-two|       no|
|    6|JfJMH1qCpiYNuPOp/...|Female|1940-01-01 00:00:00| 83|    type-two|       no|
|    7|EkW0PD80r

                                                                                

In [6]:
# make mini dataframe of the string IDs and number IDs
patientIds = cohortDf.select(col('UserId'), col('NumId')).distinct()
patientIds

DataFrame[UserId: string, NumId: int]

### Load in the raw data

In [7]:
# aaaaaaaaaaa??
startTime = time.time()

df = spark.read\
    .format('csv')\
    .option('delimiter', ',')\
    .option("mode", "DROPMALFORMED")\
    .option("header", True)\
    .schema(raw_schema)\
    .load(allPaths)\
    .select(col("PatientId"), col("Value"), \
            col("GlucoseDisplayTime"), col("GlucoseDisplayTimeRaw"))

print(time.time() - startTime)



5.827446699142456


                                                                                

### Clean-up

In [8]:
#add date and sort
startTime = time.time()

df = df.withColumn('GlucoseDisplayDate',
                       to_date(col('GlucoseDisplayTime')))

df = df.orderBy("PatientId", "GlucoseDisplayTime", ascending=True)

print(time.time() - startTime)

0.04919123649597168


In [9]:
# """the following cell is leslie's cleanup code, yoinked from fill-missing to read-data to here"""
startTime = time.time()

'''get rid of any dates from before the actual start-date of Feb 1, 2022'''
df = df.filter("GlucoseDisplayDate > date'2022-01-31'")

'''replace 0s with NaN and dropna'''
df = df.withColumn("Value", when(col("Value")=="0", None) \
                                                       .otherwise(col("Value")))
df = df.na.drop(subset=['PatientId','Value','GlucoseDisplayTime'])
# df = df.where(df.Value>0)

'''drop duplicate datetimes for each patient'''
window = Window.partitionBy('GlucoseDisplayTime','PatientId').orderBy('tiebreak')
df = (df
 .withColumn('tiebreak', monotonically_increasing_id())
 .withColumn('rank', rank().over(window))
 .filter(col('rank') == 1).drop('rank','tiebreak')
)

print(time.time() - startTime)

0.29024362564086914


### Add in (join) the NumId column

## Split into Train-Test-Val groups of 60-20-20

plan:
* get total count of values per patient
* make 2 cols:
    * int that's the total number of rows that should make up 60% of the patient's data
    * int that's the total number of rows that should make up 80% of the patient's data
* merge that dataframe back onto the main df dataframe
* (is it possible to filter based on another column?)
* add ranks based on patient ID
* filter once that splits at the 60% mark, asking for 'rank'< or > to the 60% mark
    * save the <60% as training
* filter the >60% again on the 80% mark
    * save the <80% as validation
    * save the >80% as test

somewhere in there, get rid of patients with less than 80% data

In [10]:
'''get total counts of values per patient'''
counter = df.groupBy('PatientId').count()
# print("row count: ", counter.count())

'''(?) filter out patients with too little usable data'''
minUsable = 0.80 * (60/5 * 24 * 365)
counter = counter.filter(col('count') < minUsable)
minUsable

84096.0

In [11]:
counter = counter.withColumn("split60",(col("count")* 0.6).cast("Integer"))
counter = counter.withColumn("split80",(col("count")* 0.8).cast("Integer"))

patientIds = patientIds.join(counter, patientIds.UserId == counter.PatientId)\
            .select(patientIds.NumId, patientIds.UserId, counter.split60, counter.split80)

In [12]:
'''get everything into order for ranking/sorting by 60%-20%-20%'''
joined = df.join(patientIds, df.PatientId == patientIds.UserId)\
            .select(patientIds.NumId, df.PatientId, df.Value, df.GlucoseDisplayTime, \
                    df.GlucoseDisplayTimeRaw, df.GlucoseDisplayDate, \
                    patientIds.split60, patientIds.split80)

window = Window.partitionBy('PatientId').orderBy('tiebreak')
joined = joined \
 .withColumn('tiebreak', monotonically_increasing_id()) \
 .withColumn('rank', rank().over(window))

In [13]:
'''training set'''
trainSet = joined.filter(col('rank') <= col('split60')) \
                 .drop('rank','tiebreak','split60','split80')

'''validation set'''
valSet = joined.filter((col('rank') > col('split60')) & (col('rank') <= col('split80'))) \
               .drop('rank','tiebreak','split60','split80')

'''test set'''
testSet = joined.filter(col('rank') > col('split80')) \
                .drop('rank','tiebreak','split60','split80')

## Save out into parquet files

In [14]:
'''godspeed'''
startTime = time.time()

# df.repartition('PatientId')\
#     .write.parquet('/cephfs/train_test_val/train/') 
trainSet.write.parquet('/cephfs/train_test_val/train_set/') 
valSet.write.parquet('/cephfs/train_test_val/val_set/') 
testSet.write.parquet('/cephfs/train_test_val/test_set/') 

print(time.time()-startTime)

                                                                                

45887.63158559799


In [16]:
45887/60/60

12.74638888888889