# Save raw data to 60-20-20 split for train-test-validation sets

In [1]:
import time
import pathlib
import pyspark
from pyspark.sql import SparkSession, Window
from pyspark.sql.functions import when, col, to_date, date_trunc, rank, monotonically_increasing_id, min, max
from pyspark.sql.types import StructType, StructField, StringType, IntegerType, TimestampType, DateType, FloatType

# spark.stop()

In [2]:
conf = pyspark.SparkConf().setAll([\
    ('spark.app.name', 'ReduceData')])
spark = SparkSession.builder.config(conf=conf)\
    .getOrCreate()

spark.version

Setting default log level to "WARN".
To adjust logging level use sc.setLogLevel(newLevel). For SparkR, use setLogLevel(newLevel).


23/05/17 03:49:51 WARN NativeCodeLoader: Unable to load native-hadoop library for your platform... using builtin-java classes where applicable
23/05/17 03:49:52 WARN Utils: Service 'SparkUI' could not bind on port 4040. Attempting port 4041.


'3.3.1'

### Structs we might need

In [3]:
glucose_data_schema=StructType([StructField('PatientId', StringType(), True),
                                StructField('Value', FloatType(), True),
                                StructField('GlucoseDisplayTime', TimestampType(), True),
                                StructField('GlucoseDisplayTimeRaw', StringType(), True),
                                StructField('GlucoseDisplayDate', DateType(), True),
                                StructField('NumId', IntegerType(), True)
                                ])

raw_schema=StructType([StructField('_c0', IntegerType(),True),
                                StructField('PostDate', TimestampType(),True),
                                StructField('IngestionDate', TimestampType(),True),
                                StructField('PostId', StringType(),True),
                                StructField('PostTime', TimestampType(), True),
                                StructField('PatientId', StringType(), True),
                                StructField('Stream', StringType(), True),
                                StructField('SequenceNumber', StringType(), True),
                                StructField('TransmitterNumber', StringType(), True),
                                StructField('ReceiverNumber', StringType(), True),
                                StructField('RecordedSystemTime', TimestampType(), True),
                                StructField('RecordedDisplayTime', TimestampType(), True),
                                StructField('RecordedDisplayTimeRaw', TimestampType(), True),
                                StructField('TransmitterId', StringType(), True),
                                StructField('TransmitterTime', StringType(), True),
                                StructField('GlucoseSystemTime', TimestampType(), True),
                                StructField('GlucoseDisplayTime', TimestampType(), True),
                                StructField('GlucoseDisplayTimeRaw', StringType(), True),
                                StructField('Value', FloatType(), True),
                                StructField('Status', StringType(), True),
                                StructField('TrendArrow', StringType(), True),
                                StructField('TrendRate', FloatType(), True),
                                StructField('IsBackFilled', StringType(), True),
                                StructField('InternalStatus', StringType(), True),
                                StructField('SessionStartTime', StringType(), True)])

cohortSchema = StructType([StructField('', IntegerType(), True),
                        StructField('UserId', StringType(), True),
                        StructField('Gender', StringType(), True),
                        StructField('DOB', TimestampType(), True),
                        StructField('Age', IntegerType(), True),
                        StructField('DiabetesType', StringType(), True),
                        StructField('Treatment', StringType(), True)
                        ])

### Get all paths in to read from

In [4]:
'''all CSVs of the raw data'''
# allPaths = [str(x) for x in list(pathlib.Path('/cephfs/data_by_patient').glob('*.parquet')) if 'part-00' in str(x)]
allPaths = [str(x) for x in list(pathlib.Path('/cephfs/data').glob('*.csv')) if 'glucose_records' in str(x)]

allPaths.sort()

# print("train length:", len(trainPaths), "\nvalidation length:", len(valPaths),"\ntest length:", len(testPaths))

# print(allPaths[-1])
allPaths = allPaths[:3]
len(allPaths)

3

### Read in the Cohort dataframe

In [5]:
# read in cohort dataframe, with Number ID properly labeled
startTime = time.time()

patientIds = spark.read.options(delimiter=',') \
                .csv('/cephfs/data/cohort.csv', header=True, schema=cohortSchema) \
                .withColumnRenamed('', 'NumId') \
                .select(col('UserId'), col('NumId')) \
                .distinct()

print(patientIds.dtypes)
patientIds.show(10)
print(time.time() - startTime)

[('UserId', 'string'), ('NumId', 'int')]


[Stage 0:>                                                          (0 + 1) / 1]

+--------------------+-----+
|              UserId|NumId|
+--------------------+-----+
|afuXDu4gswOv1nPz8...|  252|
|2xvLF6iyzUsM3KlN3...|  299|
|/me7Mcqd+uJvuwNzH...|  574|
|pRzpPZcuJxjdcDk9R...|  893|
|GMkzjcKvy/rP0iyhV...| 1579|
|ufcKFPML1EYMZBOmL...| 1929|
|2MjuVdaQH+LpaKwGN...| 1931|
|dJ8la8IA6j03CO4jS...| 2160|
|iQ0rMtZFN6kDkjr4G...| 2460|
|S/XdC0rnkBPEB1n6v...| 2656|
+--------------------+-----+
only showing top 10 rows

