#### Setup For Colab


In [None]:
!apt-get install openjdk-8-jdk-headless -qq > /dev/null
!wget -q http://apache.mirror.amaze.com.au/spark/spark-3.0.2/spark-3.0.2-bin-hadoop3.2.tgz
!tar xf spark-3.0.2-bin-hadoop3.2.tgz
!pip install -q findspark

E: Failed to fetch http://security.ubuntu.com/ubuntu/pool/universe/o/openjdk-8/openjdk-8-jre-headless_8u282-b08-0ubuntu1~18.04_amd64.deb  404  Not Found [IP: 91.189.88.142 80]
E: Failed to fetch http://security.ubuntu.com/ubuntu/pool/universe/o/openjdk-8/openjdk-8-jdk-headless_8u282-b08-0ubuntu1~18.04_amd64.deb  404  Not Found [IP: 91.189.88.142 80]
E: Unable to fetch some archives, maybe run apt-get update or try with --fix-missing?


In [None]:
import os
os.environ["JAVA_HOME"] = "/usr/lib/jvm/java-8-openjdk-amd64"
os.environ["SPARK_HOME"] = "spark-3.0.2-bin-hadoop3.2"

import findspark
findspark.init()

from pyspark.sql import SparkSession
spark = SparkSession.builder.master("local[*]").getOrCreate()
sc = spark.sparkContext
sc

In [None]:
from google.colab import drive
drive.mount('/content/drive')

Drive already mounted at /content/drive; to attempt to forcibly remount, call drive.mount("/content/drive", force_remount=True).


#### Load Data from CSV

In [None]:
patients = spark.read.csv("/content/drive/MyDrive/synthea/patients.csv", header=True)
conditions = spark.read.csv("/content/drive/MyDrive/synthea/conditions.csv", header=True)
procedures = spark.read.csv("/content/drive/MyDrive/synthea/procedures.csv", header=True)

Setup tables for SparkSQL queries

In [None]:
import pyspark.sql
from pyspark.sql import Row

patients.createOrReplaceTempView("patients")
conditions.createOrReplaceTempView("conditions")
procedures.createOrReplaceTempView("procedures")

#### Create a Table for Analysis

In [None]:
data_1 = spark.sql("""
WITH T1 AS (
    SELECT X.Id, random(1234) AS RAND, X.BIRTHDATE, X.DEATHDATE, 
        2020 - CAST(substr(X.BIRTHDATE, 0, 4) AS NUMERIC) AS AGE,
        CASE WHEN isnull(X.DEATHDATE) THEN 0 ELSE 1 END AS DEATH, 
        X.RACE, X.ETHNICITY, X.GENDER,
        Y1.FEVER_FLAG, Y2.COUGH_FLAG, Y3.VENT_FLAG, Y4.OXYGEN_FLAG
    FROM patients X
    LEFT JOIN (
        SELECT DISTINCT Patient, 1 AS FEVER_FLAG 
        FROM conditions WHERE CODE = 386661006
        ) Y1 ON X.Id = Y1.Patient
    LEFT JOIN (
        SELECT DISTINCT Patient, 1 AS COUGH_FLAG 
        FROM conditions WHERE CODE = 49727002
        ) Y2 ON X.Id = Y2.Patient
    LEFT JOIN (
        SELECT DISTINCT Patient, 1 AS VENT_FLAG 
        FROM procedures WHERE CODE = 26763009
        ) Y3 ON X.Id = Y3.Patient
    LEFT JOIN (
        SELECT DISTINCT Patient, 1 AS OXYGEN_FLAG 
        FROM procedures WHERE CODE = 371908008
        ) Y4 ON X.Id = Y4.Patient
)
SELECT Id, RAND, DEATH, CAST(AGE AS INT) AGE, 
    CASE WHEN RACE = 'asian' THEN 1 ELSE 0 END AS ASIAN_FLAG,
    CASE WHEN RACE = 'other' THEN 1 ELSE 0 END AS OTHER_FLAG,
    CASE WHEN RACE = 'black' THEN 1 ELSE 0 END AS BLACK_FLAG,
    CASE WHEN RACE = 'native' THEN 1 ELSE 0 END AS NATIVE_FLAG,
    CASE WHEN ETHNICITY = 'hispanic' THEN 1 ELSE 0 END AS HISPANIC_FLAG,
    CASE WHEN GENDER = 'M' THEN 1 ELSE 0 END AS MALE_FLAG,
    nvl(FEVER_FLAG, 0) AS FEVER_FLAG,
    nvl(COUGH_FLAG, 0) AS COUGH_FLAG,
    nvl(VENT_FLAG, 0) AS VENT_FLAG,
    nvl(OXYGEN_FLAG, 0) AS OXYGEN_FLAG
FROM T1
""")
data_1.show()

