In [16]:
import os
from pyspark.sql import SparkSession
from pyspark.sql.functions import *
from pyspark.sql.types import *

import pandas as pd
import pandasql as ps
import pixiedust

import pyarrow.parquet as pq
import pyarrow as pa
!pip install duckdb

import duckdb

import numpy as np
import matplotlib.pyplot as plt


pixiedust.enableJobMonitor()

con = duckdb.connect()

0,1,2
▸,:,


Please see https://github.com/pypa/pip/issues/5599 for advice on fixing the underlying issue.
To avoid this problem you can invoke Python with '-m pip' instead of running pip directly.
Spark Job Progress Monitor already enabled


In [17]:
# Load datasets
cohort_new = spark.read.parquet('file:/home/z_han/work/Oklahoma State/Zheng_Han/epilepsy/allcohort2024')
cohort_demo = spark.read.parquet('file:/home/z_han/work/Oklahoma State/Zheng_Han/epilepsy/allcohort_new_demo')
cohort_como = spark.read.parquet('file:/home/z_han/work/Oklahoma State/Zheng_Han/epilepsy/allcohort_new_como')


0,1,2
▸,:,


In [18]:
cohort_new.printSchema()

0,1,2
▸,:,


root
 |-- personid: string (nullable = true)
 |-- date: string (nullable = true)
 |-- EPI: integer (nullable = true)



In [19]:
# Preprocess comorbidity data
cohort_como_reformat = cohort_como.withColumn(
    'comorbidityid', expr("regexp_replace(comorbidityid, '([0-9a-zA-Z]+\\.[0-9a-zA-Z]).*', '$1')").cast("string"))


0,1,2
▸,:,


In [20]:
def get_top_comorbidity(spark, como, topn):
    top_como = como.groupBy('comorbidityid').count().orderBy(col('count').desc()) \
        .select(collect_list('comorbidityid')).collect()[0] \
        .__getitem__('collect_list(comorbidityid)')[0:topn]
    return top_como

0,1,2
▸,:,


In [21]:
cohort_como_length = cohort_como_reformat.join(cohort_new.select('personid','date','EPI'), 'personid') \
    .withColumn('len_como', months_between(col('date'), col('effectivedate')) / 12)


0,1,2
▸,:,


In [22]:
cohort_como_length.printSchema()

0,1,2
▸,:,


root
 |-- personid: string (nullable = true)
 |-- encounterid: string (nullable = true)
 |-- comorbidityid: string (nullable = true)
 |-- effectivedate: string (nullable = true)
 |-- date: string (nullable = true)
 |-- EPI: integer (nullable = true)
 |-- len_como: double (nullable = true)



In [23]:
# Preprocess demographic data
cohort_demo_processed = cohort_demo.withColumn('age_at_diagnosis', months_between(col('date'), col('birthdate.value')) / 12).drop('birthdate')


0,1,2
▸,:,


In [24]:
cohort_demo_processed.printSchema()

0,1,2
▸,:,


root
 |-- personid: string (nullable = true)
 |-- gender: string (nullable = true)
 |-- date: string (nullable = true)
 |-- race: string (nullable = true)
 |-- age_at_diagnosis: double (nullable = true)



In [25]:
# Encode gender and race
df = cohort_demo_processed.withColumn("gender",
    when(lower(col("gender")).isin("male"), 1).when(lower(col("gender")).isin("female"), 2)
    .when(col("gender").isNull(), 0).otherwise(3))

0,1,2
▸,:,


In [26]:
df = df.withColumn("race",
    when(lower(col("race")).like("%white%") | lower(col("race")).like("%caucasian%"), 1)
    .when(lower(col("race")).like("%black%") | lower(col("race")).like("%african%"), 2)
    .when(lower(col("race")).like("%hispanic%"), 3)
    .when(lower(col("race")).like("%asian%") | lower(col("race")).like("%chinese%") | lower(col("race")).like("%korean%") | lower(col("race")).like("%japanese%"), 4)
    .when(lower(col("race")).like("%indian american%") | lower(col("race")).like("%native american%"), 5)
    .when(col("race").isNull(), 0).otherwise(6))

0,1,2
▸,:,


In [27]:
cohort_demo_processed = df

0,1,2
▸,:,