6.270871877670288


                                                                                

### Load in the raw data

In [5]:
startTime = time.time()

# df = spark.read \
#            .schema(glucose_data_schema) \
#            .format('parquet') \
#            .load(allPaths)
df = spark.read\
    .format('csv')\
    .option('delimiter', ',')\
    .option("mode", "DROPMALFORMED")\
    .option("header", True)\
    .schema(raw_schema)\
    .load(allPaths)\
    .select(col("PatientId"), col("Value"), \
            col("GlucoseDisplayTime"))

print(time.time() - startTime)

3.4299309253692627


In [None]:
startTime = time.time()

df.printSchema()
patIds = [i.PatientId for i in df.select('PatientId').distinct().collect()]
print(len(patIds), "total patients")
print("row count: ", df.count())
df.show()

print(time.time() - startTime)

root
 |-- PatientId: string (nullable = true)
 |-- Value: float (nullable = true)
 |-- GlucoseDisplayTime: timestamp (nullable = true)



                                                                                

6036 total patients


                                                                                

row count:  4994511
+--------------------+-----+--------------------+
|           PatientId|Value|  GlucoseDisplayTime|
+--------------------+-----+--------------------+
|1Jxgxke6R3Uh2c9aR...|  0.0|2022-02-01 14:45:...|
|toBStbTTYI2GU28Yd...|  0.0|2022-02-01 17:46:...|
|+XAhHhm+BkhqusxsZ...|  0.0|2022-02-01 14:58:...|
|+XAhHhm+BkhqusxsZ...|  0.0|2022-01-31 22:53:...|
|+XAhHhm+BkhqusxsZ...|  0.0|2022-01-31 22:38:...|
|+XAhHhm+BkhqusxsZ...|  0.0|2022-02-01 05:03:...|
|+XAhHhm+BkhqusxsZ...|  0.0|2022-02-01 09:33:...|
|+XAhHhm+BkhqusxsZ...|  0.0|2022-02-01 16:38:...|
|+XAhHhm+BkhqusxsZ...|  0.0|2022-01-31 20:58:...|
|+XAhHhm+BkhqusxsZ...|  0.0|2022-01-31 23:29:...|
|+XAhHhm+BkhqusxsZ...|  0.0|2022-02-01 03:48:...|
|+XAhHhm+BkhqusxsZ...|  0.0|2022-02-01 11:28:...|
|+XAhHhm+BkhqusxsZ...|  0.0|2022-01-31 20:08:...|
|+XAhHhm+BkhqusxsZ...|  0.0|2022-02-01 10:43:...|
|+XAhHhm+BkhqusxsZ...|  0.0|2022-02-01 06:43:...|
|+XAhHhm+BkhqusxsZ...|  0.0|2022-02-01 02:13:...|
|+XAhHhm+BkhqusxsZ...|  0.0|20

### Clean-up

In [20]:
# """the following cell is leslie's cleanup code, yoinked from fill-missing to read-data to here"""
startTime = time.time()

'''get rid of any dates from before the actual start-date of Feb 1, 2022'''
df = df.where("GlucoseDisplayTime > '2022-01-31 23:59:59'")

'''replace 0s with NaN and dropna'''
df = df.withColumn("Value", when(col("Value")=="0", None) \
                                                       .otherwise(col("Value")))
df = df.na.drop(subset=['PatientId','Value','GlucoseDisplayTime'])
# df = df.where(df.Value>0)

df = df.withColumn("GlucoseDisplayTime",
                   date_trunc("minute",
                   col("GlucoseDisplayTime")))

'''drop duplicate datetimes for each patient'''
window = Window.partitionBy('GlucoseDisplayTime','PatientId').orderBy('tiebreak')
df = (df
 .withColumn('tiebreak', monotonically_increasing_id())
 .withColumn('rank', rank().over(window))
 .filter(col('rank') == 1).drop('rank','tiebreak')
)

print(time.time() - startTime)

0.08103632926940918


In [8]:
startTime = time.time()

df.printSchema()
print("row count: ", df.count())
df.show()

print(time.time() - startTime)

root
 |-- PatientId: string (nullable = true)
 |-- Value: float (nullable = true)
 |-- GlucoseDisplayTime: timestamp (nullable = true)



                                                                                

row count:  4311678