+--------------------+-------------------+-----+---+----------+----------+----------+-----------+-------------+---------+----------+----------+---------+-----------+
|                  Id|               RAND|DEATH|AGE|ASIAN_FLAG|OTHER_FLAG|BLACK_FLAG|NATIVE_FLAG|HISPANIC_FLAG|MALE_FLAG|FEVER_FLAG|COUGH_FLAG|VENT_FLAG|OXYGEN_FLAG|
+--------------------+-------------------+-----+---+----------+----------+----------+-----------+-------------+---------+----------+----------+---------+-----------+
|0053c053-09b3-408...| 0.7151043443924859|    1| 51|         0|         0|         0|          0|            0|        0|         1|         1|        1|          1|
|00d38be8-920a-4e6...|  0.833439862021827|    1| 51|         0|         0|         0|          0|            0|        0|         0|         0|        0|          0|
|0120542e-62f7-4b6...|0.20939665324719448|    0| 15|         0|         0|         0|          0|            0|        1|         1|         0|        0|          0|
|013

In [None]:
data_train = data_1.filter(data_1.RAND <= 0.7)
data_test = data_1.filter(data_1.RAND > 0.7)

#### Create Models

Format table as labeled points

In [None]:
from pyspark.mllib.classification import *
from pyspark.mllib.tree import *
from pyspark.mllib.regression import LabeledPoint
from pyspark.mllib.evaluation import BinaryClassificationMetrics

In [None]:
df_train = data_train.rdd.map(lambda x: LabeledPoint(x['DEATH'], x[-11:]))
df_test = data_test.rdd.map(lambda x: LabeledPoint(x['DEATH'], x[-11:]))

Train models

In [None]:
model_log_reg = LogisticRegressionWithSGD.train(df_train)
model_svm = SVMWithSGD.train(df_train)
model_bayes = NaiveBayes.train(df_train)
model_rf = RandomForest.trainClassifier(df_train, 2, {}, 10, seed=1234)

#### Evaluate Models

In [None]:
def eval_label(x, th=0.5):
    y = x[0]
    pred = x[1]
    if (y == 0 and pred < th): return ("TP", 1)
    elif (y == 0 and pred >= th): return ("FP", 1)
    elif (y == 1 and pred < th): return ("FN", 1)
    elif (y == 1 and pred >= th): return ("TN", 1)
    else: return ("ope", 1)

def list_to_metrics(m, model_name = "model"):
    d = dict(m)
    TP = d['TP']
    try:
        FP = d['FP']
    except:
        FP = 0
    try:
        FN = d['FN']
    except:
        FN = 0
    try:
        TN = d['TN']
    except:
        TN = 0
    n = TP + FP + FN + TN
    acc = (TP + TN)/n
    F1 = TP/(TP + 0.5*(FP + FN))
    if (TP + FN) == 0: 
        sens = 0.0
    else: 
        sens = TP/(TP + FN)
    if (TN + FP) == 0: 
        spec = 0.0
    else:
        spec = TN/(TN + FP)
    metrics = {"model":model_name, "n":n, "acc":acc, "F1":F1, "sens":sens, "spec":spec}
    return metrics

In [None]:
# Have to predict w/ Random Forests outside of Spark for some reason
pred_rf = [(x.label, model_rf.predict(x.features)) for x in df_test.collect()]
pred_eval_rf = sc.parallelize(pred_rf).map(eval_label).reduceByKey(lambda x, y: x + y).collect()

In [None]:
pred_log = df_test.map(lambda x: (x.label, model_log_reg.predict(x.features)))
pred_svm = df_test.map(lambda x: (x.label, model_svm.predict(x.features)))
pred_nb = df_test.map(lambda x: (x.label, model_bayes.predict(x.features)))