In [28]:
def create_comorbidity_features(cohort_como_length_filter, top_comorbidity):
    # Trim whitespaces in comorbidityid to ensure consistent formatting
    cohort_como_length_filter = cohort_como_length_filter.withColumn("comorbidityid", trim(col("comorbidityid")))

    # Create comorbidity columns based on the presence of each comorbidity in the data
    for comorbidity in top_comorbidity:
        # Use backticks for the comorbidity value in case of special characters
        cohort_como_length_filter = cohort_como_length_filter.withColumn(
            f"has_{comorbidity.replace('.', '_')}",  # Replace '.' with '_' for safe column naming
            when(col("comorbidityid") == comorbidity, col("len_como")).otherwise(None)
        )

    # Define the new list of comorbidity columns to select
    comorbidity_cols = [f"has_{comorbidity.replace('.', '_')}" for comorbidity in top_comorbidity]

    # Select personid and the newly created comorbidity columns
    cohort_como_features = cohort_como_length_filter.select(
        "personid",
        *[col(f"has_{comorbidity.replace('.', '_')}") for comorbidity in top_comorbidity]
    )

    # Group by personid to aggregate comorbidity values using max
    cohort_como_features = cohort_como_features.groupBy("personid").agg(
        *[max(col(f"has_{comorbidity.replace('.', '_')}")).alias(f"has_{comorbidity.replace('.', '_')}") for comorbidity in top_comorbidity]
    )

    return cohort_como_features


0,1,2
▸,:,


In [29]:
spark

0,1,2
▸,:,


In [36]:
from pyspark.ml.feature import VectorAssembler, StandardScaler
from pyspark.ml.classification import DecisionTreeClassifier, RandomForestClassifier, LogisticRegression, GBTClassifier
from pyspark.ml.evaluation import BinaryClassificationEvaluator, MulticlassClassificationEvaluator
from pyspark.sql.functions import col, when, lit

def train_baseline_models_sparkml(cohort_demo_processed, cohort_como_length, topn):
    print('-----preparing data--------')

    top_comorbidity = get_top_comorbidity(spark, cohort_como_reformat, topn)
    cohort_como_length_filter = cohort_como_length.filter(col('comorbidityid').isin(top_comorbidity))
    cohort_comorbidity_features = create_comorbidity_features(cohort_como_length_filter, top_comorbidity)
    cohort_demo_processed = cohort_demo_processed.join(cohort_comorbidity_features, 'personid')

    x_person = cohort_demo_processed.drop('date')
    x_person = x_person.join(cohort_new.select('personid', 'EPI'), 'personid', 'inner')

    # Ensure label column is present
    data = x_person.withColumnRenamed("EPI", "label")

    # Drop unnecessary columns to prepare feature set
    feature_columns = [col for col in data.columns if col not in ["personid", "date", "label"]]

    # Handle missing values by replacing NaNs with a unique value
    data = data.fillna({col: -1 for col in feature_columns})

    # Assemble features after handling NaN values
    assembler = VectorAssembler(inputCols=feature_columns, outputCol="features_assembled")
    data = assembler.transform(data)

    # Ensure label column is properly set
    data = data.select("features_assembled", "label")
    
    # Split data into training and testing sets
    train_data, test_data = data.randomSplit([0.8, 0.2], seed=42)

    # Compute weights for imbalanced classes
    positive_count = train_data.filter(col("label") == 1).count()
    negative_count = train_data.filter(col("label") == 0).count()
    total_count = positive_count + negative_count

    weight_positive = total_count / (2 * positive_count)
    weight_negative = total_count / (2 * negative_count)

    train_data = train_data.withColumn("classWeight", when(col("label") == 1, weight_positive).otherwise(weight_negative))

    

    # Scale the features
    scaler = StandardScaler(inputCol="features_assembled", outputCol="scaledFeatures", withMean=True, withStd=True)
    scaler_model = scaler.fit(train_data)
    train_data = scaler_model.transform(train_data)
    test_data = scaler_model.transform(test_data)

    def evaluate_model(model, train_data, test_data):
        model = model.fit(train_data)
        train_predictions = model.transform(train_data)
        test_predictions = model.transform(test_data)

        # Evaluate AUC
        evaluator = BinaryClassificationEvaluator(labelCol="label", rawPredictionCol="prediction")
        train_auc = evaluator.evaluate(train_predictions, {evaluator.metricName: "areaUnderROC"})
        test_auc = evaluator.evaluate(test_predictions, {evaluator.metricName: "areaUnderROC"})
        print(f"Training AUC: {train_auc}")
        print(f"Test AUC: {test_auc}")
        
        evaluator.setMetricName("areaUnderPR")
        train_prauc = evaluator.evaluate(train_predictions)
        test_prauc = evaluator.evaluate(test_predictions)
        print(f"Training PRAUC: {train_prauc}")
        print(f"Test PRAUC: {test_prauc}")

        # Evaluate Accuracy
        accuracy_evaluator = MulticlassClassificationEvaluator(labelCol="label", predictionCol="prediction", metricName="accuracy")
        train_accuracy = accuracy_evaluator.evaluate(train_predictions)
        test_accuracy = accuracy_evaluator.evaluate(test_predictions)
        print(f"Training Accuracy: {train_accuracy}")
        print(f"Test Accuracy: {test_accuracy}")

        # Compute confusion matrix elements
        tp = test_predictions.filter((col("prediction") == 1) & (col("label") == 1)).count()
        tn = test_predictions.filter((col("prediction") == 0) & (col("label") == 0)).count()
        fp = test_predictions.filter((col("prediction") == 1) & (col("label") == 0)).count()
        fn = test_predictions.filter((col("prediction") == 0) & (col("label") == 1)).count()

        print(f"Confusion Matrix: TP={tp}, TN={tn}, FP={fp}, FN={fn}")

        # Calculate sensitivity and specificity
        sensitivity = tp / (tp + fn) if (tp + fn) > 0 else 0
        specificity = tn / (tn + fp) if (tn + fp) > 0 else 0

        print(f"Sensitivity (Recall): {sensitivity}")
        print(f"Specificity: {specificity}")

    print('---------training decision tree-----------')
    dt_classifier = DecisionTreeClassifier(featuresCol="scaledFeatures", labelCol="label", weightCol="classWeight", seed=45)
    evaluate_model(dt_classifier, train_data, test_data)

    print('---------training random forest-----------')
    rf_classifier = RandomForestClassifier(featuresCol="scaledFeatures", labelCol="label", weightCol="classWeight", seed=42)
    evaluate_model(rf_classifier, train_data, test_data)

    print('---------training logistic regression-----------')
    lr_classifier = LogisticRegression(featuresCol="scaledFeatures", labelCol="label", weightCol="classWeight", maxIter=100, regParam=0.01)
    evaluate_model(lr_classifier, train_data, test_data)
    
    print('---------training gradient boosting tree-----------')
    gbt_classifier = GBTClassifier(featuresCol="scaledFeatures", labelCol="label", maxIter=100)
    evaluate_model(gbt_classifier, train_data, test_data)


    print('---------training multilayer perceptron-----------')
    from pyspark.ml.classification import MultilayerPerceptronClassifier
    mlp_classifier = MultilayerPerceptronClassifier(featuresCol="scaledFeatures", labelCol="label", maxIter=100, layers=[len(feature_columns), 64, 32, 2], blockSize=128, seed=123)
    evaluate_model(mlp_classifier, train_data, test_data)

    print('---------training support vector machine-----------')
    from pyspark.ml.classification import LinearSVC
    svm_classifier = LinearSVC(featuresCol="scaledFeatures", labelCol="label", weightCol="classWeight", maxIter=100, regParam=0.01)
    evaluate_model(svm_classifier, train_data, test_data)


    print('---completed------')