+--------------------+-----+-------------------+
|           PatientId|Value| GlucoseDisplayTime|
+--------------------+-----+-------------------+
|/rAxYocQpbKaUml2y...|266.0|2022-02-01 00:00:00|
|1FAUCirnkLqrYiNWb...|177.0|2022-02-01 00:00:00|
|1VZ3RiHSjH9RIy9GB...|121.0|2022-02-01 00:00:00|
|1uUAsobV9i087qUq4...|295.0|2022-02-01 00:00:00|
|2Oc9FE04nl4AbfE/z...|188.0|2022-02-01 00:00:00|
|3A2/x042CHD09Plf7...|149.0|2022-02-01 00:00:00|
|9IkmK0geGi4BXGhT8...|168.0|2022-02-01 00:00:00|
|9urL8fduVWtONauiT...|116.0|2022-02-01 00:00:00|
|AxWf3T1YEzDIOYKik...|141.0|2022-02-01 00:00:00|
|CH4+gLYbwWF0+FUw1...| 73.0|2022-02-01 00:00:00|
|CapkbhvAry1quSf4I...|124.0|2022-02-01 00:00:00|
|DX0Pbdq+EBN4/PQIR...|222.0|2022-02-01 00:00:00|
|DY0Ni7z8FswFAdcSl...|186.0|2022-02-01 00:00:00|
|E5GUqb4eVAUBIU1EX...|314.0|2022-02-01 00:00:00|
|ER+gxKFxHvXyfESG6...|128.0|2022-02-01 00:00:00|
|F/JsFxzP/YSrZl3Fj...|125.0|2022-02-01 00:00:00|
|FgSnoYnhijtsbDFNb...|209.0|2022-02-01 00:00:00|
|GlILdY2V8rf7zvUqi..

                                                                                

### **CHECKPOINT**

In [9]:
startTime = time.time()

df = df.checkpoint()

print(time.time() - startTime)

[Stage 4:>                                                       (0 + 16) / 200]