pred_eval_log = pred_log.map(eval_label).reduceByKey(lambda x, y: x + y).collect()  
pred_eval_svm = pred_svm.map(eval_label).reduceByKey(lambda x, y: x + y).collect()
pred_eval_nb = pred_nb.map(eval_label).reduceByKey(lambda x, y: x + y).collect()

In [None]:
spark.createDataFrame([
                       Row(**list_to_metrics(pred_eval_svm, "Support Vector Machines")),
                       Row(**list_to_metrics(pred_eval_log, "Logistic Regression")),
                       Row(**list_to_metrics(pred_eval_nb, "Naive Bayes")),
                       Row(**list_to_metrics(pred_eval_rf, "Random Forests"))
                       ]).show()

+--------------------+-----+------------------+-------------------+------------------+------------------+
|               model|    n|               acc|                 F1|              sens|              spec|
+--------------------+-----+------------------+-------------------+------------------+------------------+
|Support Vector Ma...|37277|0.8030689164900608| 0.8907800574293664|0.8030689164900608|               0.0|
| Logistic Regression|37277|0.3272259033720525|0.27940120104588684|0.9989726731045819|0.2263498920086393|
|         Naive Bayes|37277|0.9069667623467553|  0.941645633518425|0.9487014307994847|0.7488115122703328|
|      Random Forests|37277|0.9271132333610538| 0.9556169038012317|0.9350724081710943|0.8855903935957304|
+--------------------+-----+------------------+-------------------+------------------+------------------+



Random Forests peformed the best but Naive Bayes also did pretty well.

## Investigate Sub Models

See if we can get just as good test performance using submodels.

In [None]:
def nb_sub_model(pred_vars):
    all_vars = ["DEATH"] + pred_vars
    sub_train = data_train.select(all_vars).rdd.map(lambda x: LabeledPoint(x[0], x[1:]))
    sub_test = data_test.select(all_vars).rdd.map(lambda x: LabeledPoint(x[0], x[1:]))
    model_bayes = NaiveBayes.train(sub_train)
    pred = sub_test.map(lambda x: (x.label, model_bayes.predict(x.features)))
    pred_eval = pred.map(eval_label).reduceByKey(lambda x, y: x + y).collect()
    return pred_eval

In [None]:
vars_demo = ["AGE", "ASIAN_FLAG", "OTHER_FLAG", "BLACK_FLAG", "NATIVE_FLAG", "HISPANIC_FLAG", "MALE_FLAG"]
vars_proc = ["VENT_FLAG", "OXYGEN_FLAG"]
vars_cond = ["FEVER_FLAG", "COUGH_FLAG"]

spark.createDataFrame([
                       Row(**list_to_metrics(nb_sub_model(vars_demo), "Demographics Only")),
                       Row(**list_to_metrics(nb_sub_model(vars_proc), "Procedures Only")),
                       Row(**list_to_metrics(nb_sub_model(vars_cond), "Conditions Only")),
                       Row(**list_to_metrics(nb_sub_model(vars_demo + vars_proc), "Demographics & Procedures")),
                       Row(**list_to_metrics(nb_sub_model(vars_demo + vars_cond), "Demographics & Conditions")),
                       Row(**list_to_metrics(nb_sub_model(vars_cond + vars_proc), "Procedures & Conditions")),
                       Row(**list_to_metrics(nb_sub_model(vars_demo + vars_proc + vars_cond), "Demographics, Conds & Procs")),
                       ]).show()

+--------------------+-----+------------------+------------------+------------------+------------------+
|               model|    n|               acc|                F1|              sens|              spec|
+--------------------+-----+------------------+------------------+------------------+------------------+
|   Demographics Only|37277|0.8030689164900608|0.8907800574293664|0.8030689164900608|               0.0|
|     Procedures Only|37277|0.8202108538777262|0.8989597467209407|0.8192009671923943|0.8618346545866364|
|     Conditions Only|37277|0.8030689164900608|0.8907800574293664|0.8030689164900608|               0.0|
|Demographics & Pr...|37277|0.8180111060439413|0.8979297063071738|0.8169075777485765|0.8718291054739653|
|Demographics & Co...|37277|0.9027282238377552|0.9391896424498558|0.9430486326283174|0.7448912326961108|
|Procedures & Cond...|37277|0.8202108538777262|0.8989597467209407|0.8192009671923943|0.8618346545866364|
|Demographics, Con...|37277|0.9069667623467553| 0.94164