0,1,2
▸,:,


In [35]:
train_baseline_models_sparkml(cohort_demo_processed,cohort_como_length, 20)

0,1,2
▸,:,


-----preparing data--------
---------training decision tree-----------
Training AUC: 0.7064949123481028
Test AUC: 0.7028155360525713
Training Accuracy: 0.6984619730841231
Test Accuracy: 0.6978489244622311
Confusion Matrix: TP=4240, TN=103175, FP=44761, FN=1747
Sensitivity (Recall): 0.7082011023885084
Specificity: 0.6974299697166342
---------training random forest-----------
Training AUC: 0.7200667076129935
Test AUC: 0.7171251821760254
Training Accuracy: 0.6747315799642486
Test Accuracy: 0.6732911910500705
Confusion Matrix: TP=4578, TN=99057, FP=48879, FN=1409
Sensitivity (Recall): 0.7646567563053283
Specificity: 0.6695936080467229
---------training logistic regression-----------
Training AUC: 0.6530001360885425
Test AUC: 0.650051585265624
Training Accuracy: 0.6468918802634762
Test Accuracy: 0.6444845799523138
Confusion Matrix: TP=3928, TN=95273, FP=52663, FN=2059
Sensitivity (Recall): 0.6560881910806748
Specificity: 0.6440149794505732
---------training gradient boosting tree-----------

ImportError: cannot import name 'KNNClassifier' from 'pyspark.ml.classification' (/usr/local/spark-3.4.4-bin-hadoop3/python/pyspark/ml/classification.py)

In [37]:
train_baseline_models_sparkml(cohort_demo_processed,cohort_como_length, 50)

0,1,2
▸,:,


-----preparing data--------
---------training decision tree-----------
Training AUC: 0.6864270161095266
Test AUC: 0.6874686752847758
Training Accuracy: 0.8053803867527448
Test Accuracy: 0.806543406832469
Confusion Matrix: TP=3927, TN=157347, FP=35590, FN=3093
Sensitivity (Recall): 0.5594017094017094
Specificity: 0.8155356411678424
---------training random forest-----------
Training AUC: 0.7125517985509491
Test AUC: 0.7121705851992164
Training Accuracy: 0.7782025933304374
Test Accuracy: 0.7787174242462129
Confusion Matrix: TP=4497, TN=151213, FP=41724, FN=2523
Sensitivity (Recall): 0.6405982905982905
Specificity: 0.783742879800142
---------training logistic regression-----------
Training AUC: 0.6713523829293934
Test AUC: 0.6686864807308268
Training Accuracy: 0.6861828013361094
Test Accuracy: 0.687517816330511
Confusion Matrix: TP=4552, TN=132922, FP=60015, FN=2468
Sensitivity (Recall): 0.6484330484330484
Specificity: 0.6889399130286052
---------training gradient boosting tree-----------