23/05/16 23:57:08 ERROR Executor: Exception in task 5.0 in stage 4.0 (TID 3185)
java.lang.OutOfMemoryError: Java heap space
23/05/16 23:57:08 ERROR Executor: Exception in task 3.0 in stage 4.0 (TID 3183)
java.lang.OutOfMemoryError: Java heap space
	at org.apache.spark.util.collection.unsafe.sort.UnsafeSorterSpillReader.<init>(UnsafeSorterSpillReader.java:50)
	at org.apache.spark.util.collection.unsafe.sort.UnsafeSorterSpillWriter.getReader(UnsafeSorterSpillWriter.java:159)
	at org.apache.spark.util.collection.unsafe.sort.UnsafeExternalSorter.getSortedIterator(UnsafeExternalSorter.java:553)
	at org.apache.spark.sql.execution.UnsafeExternalRowSorter.sort(UnsafeExternalRowSorter.java:172)
	at org.apache.spark.sql.catalyst.expressions.GeneratedClass$GeneratedIteratorForCodegenStage2.processNext(Unknown Source)
	at org.apache.spark.sql.execution.BufferedRowIterator.hasNext(BufferedRowIterator.java:43)
	at org.apache.spark.sql.execution.WholeStageCodegenExec$$anon$1.hasNext(WholeStageCodegen

Py4JJavaError: An error occurred while calling o116.checkpoint.
: org.apache.spark.SparkException: Job aborted due to stage failure: Task 2 in stage 4.0 failed 1 times, most recent failure: Lost task 2.0 in stage 4.0 (TID 3182) (jupyter-ljoe-40ucsd-2eedu executor driver): java.lang.OutOfMemoryError: Java heap space

Driver stacktrace:
	at org.apache.spark.scheduler.DAGScheduler.failJobAndIndependentStages(DAGScheduler.scala:2672)
	at org.apache.spark.scheduler.DAGScheduler.$anonfun$abortStage$2(DAGScheduler.scala:2608)
	at org.apache.spark.scheduler.DAGScheduler.$anonfun$abortStage$2$adapted(DAGScheduler.scala:2607)
	at scala.collection.mutable.ResizableArray.foreach(ResizableArray.scala:62)
	at scala.collection.mutable.ResizableArray.foreach$(ResizableArray.scala:55)
	at scala.collection.mutable.ArrayBuffer.foreach(ArrayBuffer.scala:49)
	at org.apache.spark.scheduler.DAGScheduler.abortStage(DAGScheduler.scala:2607)
	at org.apache.spark.scheduler.DAGScheduler.$anonfun$handleTaskSetFailed$1(DAGScheduler.scala:1182)
	at org.apache.spark.scheduler.DAGScheduler.$anonfun$handleTaskSetFailed$1$adapted(DAGScheduler.scala:1182)
	at scala.Option.foreach(Option.scala:407)
	at org.apache.spark.scheduler.DAGScheduler.handleTaskSetFailed(DAGScheduler.scala:1182)
	at org.apache.spark.scheduler.DAGSchedulerEventProcessLoop.doOnReceive(DAGScheduler.scala:2860)
	at org.apache.spark.scheduler.DAGSchedulerEventProcessLoop.onReceive(DAGScheduler.scala:2802)
	at org.apache.spark.scheduler.DAGSchedulerEventProcessLoop.onReceive(DAGScheduler.scala:2791)
	at org.apache.spark.util.EventLoop$$anon$1.run(EventLoop.scala:49)
	at org.apache.spark.scheduler.DAGScheduler.runJob(DAGScheduler.scala:952)
	at org.apache.spark.SparkContext.runJob(SparkContext.scala:2228)
	at org.apache.spark.SparkContext.runJob(SparkContext.scala:2249)
	at org.apache.spark.SparkContext.runJob(SparkContext.scala:2281)
	at org.apache.spark.rdd.ReliableCheckpointRDD$.writeRDDToCheckpointDirectory(ReliableCheckpointRDD.scala:166)
	at org.apache.spark.rdd.ReliableRDDCheckpointData.doCheckpoint(ReliableRDDCheckpointData.scala:60)
	at org.apache.spark.rdd.RDDCheckpointData.checkpoint(RDDCheckpointData.scala:75)
	at org.apache.spark.rdd.RDD.$anonfun$doCheckpoint$1(RDD.scala:1906)
	at scala.runtime.java8.JFunction0$mcV$sp.apply(JFunction0$mcV$sp.java:23)
	at org.apache.spark.rdd.RDDOperationScope$.withScope(RDDOperationScope.scala:151)
	at org.apache.spark.rdd.RDD.doCheckpoint(RDD.scala:1896)
	at org.apache.spark.sql.Dataset.$anonfun$checkpoint$1(Dataset.scala:687)
	at org.apache.spark.sql.Dataset.$anonfun$withAction$2(Dataset.scala:3858)
	at org.apache.spark.sql.execution.QueryExecution$.withInternalError(QueryExecution.scala:510)
	at org.apache.spark.sql.Dataset.$anonfun$withAction$1(Dataset.scala:3856)
	at org.apache.spark.sql.execution.SQLExecution$.$anonfun$withNewExecutionId$6(SQLExecution.scala:109)
	at org.apache.spark.sql.execution.SQLExecution$.withSQLConfPropagated(SQLExecution.scala:169)
	at org.apache.spark.sql.execution.SQLExecution$.$anonfun$withNewExecutionId$1(SQLExecution.scala:95)
	at org.apache.spark.sql.SparkSession.withActive(SparkSession.scala:779)
	at org.apache.spark.sql.execution.SQLExecution$.withNewExecutionId(SQLExecution.scala:64)
	at org.apache.spark.sql.Dataset.withAction(Dataset.scala:3856)
	at org.apache.spark.sql.Dataset.checkpoint(Dataset.scala:678)
	at org.apache.spark.sql.Dataset.checkpoint(Dataset.scala:641)
	at java.base/jdk.internal.reflect.NativeMethodAccessorImpl.invoke0(Native Method)
	at java.base/jdk.internal.reflect.NativeMethodAccessorImpl.invoke(NativeMethodAccessorImpl.java:77)
	at java.base/jdk.internal.reflect.DelegatingMethodAccessorImpl.invoke(DelegatingMethodAccessorImpl.java:43)
	at java.base/java.lang.reflect.Method.invoke(Method.java:568)
	at py4j.reflection.MethodInvoker.invoke(MethodInvoker.java:244)
	at py4j.reflection.ReflectionEngine.invoke(ReflectionEngine.java:357)
	at py4j.Gateway.invoke(Gateway.java:282)
	at py4j.commands.AbstractCommand.invokeMethod(AbstractCommand.java:132)
	at py4j.commands.CallCommand.execute(CallCommand.java:79)
	at py4j.ClientServerConnection.waitForCommands(ClientServerConnection.java:182)
	at py4j.ClientServerConnection.run(ClientServerConnection.java:106)
	at java.base/java.lang.Thread.run(Thread.java:833)
Caused by: java.lang.OutOfMemoryError: Java heap space


### Add in (join) the NumId column

In [21]:
'''add numId to df'''
startTime = time.time()

df = df.join(patientIds, df.PatientId == patientIds.UserId)\
            .select(df.PatientId, patientIds.NumId, df.Value, df.GlucoseDisplayTime)

print(time.time() - startTime)

0.02255535125732422


## Split into Train-Test-Val groups of 60-20-20

plan:
* get total count of values per patient
* make 2 cols:
    * int that's the total number of rows that should make up 60% of the patient's data
    * int that's the total number of rows that should make up 80% of the patient's data
* merge that dataframe back onto the main df dataframe
* (is it possible to filter based on another column?)
* add ranks based on patient ID
* filter once that splits at the 60% mark, asking for 'rank'< or > to the 60% mark
    * save the <60% as training
* filter the >60% again on the 80% mark
    * save the <80% as validation
    * save the >80% as test

somewhere in there, get rid of patients with less than 80% data

In [22]:
'''get total counts of values per patient'''
counter = df.groupBy('NumId').count()
# print("row count: ", counter.count())

'''(?) filter out patients with too little usable data'''
minUsable = 0.80 * (60/5 * 24 * 365)
# temp = counter.filter(col('count') < minUsable)
print("the minimum number of glucose recordings we should be allowing are", minUsable)
# print("but there are only", temp.count(), "patients that have at least that much.")
# counter.write.parquet('temp_data/')

'''get the numbers that indicate where the splits between train-val and val-test will be'''
counter = counter.withColumn("split60",(col("count")* 0.6).cast("Integer"))
counter = counter.withColumn("split80",(col("count")* 0.8).cast("Integer"))
counter = counter.drop('count')

# '''if i want to merge both small dataframes first before merging on the big main df'''
# patientIds = patientIds.join(counter, patientIds.UserId == counter.PatientId)\
#             .select(patientIds.NumId, patientIds.UserId, counter.split60, counter.split80)

'''rename for future merge (will get an "ambiguous" error without this)'''
counter = counter.withColumnRenamed("NumId","UserId")

the minimum number of glucose recordings we should be allowing are 84096.0


In [23]:
'''get everything into order for ranking/sorting by 60%-20%-20%'''
startTime = time.time()

df = df.join(counter, df.NumId == counter.UserId)\
            .select(df.PatientId, df.NumId, df.Value, df.GlucoseDisplayTime, \
                    counter.split60, counter.split80)

df = df.orderBy("NumId", "GlucoseDisplayTime", ascending=True)

print(time.time() - startTime)

0.050591230392456055


In [24]:
startTime = time.time()

df.printSchema()
print("row count: ", df.count())
df.show()

print(time.time() - startTime)

root
 |-- PatientId: string (nullable = true)
 |-- NumId: integer (nullable = true)
 |-- Value: float (nullable = true)
 |-- GlucoseDisplayTime: timestamp (nullable = true)
 |-- split60: integer (nullable = true)
 |-- split80: integer (nullable = true)



                                                                                

row count:  4311678


                                                                                

+--------------------+-----+-----+-------------------+-------+-------+
|           PatientId|NumId|Value| GlucoseDisplayTime|split60|split80|
+--------------------+-----+-----+-------------------+-------+-------+
|5lZPrCk6qk8L6Jw+S...|    0|155.0|2022-02-01 00:02:00|    481|    641|
|5lZPrCk6qk8L6Jw+S...|    0|155.0|2022-02-01 00:07:00|    481|    641|
|5lZPrCk6qk8L6Jw+S...|    0|154.0|2022-02-01 00:12:00|    481|    641|
|5lZPrCk6qk8L6Jw+S...|    0|153.0|2022-02-01 00:17:00|    481|    641|
|5lZPrCk6qk8L6Jw+S...|    0|153.0|2022-02-01 00:22:00|    481|    641|
|5lZPrCk6qk8L6Jw+S...|    0|153.0|2022-02-01 00:27:00|    481|    641|
|5lZPrCk6qk8L6Jw+S...|    0|151.0|2022-02-01 00:32:00|    481|    641|
|5lZPrCk6qk8L6Jw+S...|    0|151.0|2022-02-01 00:37:00|    481|    641|
|5lZPrCk6qk8L6Jw+S...|    0|149.0|2022-02-01 00:42:00|    481|    641|
|5lZPrCk6qk8L6Jw+S...|    0|147.0|2022-02-01 00:47:00|    481|    641|
|5lZPrCk6qk8L6Jw+S...|    0|145.0|2022-02-01 00:52:00|    481|    641|
|5lZPr

In [26]:
'''rank, then next cell is filtering'''
startTime = time.time()

window = Window.partitionBy('PatientId').orderBy('GlucoseDisplayTime')
df = df \
 .withColumn('rank', rank().over(window))

print(time.time() - startTime)

0.02889847755432129


In [27]:
startTime = time.time()

df.printSchema()
print("row count: ", df.count())
df.show()

print(time.time() - startTime)

root
 |-- PatientId: string (nullable = true)
 |-- NumId: integer (nullable = true)
 |-- Value: float (nullable = true)
 |-- GlucoseDisplayTime: timestamp (nullable = true)
 |-- split60: integer (nullable = true)
 |-- split80: integer (nullable = true)
 |-- rank: integer (nullable = false)



                                                                                

row count:  4311678




+--------------------+-----+-----+-------------------+-------+-------+----+
|           PatientId|NumId|Value| GlucoseDisplayTime|split60|split80|rank|
+--------------------+-----+-----+-------------------+-------+-------+----+
|+Gr/1qOf9OWMa4LOL...| 4660|196.0|2022-02-01 00:02:00|    465|    620|   1|
|+Gr/1qOf9OWMa4LOL...| 4660|193.0|2022-02-01 00:07:00|    465|    620|   2|
|+Gr/1qOf9OWMa4LOL...| 4660|192.0|2022-02-01 00:12:00|    465|    620|   3|
|+Gr/1qOf9OWMa4LOL...| 4660|193.0|2022-02-01 00:17:00|    465|    620|   4|
|+Gr/1qOf9OWMa4LOL...| 4660|191.0|2022-02-01 00:22:00|    465|    620|   5|
|+Gr/1qOf9OWMa4LOL...| 4660|189.0|2022-02-01 00:27:00|    465|    620|   6|
|+Gr/1qOf9OWMa4LOL...| 4660|186.0|2022-02-01 00:32:00|    465|    620|   7|
|+Gr/1qOf9OWMa4LOL...| 4660|183.0|2022-02-01 00:37:00|    465|    620|   8|
|+Gr/1qOf9OWMa4LOL...| 4660|180.0|2022-02-01 00:42:00|    465|    620|   9|
|+Gr/1qOf9OWMa4LOL...| 4660|177.0|2022-02-01 00:47:00|    465|    620|  10|
|+Gr/1qOf9OW

                                                                                

In [28]:
'''training set'''
trainSet = df.filter(col('rank') <= col('split60')) \
                 .drop('rank','tiebreak','split60','split80')

'''validation set'''
valSet = df.filter((col('rank') > col('split60')) & (col('rank') <= col('split80'))) \
               .drop('rank','tiebreak','split60','split80')

'''test set'''
testSet = df.filter(col('rank') > col('split80')) \
                .drop('rank','tiebreak','split60','split80')

In [36]:
temp = trainSet.groupby('NumId')\
            .agg(max("GlucoseDisplayTime").alias("Max"),
                min("GlucoseDisplayTime").alias("Min"))
temp.show(5)



+-----+-------------------+-------------------+
|NumId|                Max|                Min|
+-----+-------------------+-------------------+
| 4818|2022-02-02 15:19:00|2022-02-01 00:04:00|
| 7754|2022-02-02 15:47:00|2022-02-01 00:02:00|
| 2235|2022-02-02 19:17:00|2022-02-01 11:07:00|
| 4219|2022-02-02 15:28:00|2022-02-01 00:03:00|
| 2711|2022-02-02 20:28:00|2022-02-01 00:03:00|
+-----+-------------------+-------------------+
only showing top 5 rows



                                                                                

In [37]:
temp = valSet.filter(col('NumId').isin([4818, 7754, 2235, 4219, 2711]))
temp = temp.groupby('NumId')\
            .agg(max("GlucoseDisplayTime").alias("Max"),
                min("GlucoseDisplayTime").alias("Min"))
temp.show()



+-----+-------------------+-------------------+
|NumId|                Max|                Min|
+-----+-------------------+-------------------+
| 4818|2022-02-03 04:24:00|2022-02-02 15:24:00|
| 7754|2022-02-03 04:47:00|2022-02-02 15:52:00|
| 2235|2022-02-03 04:18:00|2022-02-02 19:22:00|
| 4219|2022-02-03 04:38:00|2022-02-02 15:33:00|
| 2711|2022-02-03 07:33:00|2022-02-02 20:33:00|
+-----+-------------------+-------------------+



                                                                                

In [35]:
temp = testSet.filter(col('NumId').isin([4818, 7754, 2235, 4219, 2711]))
temp = temp.groupby('NumId')\
            .agg(max("GlucoseDisplayTime").alias("Max"),
                min("GlucoseDisplayTime").alias("Min"))
temp.show()



+-----+-------------------+-------------------+
|NumId|                Max|                Min|
+-----+-------------------+-------------------+
| 4818|2022-02-03 17:34:00|2022-02-03 04:29:00|
| 7754|2022-02-03 17:47:00|2022-02-03 04:52:00|
| 2235|2022-02-03 09:03:00|2022-02-03 04:23:00|
| 4219|2022-02-03 17:48:00|2022-02-03 04:43:00|
| 2711|2022-02-03 18:03:00|2022-02-03 07:38:00|
+-----+-------------------+-------------------+



                                                                                

## Save out into parquet files

In [15]:
'''godspeed'''
startTime = time.time()

# df.repartition('PatientId')\
#     .write.parquet('/cephfs/train_test_val/test/') 
trainSet.repartition('PatientId').write.parquet('/cephfs/train_test_val/train_set/')
valSet.repartition('PatientId').write.parquet('/cephfs/train_test_val/val_set/')
testSet.repartition('PatientId').write.parquet('/cephfs/train_test_val/test_set/')



23/05/16 23:18:14 ERROR Executor: Exception in task 11.0 in stage 8.0 (TID 6007)
java.lang.OutOfMemoryError: Java heap space
23/05/16 23:18:14 ERROR Executor: Exception in task 4.0 in stage 8.0 (TID 6000)
java.lang.OutOfMemoryError: Java heap space
23/05/16 23:18:14 ERROR Executor: Exception in task 7.0 in stage 8.0 (TID 6003)
java.lang.OutOfMemoryError: Java heap space
23/05/16 23:18:14 ERROR Executor: Exception in task 2.0 in stage 8.0 (TID 5998)
java.lang.OutOfMemoryError: Java heap space
	at org.apache.spark.util.collection.unsafe.sort.UnsafeSorterSpillReader.<init>(UnsafeSorterSpillReader.java:50)
	at org.apache.spark.util.collection.unsafe.sort.UnsafeSorterSpillWriter.getReader(UnsafeSorterSpillWriter.java:159)
	at org.apache.spark.util.collection.unsafe.sort.UnsafeExternalSorter.getSortedIterator(UnsafeExternalSorter.java:553)
	at org.apache.spark.sql.execution.UnsafeExternalRowSorter.sort(UnsafeExternalRowSorter.java:172)
	at org.apache.spark.sql.catalyst.expressions.GeneratedC



23/05/16 23:18:14 ERROR FileFormatWriter: Aborting job 414e7339-6f64-4779-88d6-28382fb4a3a3.
org.apache.spark.SparkException: Job aborted due to stage failure: Task 8 in stage 8.0 failed 1 times, most recent failure: Lost task 8.0 in stage 8.0 (TID 6004) (jupyter-ljoe-40ucsd-2eedu executor driver): java.lang.OutOfMemoryError: Java heap space

Driver stacktrace:
	at org.apache.spark.scheduler.DAGScheduler.failJobAndIndependentStages(DAGScheduler.scala:2672)
	at org.apache.spark.scheduler.DAGScheduler.$anonfun$abortStage$2(DAGScheduler.scala:2608)
	at org.apache.spark.scheduler.DAGScheduler.$anonfun$abortStage$2$adapted(DAGScheduler.scala:2607)
	at scala.collection.mutable.ResizableArray.foreach(ResizableArray.scala:62)
	at scala.collection.mutable.ResizableArray.foreach$(ResizableArray.scala:55)
	at scala.collection.mutable.ArrayBuffer.foreach(ArrayBuffer.scala:49)
	at org.apache.spark.scheduler.DAGScheduler.abortStage(DAGScheduler.scala:2607)
	at org.apache.spark.scheduler.DAGScheduler

Py4JJavaError: An error occurred while calling o200.parquet.
: org.apache.spark.SparkException: Job aborted.
	at org.apache.spark.sql.errors.QueryExecutionErrors$.jobAbortedError(QueryExecutionErrors.scala:651)
	at org.apache.spark.sql.execution.datasources.FileFormatWriter$.write(FileFormatWriter.scala:278)
	at org.apache.spark.sql.execution.datasources.InsertIntoHadoopFsRelationCommand.run(InsertIntoHadoopFsRelationCommand.scala:186)
	at org.apache.spark.sql.execution.command.DataWritingCommandExec.sideEffectResult$lzycompute(commands.scala:113)
	at org.apache.spark.sql.execution.command.DataWritingCommandExec.sideEffectResult(commands.scala:111)
	at org.apache.spark.sql.execution.command.DataWritingCommandExec.executeCollect(commands.scala:125)
	at org.apache.spark.sql.execution.QueryExecution$$anonfun$eagerlyExecuteCommands$1.$anonfun$applyOrElse$1(QueryExecution.scala:98)
	at org.apache.spark.sql.execution.SQLExecution$.$anonfun$withNewExecutionId$6(SQLExecution.scala:109)
	at org.apache.spark.sql.execution.SQLExecution$.withSQLConfPropagated(SQLExecution.scala:169)
	at org.apache.spark.sql.execution.SQLExecution$.$anonfun$withNewExecutionId$1(SQLExecution.scala:95)
	at org.apache.spark.sql.SparkSession.withActive(SparkSession.scala:779)
	at org.apache.spark.sql.execution.SQLExecution$.withNewExecutionId(SQLExecution.scala:64)
	at org.apache.spark.sql.execution.QueryExecution$$anonfun$eagerlyExecuteCommands$1.applyOrElse(QueryExecution.scala:98)
	at org.apache.spark.sql.execution.QueryExecution$$anonfun$eagerlyExecuteCommands$1.applyOrElse(QueryExecution.scala:94)
	at org.apache.spark.sql.catalyst.trees.TreeNode.$anonfun$transformDownWithPruning$1(TreeNode.scala:584)
	at org.apache.spark.sql.catalyst.trees.CurrentOrigin$.withOrigin(TreeNode.scala:176)
	at org.apache.spark.sql.catalyst.trees.TreeNode.transformDownWithPruning(TreeNode.scala:584)
	at org.apache.spark.sql.catalyst.plans.logical.LogicalPlan.org$apache$spark$sql$catalyst$plans$logical$AnalysisHelper$$super$transformDownWithPruning(LogicalPlan.scala:30)
	at org.apache.spark.sql.catalyst.plans.logical.AnalysisHelper.transformDownWithPruning(AnalysisHelper.scala:267)
	at org.apache.spark.sql.catalyst.plans.logical.AnalysisHelper.transformDownWithPruning$(AnalysisHelper.scala:263)
	at org.apache.spark.sql.catalyst.plans.logical.LogicalPlan.transformDownWithPruning(LogicalPlan.scala:30)
	at org.apache.spark.sql.catalyst.plans.logical.LogicalPlan.transformDownWithPruning(LogicalPlan.scala:30)
	at org.apache.spark.sql.catalyst.trees.TreeNode.transformDown(TreeNode.scala:560)
	at org.apache.spark.sql.execution.QueryExecution.eagerlyExecuteCommands(QueryExecution.scala:94)
	at org.apache.spark.sql.execution.QueryExecution.commandExecuted$lzycompute(QueryExecution.scala:81)
	at org.apache.spark.sql.execution.QueryExecution.commandExecuted(QueryExecution.scala:79)
	at org.apache.spark.sql.execution.QueryExecution.assertCommandExecuted(QueryExecution.scala:116)
	at org.apache.spark.sql.DataFrameWriter.runCommand(DataFrameWriter.scala:860)
	at org.apache.spark.sql.DataFrameWriter.saveToV1Source(DataFrameWriter.scala:390)
	at org.apache.spark.sql.DataFrameWriter.saveInternal(DataFrameWriter.scala:363)
	at org.apache.spark.sql.DataFrameWriter.save(DataFrameWriter.scala:239)
	at org.apache.spark.sql.DataFrameWriter.parquet(DataFrameWriter.scala:793)
	at java.base/jdk.internal.reflect.NativeMethodAccessorImpl.invoke0(Native Method)
	at java.base/jdk.internal.reflect.NativeMethodAccessorImpl.invoke(NativeMethodAccessorImpl.java:77)
	at java.base/jdk.internal.reflect.DelegatingMethodAccessorImpl.invoke(DelegatingMethodAccessorImpl.java:43)
	at java.base/java.lang.reflect.Method.invoke(Method.java:568)
	at py4j.reflection.MethodInvoker.invoke(MethodInvoker.java:244)
	at py4j.reflection.ReflectionEngine.invoke(ReflectionEngine.java:357)
	at py4j.Gateway.invoke(Gateway.java:282)
	at py4j.commands.AbstractCommand.invokeMethod(AbstractCommand.java:132)
	at py4j.commands.CallCommand.execute(CallCommand.java:79)
	at py4j.ClientServerConnection.waitForCommands(ClientServerConnection.java:182)
	at py4j.ClientServerConnection.run(ClientServerConnection.java:106)
	at java.base/java.lang.Thread.run(Thread.java:833)
Caused by: org.apache.spark.SparkException: Job aborted due to stage failure: Task 8 in stage 8.0 failed 1 times, most recent failure: Lost task 8.0 in stage 8.0 (TID 6004) (jupyter-ljoe-40ucsd-2eedu executor driver): java.lang.OutOfMemoryError: Java heap space

Driver stacktrace:
	at org.apache.spark.scheduler.DAGScheduler.failJobAndIndependentStages(DAGScheduler.scala:2672)
	at org.apache.spark.scheduler.DAGScheduler.$anonfun$abortStage$2(DAGScheduler.scala:2608)
	at org.apache.spark.scheduler.DAGScheduler.$anonfun$abortStage$2$adapted(DAGScheduler.scala:2607)
	at scala.collection.mutable.ResizableArray.foreach(ResizableArray.scala:62)
	at scala.collection.mutable.ResizableArray.foreach$(ResizableArray.scala:55)
	at scala.collection.mutable.ArrayBuffer.foreach(ArrayBuffer.scala:49)
	at org.apache.spark.scheduler.DAGScheduler.abortStage(DAGScheduler.scala:2607)
	at org.apache.spark.scheduler.DAGScheduler.$anonfun$handleTaskSetFailed$1(DAGScheduler.scala:1182)
	at org.apache.spark.scheduler.DAGScheduler.$anonfun$handleTaskSetFailed$1$adapted(DAGScheduler.scala:1182)
	at scala.Option.foreach(Option.scala:407)
	at org.apache.spark.scheduler.DAGScheduler.handleTaskSetFailed(DAGScheduler.scala:1182)
	at org.apache.spark.scheduler.DAGSchedulerEventProcessLoop.doOnReceive(DAGScheduler.scala:2860)
	at org.apache.spark.scheduler.DAGSchedulerEventProcessLoop.onReceive(DAGScheduler.scala:2802)
	at org.apache.spark.scheduler.DAGSchedulerEventProcessLoop.onReceive(DAGScheduler.scala:2791)
	at org.apache.spark.util.EventLoop$$anon$1.run(EventLoop.scala:49)
Caused by: java.lang.OutOfMemoryError: Java heap space


In [None]:
endTime = time.time()
endTime-startTime

In [16]:
# endTime = time.time()
# endTime-startTime

5.854722222222223

In [15]:
17300/60/60

4.805555555555555

In [16]:
45887/60/60

12.74638888888889