## Incorporate Observations

Using Kavitha's work

In [None]:
obs_subset = spark.read.csv("/content/drive/MyDrive/synthea/Observations_Subset.csv", header=True)
obs_subset.createOrReplaceTempView("observations")

In [None]:
observations_wide = spark.sql("""
Select X.DATE, X.PATIENT, X.ENCOUNTER,
C.Diastolic_BP,       
D.Systolic_BP,      
F.Tobacco_smoking,      
G.BMI,      
Q.Glucose_in_Serum, 
R.HDL_Cholesterol
from 
(Select DISTINCT DATE, PATIENT, ENCOUNTER     from  observations) X
LEFT JOIN                                                                                                         
(Select DATE, PATIENT, ENCOUNTER,VALUE as Diastolic_BP        from  observations where CODE = '8462-4'     ) C
on    X.DATE = C.DATE and X.PATIENT = C.PATIENT and X.ENCOUNTER = C.ENCOUNTER 
LEFT JOIN
(Select DATE, PATIENT, ENCOUNTER,VALUE as Systolic_BP       from  observations where CODE = '8480-6'     ) D
on    X.DATE = D.DATE and X.PATIENT = D.PATIENT and X.ENCOUNTER = D.ENCOUNTER 
LEFT JOIN
(Select DATE, PATIENT, ENCOUNTER,VALUE as Tobacco_smoking     from  observations where CODE = '72166-2'    ) F
on    X.DATE = F.DATE and X.PATIENT = F.PATIENT and X.ENCOUNTER = F.ENCOUNTER 
LEFT JOIN
(Select DATE, PATIENT, ENCOUNTER,VALUE as BMI                 from  observations where CODE = '39156-5'    ) G
on    X.DATE = G.DATE and X.PATIENT = G.PATIENT and X.ENCOUNTER = G.ENCOUNTER 
LEFT JOIN
(Select DATE, PATIENT, ENCOUNTER,VALUE as Glucose_in_Serum      from  observations where CODE = '2345-7'     ) Q
on    X.DATE = q.DATE and X.PATIENT = q.PATIENT and X.ENCOUNTER = q.ENCOUNTER 
LEFT JOIN
(Select DATE, PATIENT, ENCOUNTER,VALUE as HDL_Cholesterol     from  observations where CODE = '2085-9'     ) R
on    X.DATE = r.DATE and X.PATIENT = r.PATIENT and X.ENCOUNTER = r.ENCOUNTER 
""")
observations_wide.show()


+----------+--------------------+--------------------+------------+-----------+---------------+----+----------------+---------------+
|      DATE|             PATIENT|           ENCOUNTER|Diastolic_BP|Systolic_BP|Tobacco_smoking| BMI|Glucose_in_Serum|HDL_Cholesterol|
+----------+--------------------+--------------------+------------+-----------+---------------+----+----------------+---------------+
|1924-04-23|cca21f9e-b16e-44c...|ce75bc62-15e3-4f2...|        90.0|      125.0|   Never smoker|null|            null|           null|
|1936-11-07|b7bd8d9d-ae46-497...|5ed6737b-3c3f-457...|        72.0|      126.0|   Never smoker|null|            null|           null|
|1943-09-21|8e50dfc4-f548-449...|c5a2f758-308b-44a...|        79.0|      126.0|   Never smoker|null|            null|           null|
|1958-01-31|8a926679-db7a-40b...|8bd68d0f-b3d8-419...|        81.0|      110.0|   Never smoker|null|            null|           null|
|1964-02-05|189b094b-f7e9-4c4...|652bd349-8317-459...|        

In [None]:
observations_wide.createOrReplaceTempView("observations_wide")
spark.sql("SELECT DISTINCT Tobacco_smoking FROM observations_wide").collect()

[Row(Tobacco_smoking=None),
 Row(Tobacco_smoking='Never smoker'),
 Row(Tobacco_smoking='Current every day smoker'),
 Row(Tobacco_smoking='Former smoker')]