In [38]:
train_baseline_models_sparkml(cohort_demo_processed,cohort_como_length, 100)

0,1,2
▸,:,


-----preparing data--------
---------training decision tree-----------
Training AUC: 0.676246145545819
Test AUC: 0.6774451359168702
Training Accuracy: 0.8199171914972312
Test Accuracy: 0.820475177791498
Confusion Matrix: TP=4008, TN=175393, FP=35608, FN=3646
Sensitivity (Recall): 0.5236477658740528
Specificity: 0.8312425059596874
---------training random forest-----------
Training AUC: 0.7038232011624732
Test AUC: 0.709746037201272
Training Accuracy: 0.7873303219074951
Test Accuracy: 0.7890146577942421
Confusion Matrix: TP=4780, TN=167742, FP=43259, FN=2874
Sensitivity (Recall): 0.6245100600992944
Specificity: 0.7949820143032498
---------training logistic regression-----------
Training AUC: 0.680175248451147
Test AUC: 0.6867466121492901
Training Accuracy: 0.7057227364933898
Test Accuracy: 0.7073243237062953
Confusion Matrix: TP=5087, TN=149573, FP=61428, FN=2567
Sensitivity (Recall): 0.6646198066370526
Specificity: 0.7088734176615277
---------training gradient boosting tree-----------


In [40]:
train_baseline_models_sparkml(cohort_demo_processed,cohort_como_length, 200)

0,1,2
▸,:,


-----preparing data--------
---------training decision tree-----------
Training AUC: 0.735510848347608
Test AUC: 0.7341349310081899
Training Accuracy: 0.8697564921397338
Test Accuracy: 0.8689867769935913
Confusion Matrix: TP=4880, TN=198648, FP=27280, FN=3405
Sensitivity (Recall): 0.5890162945081473
Specificity: 0.8792535675082327
---------training random forest-----------
Training AUC: 0.7501883124656592
Test AUC: 0.747204766363794
Training Accuracy: 0.8386777743792942
Test Accuracy: 0.8385700195975458
Confusion Matrix: TP=5376, TN=191028, FP=34900, FN=2909
Sensitivity (Recall): 0.6488835244417622
Specificity: 0.8455260082858256
---------training logistic regression-----------
Training AUC: 0.7436294746530319
Test AUC: 0.7431827999229923
Training Accuracy: 0.803459774684092
Test Accuracy: 0.802658264058784
Confusion Matrix: TP=5627, TN=182366, FP=43562, FN=2658
Sensitivity (Recall): 0.6791792395896198
Specificity: 0.8071863602563648
---------training gradient boosting tree-----------


In [43]:
train_baseline_models_sparkml(cohort_demo_processed,cohort_como_length, 500)

0,1,2
▸,:,


-----preparing data--------


Exception ignored in: <function JavaWrapper.__del__ at 0x7fa50d1dc290>
Traceback (most recent call last):
  File "/usr/local/spark-3.4.4-bin-hadoop3/python/pyspark/ml/wrapper.py", line 53, in __del__
    if SparkContext._active_spark_context and self._java_obj is not None:
AttributeError: 'MultilayerPerceptronClassifier' object has no attribute '_java_obj'


---------training decision tree-----------
Training AUC: 0.7496369212130346
Test AUC: 0.750183598126805
Training Accuracy: 0.931933347714424
Test Accuracy: 0.931873390068147
Confusion Matrix: TP=4822, TN=223815, FP=12843, FN=3872
Sensitivity (Recall): 0.5546353807223372
Specificity: 0.945731815531273
---------training random forest-----------
Training AUC: 0.7744413535023054
Test AUC: 0.7732597325860141
Training Accuracy: 0.8797920575677777
Test Accuracy: 0.8794588998663145
Confusion Matrix: TP=5729, TN=210048, FP=26610, FN=2965
Sensitivity (Recall): 0.6589602024384633
Specificity: 0.8875592627335649
---------training logistic regression-----------
Training AUC: 0.7702854549583009
Test AUC: 0.7698998247514047
Training Accuracy: 0.8496690992417366
Test Accuracy: 0.8498932146467116
Confusion Matrix: TP=5945, TN=202578, FP=34080, FN=2749
Sensitivity (Recall): 0.6838049229353578
Specificity: 0.8559947265674518
---------training gradient boosting tree-----------
Training AUC: 0.677207935517