In [None]:
obs_analytic = spark.sql("""
WITH X1 AS(
    SELECT *, to_date(DATE, 'yyyy-MM-dd') AS DATE_F
    FROM observations_wide
), X2 AS (
    SELECT DISTINCT Id as PATIENT
    FROM patients
), X3 AS (
    SELECT DISTINCT PATIENT, 1 AS FLAG_FORMER_SMOKER
    FROM observations_wide WHERE Tobacco_smoking = 'Former smoker'
), X4 AS (
    SELECT DISTINCT PATIENT, 1 AS FLAG_CURRENT_SMOKER
    FROM observations_wide WHERE Tobacco_smoking = 'Current every day smoker'
), X5 AS (
    SELECT PATIENT, BMI, BMI_FLAG
    FROM (
        SELECT X1.DATE_F, X1.PATIENT, X1.BMI, 1 AS BMI_FLAG,
            row_number() OVER (PARTITION BY X1.PATIENT ORDER BY X1.DATE_F DESC) AS RN
        FROM X1
        WHERE X1.BMI IS NOT NULL
    )
    WHERE RN = 1
), X6 AS (
    SELECT PATIENT, Systolic_BP, Systolic_FLAG
    FROM (
        SELECT X1.DATE_F, X1.PATIENT, X1.Systolic_BP, 1 AS Systolic_FLAG,
            row_number() OVER (PARTITION BY X1.PATIENT ORDER BY X1.DATE_F DESC) AS RN
        FROM X1
        WHERE X1.Systolic_BP IS NOT NULL
    )
    WHERE RN = 1
), X7 AS (
    SELECT PATIENT, Glucose_in_Serum, Glucose_FLAG
    FROM (
        SELECT X1.DATE_F, X1.PATIENT, X1.Glucose_in_Serum, 1 AS Glucose_FLAG,
            row_number() OVER (PARTITION BY X1.PATIENT ORDER BY X1.DATE_F DESC) AS RN
        FROM X1
        WHERE X1.Glucose_in_Serum IS NOT NULL
    )
    WHERE RN = 1
), X8 AS (
    SELECT PATIENT, HDL_Cholesterol, HDL_FLAG
    FROM (
        SELECT X1.DATE_F, X1.PATIENT, X1.HDL_Cholesterol, 1 AS HDL_FLAG,
            row_number() OVER (PARTITION BY X1.PATIENT ORDER BY X1.DATE_F DESC) AS RN
        FROM X1
        WHERE X1.HDL_Cholesterol IS NOT NULL
    )
    WHERE RN = 1
), X_J AS (
    SELECT X2.PATIENT, 
        X3.FLAG_FORMER_SMOKER, X4.FLAG_CURRENT_SMOKER,
        X5.BMI, X5.BMI_FLAG,
        X6.Systolic_BP, X6.Systolic_FLAG,
        X7.Glucose_in_Serum, X7.Glucose_FLAG,
        X8.HDL_Cholesterol, X8.HDL_Flag
    FROM X2
    LEFT JOIN X3 ON X2.PATIENT = X3.PATIENT
    LEFT JOIN X4 ON X2.PATIENT = X4.PATIENT
    LEFT JOIN X5 ON X2.PATIENT = X5.PATIENT
    LEFT JOIN X6 ON X2.PATIENT = X6.PATIENT
    LEFT JOIN X7 ON X2.PATIENT = X7.PATIENT
    LEFT JOIN X8 ON X2.PATIENT = X8.PATIENT
)
SELECT PATIENT,
    NVL(FLAG_FORMER_SMOKER, 0) AS FLAG_FORMER_SMOKER,
    NVL(FLAG_CURRENT_SMOKER, 0) AS FLAG_CURRENT_SMOKER,
    NVL(BMI, 0) AS BMI, NVL(BMI_FLAG, 0) AS BMI_FLAG,
    NVL(Systolic_BP, 0) AS Systolic_BP, NVL(Systolic_FLAG, 0) AS Systolic_FLAG,
    NVL(Glucose_in_Serum, 0) AS Glucose_in_Serum, 
    NVL(Glucose_FLAG, 0) AS Glucose_FLAG,
    NVL(HDL_Cholesterol, 0) AS HDL_Cholesterol, 
    NVL(HDL_FLAG, 0) AS HDL_FLAG
FROM X_J
""")

obs_analytic.show()

+--------------------+------------------+-------------------+----+--------+-----------+-------------+----------------+------------+---------------+--------+
|             PATIENT|FLAG_FORMER_SMOKER|FLAG_CURRENT_SMOKER| BMI|BMI_FLAG|Systolic_BP|Systolic_FLAG|Glucose_in_Serum|Glucose_FLAG|HDL_Cholesterol|HDL_FLAG|
+--------------------+------------------+-------------------+----+--------+-----------+-------------+----------------+------------+---------------+--------+
|0053c053-09b3-408...|                 1|                  0|28.0|       1|      105.0|            1|            87.5|           1|              0|       0|
|00d38be8-920a-4e6...|                 0|                  0|30.2|       1|      115.0|            1|               0|           0|           74.8|       1|
|0120542e-62f7-4b6...|                 0|                  0|17.9|       1|      134.0|            1|               0|           0|              0|       0|
|01358bd4-cdcb-4de...|                 1|                 

#### Compare Naive Bayes Model
Start by joining tables together

In [None]:
obs_analytic.createOrReplaceTempView("obs_analytic")
data_1.createOrReplaceTempView("data_1")
data_2 = spark.sql("""
SELECT X.*, Y.*
FROM data_1 X
INNER JOIN obs_analytic Y ON X.Id = Y.PATIENT
""")
data_2.count()

124150

In [None]:
data_train_2 = data_2.filter(data_1.RAND <= 0.7)
data_test_2 = data_2.filter(data_1.RAND > 0.7)

In [None]:
def nb_sub_model_2(pred_vars):
    all_vars = ["DEATH"] + pred_vars
    sub_train = data_train_2.select(all_vars).rdd.map(lambda x: LabeledPoint(x[0], x[1:]))
    sub_test = data_test_2.select(all_vars).rdd.map(lambda x: LabeledPoint(x[0], x[1:]))
    model_bayes = NaiveBayes.train(sub_train)
    pred = sub_test.map(lambda x: (x.label, model_bayes.predict(x.features)))
    pred_eval = pred.map(eval_label).reduceByKey(lambda x, y: x + y).collect()
    return pred_eval

In [None]:
vars_demo = ["AGE", "ASIAN_FLAG", "OTHER_FLAG", "BLACK_FLAG", "NATIVE_FLAG", "HISPANIC_FLAG", "MALE_FLAG"]
vars_proc = ["VENT_FLAG", "OXYGEN_FLAG"]
vars_cond = ["FEVER_FLAG", "COUGH_FLAG"]
vars_data_1 = vars_demo + vars_proc + vars_cond
vars_obs = ["FLAG_FORMER_SMOKER", "FLAG_CURRENT_SMOKER", "BMI", "BMI_FLAG", "Systolic_BP", 
            "Systolic_FLAG", "Glucose_in_Serum", "Glucose_FLAG", "HDL_Cholesterol", "HDL_FLAG"]

spark.createDataFrame([
                       Row(**list_to_metrics(nb_sub_model_2(vars_data_1), "Non observations")),
                       Row(**list_to_metrics(nb_sub_model_2(vars_obs), "Observations Only"))
                       ]).show()


+-----------------+-----+------------------+------------------+------------------+-------------------+
|            model|    n|               acc|                F1|              sens|               spec|
+-----------------+-----+------------------+------------------+------------------+-------------------+
| Non observations|37277|0.9069667623467553| 0.941645633518425|0.9487014307994847| 0.7488115122703328|
|Observations Only|37277|0.6677575985191941|0.7755486688776527|0.8476409301588559|0.29042712315107194|
+-----------------+-----+------------------+------------------+------------------+-------------------+



In [None]:
spark.createDataFrame([
                       Row(**list_to_metrics(nb_sub_model_2(vars_data_1 + vars_obs), "Both"))
                       ]).show()

+-----+-----+------------------+------------------+------------------+-------------------+
|model|    n|               acc|                F1|              sens|               spec|
+-----+-----+------------------+------------------+------------------+-------------------+
| Both|37277|0.7523137591544384|0.8290343486714193|0.9300760314096971|0.42837674136886733|
+-----+-----+------------------+------------------+------------------+-------------------+



In general, the combination of BMI, blood pressure, cholesterol, and glucose wasn't very helpful.

## Investigate Medications

Shruthi's list of medications

* Baricitinib
* Tocilizumab  
* Hydroxychloroquine
* Chloroquine
* Remdesivir
* Desflurane
* Pulmozyme
* Colchicine

In [None]:
medications = spark.read.csv("/content/drive/MyDrive/synthea/medications.csv", header=True)
medications.createOrReplaceTempView("medications")

In [None]:
spark.sql("""
SELECT COUNT(*) AS baricitinib_count
FROM medications
WHERE lower(DESCRIPTION) like '%baricitinib%'
""").show()

+-----------------+
|baricitinib_count|
+-----------------+
|              149|
+-----------------+



In [None]:
spark.sql("""
SELECT COUNT(*) AS tocilizumab_count
FROM medications
WHERE lower(DESCRIPTION) like '%tocilizumab%'
""").show()

+-----------------+
|baricitinib_count|
+-----------------+
|              161|
+-----------------+



In [None]:
spark.sql("""
SELECT COUNT(*) AS hydroxychloroquine_count
FROM medications
WHERE lower(DESCRIPTION) like '%hydroxychloroquine%'
""").show()

+-----------------+
|baricitinib_count|
+-----------------+
|              208|
+-----------------+



In [None]:
spark.sql("""
SELECT COUNT(*) AS chloroquine_count
FROM medications
WHERE lower(DESCRIPTION) like '%chloroquine%'
""").show()

+-----------------+
|baricitinib_count|
+-----------------+
|              351|
+-----------------+



In [None]:
spark.sql("""
SELECT COUNT(*) AS remdesivir_count
FROM medications
WHERE lower(DESCRIPTION) like '%remdesivir%'
""").show()

+----------------+
|remdesivir_count|
+----------------+
|             164|
+----------------+



In [None]:
spark.sql("""
SELECT COUNT(*) AS desflurane_count
FROM medications
WHERE lower(DESCRIPTION) like '%desflurane%'
""").show()

+----------------+
|desflurane_count|
+----------------+
|               7|
+----------------+



In [None]:
spark.sql("""
SELECT COUNT(*) AS pulmozyme_count
FROM medications
WHERE lower(DESCRIPTION) like '%pulmozyme%'
""").show()

+---------------+
|pulmozyme_count|
+---------------+
|             39|
+---------------+



In [None]:
spark.sql("""
SELECT COUNT(*) AS colchicine_count
FROM medications
WHERE lower(DESCRIPTION) like '%colchicine%'
""").show()

+----------------+
|colchicine_count|
+----------------+
|              51|
+----------------+



In [None]:
meds_analytic = spark.sql("""
WITH PATS AS (
    SELECT DISTINCT Id AS PATIENT FROM patients
), X1 AS (
    SELECT DISTINCT PATIENT, 1 AS tocilizumab_FLAG
    FROM medications WHERE lower(DESCRIPTION) like '%tocilizumab%'
), X2 AS (
    SELECT DISTINCT PATIENT, 1 AS baricitinib_FLAG
    FROM medications WHERE lower(DESCRIPTION) like '%baricitinib%'
), X3 AS (
    SELECT DISTINCT PATIENT, 1 AS hydroxychloroquine_FLAG
    FROM medications WHERE lower(DESCRIPTION) like '%hydroxychloroquine%'
), X4 AS (
    SELECT DISTINCT PATIENT, 1 AS remdesivir_FLAG
    FROM medications WHERE lower(DESCRIPTION) like '%remdesivir%'
), X5 AS (
    SELECT DISTINCT PATIENT, 1 AS pulmozyme_FLAG
    FROM medications WHERE lower(DESCRIPTION) like '%pulmozyme%'
), X6 AS (
    SELECT DISTINCT PATIENT, 1 AS colchicine_FLAG
    FROM medications WHERE lower(DESCRIPTION) like '%colchicine%'
)
SELECT PATS.PATIENT,
    NVL(baricitinib_FLAG, 0) AS baricitinib_FLAG,
    NVL(tocilizumab_FLAG, 0) AS tocilizumab_FLAG,
    NVL(hydroxychloroquine_FLAG, 0) AS hydroxychloroquine_FLAG,
    NVL(remdesivir_FLAG, 0) AS remdesivir_FLAG,
    NVL(pulmozyme_FLAG, 0) AS pulmozyme_FLAG,
    NVL(colchicine_FLAG, 0) AS colchicine_FLAG
FROM PATS
LEFT JOIN X1 ON PATS.PATIENT = X1.PATIENT
LEFT JOIN X2 ON PATS.PATIENT = X2.PATIENT
LEFT JOIN X3 ON PATS.PATIENT = X3.PATIENT
LEFT JOIN X4 ON PATS.PATIENT = X4.PATIENT
LEFT JOIN X5 ON PATS.PATIENT = X5.PATIENT
LEFT JOIN X6 ON PATS.PATIENT = X6.PATIENT
""")

meds_analytic.show()

+--------------------+----------------+----------------+-----------------------+---------------+--------------+---------------+
|             PATIENT|baricitinib_FLAG|tocilizumab_FLAG|hydroxychloroquine_FLAG|remdesivir_FLAG|pulmozyme_FLAG|colchicine_FLAG|
+--------------------+----------------+----------------+-----------------------+---------------+--------------+---------------+
|0053c053-09b3-408...|               0|               0|                      0|              0|             0|              0|
|00d38be8-920a-4e6...|               0|               0|                      0|              0|             0|              0|
|0120542e-62f7-4b6...|               0|               0|                      0|              0|             0|              0|
|01358bd4-cdcb-4de...|               0|               0|                      0|              0|             0|              0|
|01c63fc5-0293-462...|               0|               0|                      0|              0|        

#### Compare Naive Bayes Model

In [None]:
meds_analytic.createOrReplaceTempView("meds_analytic")
data_1.createOrReplaceTempView("data_1")
data_3 = spark.sql("""
SELECT X.*, Y.*
FROM data_1 X
INNER JOIN meds_analytic Y ON X.Id = Y.PATIENT
""")
data_3.count()

124150

In [None]:
data_train_3 = data_3.filter(data_3.RAND <= 0.7)
data_test_3 = data_3.filter(data_3.RAND > 0.7)

def nb_sub_model_3(pred_vars):
    all_vars = ["DEATH"] + pred_vars
    sub_train = data_train_3.select(all_vars).rdd.map(lambda x: LabeledPoint(x[0], x[1:]))
    sub_test = data_test_3.select(all_vars).rdd.map(lambda x: LabeledPoint(x[0], x[1:]))
    model_bayes = NaiveBayes.train(sub_train)
    pred = sub_test.map(lambda x: (x.label, model_bayes.predict(x.features)))
    pred_eval = pred.map(eval_label).reduceByKey(lambda x, y: x + y).collect()
    return pred_eval

In [None]:
vars_demo = ["AGE", "ASIAN_FLAG", "OTHER_FLAG", "BLACK_FLAG", "NATIVE_FLAG", "HISPANIC_FLAG", "MALE_FLAG"]
vars_proc = ["VENT_FLAG", "OXYGEN_FLAG"]
vars_cond = ["FEVER_FLAG", "COUGH_FLAG"]
vars_data_1 = vars_demo + vars_proc + vars_cond
vars_meds = ["baricitinib_FLAG", "tocilizumab_FLAG", "hydroxychloroquine_FLAG", 
             "remdesivir_FLAG", "pulmozyme_FLAG", "colchicine_FLAG"]

spark.createDataFrame([
                       Row(**list_to_metrics(nb_sub_model_3(vars_data_1), "Non medication")),
                       Row(**list_to_metrics(nb_sub_model_3(vars_meds), "Medication Only")),
                       Row(**list_to_metrics(nb_sub_model_3(vars_meds + vars_data_1), "Combined")),
                       ]).show()

+---------------+-----+------------------+------------------+------------------+------------------+
|          model|    n|               acc|                F1|              sens|              spec|
+---------------+-----+------------------+------------------+------------------+------------------+
| Non medication|37277|0.9069667623467553| 0.941645633518425|0.9487014307994847|0.7488115122703328|
|Medication Only|37277|0.8034713093864849|0.8909724082507366|0.8034247678350959|0.8947368421052632|
|       Combined|37277|0.9071545456984199|0.9417506774155545| 0.949018011600692|0.7488455618265777|
+---------------+-----+------------------+------------------+------------------+------------------+

