# Proof of concept - Predict Top-5 diagnosis categories

<ul> <li> D1: Diseases of the circulatory system </ul> </li>
<ul> <li> D2: External causes of injury and supplemental classification </ul> </li>
<ul> <li> D3: Endocrine, nutritional and metabolic diseases, and immunity disorders </ul> </li>
<ul> <li> D4: Diseases of the respiratory system </ul> </li>
<ul> <li> D5: Injury and poisoning </ul> </li>

## 1. Pre process data tables

In [177]:
import re
import os
import ast
import time
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
from pyspark.sql import SQLContext
from pyspark.sql.types import *
from pyspark.sql.functions import collect_list
import pyspark.sql.functions as F
import pyspark.sql.types as T
from pyspark.sql import SparkSession
import nltk
from nltk.corpus import stopwords
import string
from nltk.stem import WordNetLemmatizer
import matplotlib.pyplot as plt
from pyspark.ml.feature import RegexTokenizer, StopWordsRemover, CountVectorizer, VectorAssembler, StringIndexer, VectorIndexer
from pyspark.ml.classification import LogisticRegression
from pyspark.ml import Pipeline

In [4]:
app_name = "ClickThrough"
master = "local[*]"
spark = SparkSession\
        .builder\
        .appName(app_name)\
        .master(master)\
        .getOrCreate()
sc = spark.sparkContext
sqlContext = SQLContext(sc)

Edit the S3 URL as per your config

In [23]:
S3_PREFIX = 's3://'
S3_BUCKET = '<bucket name>'
S3_PATH = '<folder path>'
S3_URL = S3_PREFIX + S3_BUCKET + S3_PATH

In [119]:
filename_LIST = ['D_ICD_DIAGNOSES.csv','DIAGNOSES_ICD.csv','NOTEEVENTS.csv']
for filename_LIST_ITEM in filename_LIST:
    filename,fileformat = filename_LIST_ITEM.split('.')
    exec(filename+'_DF = sqlContext.read.format("'+fileformat+'").option("header", "true").option("multiline",True).'+
         'option("escape",'+"'"+'"'+"')"+'.load("'+S3_URL+'/'
         +filename+'.csv")')
    exec(filename+'_DF.createOrReplaceTempView("'+filename+'")')

### 1.1 Build HADM level diagnosis flag tables

In [120]:
spark.sql("""
select A.*, B.SHORT_TITLE, B.LONG_TITLE, 
case when substr(A.ICD9_CODE,1,1) in ('E','V') then 'external causes of injury and supplemental classification'
when substr(A.ICD9_CODE,1,3) between 001 and 139 then 'infectious and parasitic diseases'
when substr(A.ICD9_CODE,1,3) between 140 and 239 then 'neoplasms'
when substr(A.ICD9_CODE,1,3) between 240 and 279 then 'endocrine, nutritional and metabolic diseases, and immunity disorders'
when substr(A.ICD9_CODE,1,3) between 280 and 289 then 'diseases of the blood and blood-forming organs'
when substr(A.ICD9_CODE,1,3) between 290 and 319 then 'mental disorders'
when substr(A.ICD9_CODE,1,3) between 320 and 389 then 'diseases of the nervous system and sense organs'
when substr(A.ICD9_CODE,1,3) between 390 and 459 then 'diseases of the circulatory system'
when substr(A.ICD9_CODE,1,3) between 460 and 519 then 'diseases of the respiratory system'
when substr(A.ICD9_CODE,1,3) between 520 and 579 then 'diseases of the digestive system'
when substr(A.ICD9_CODE,1,3) between 580 and 629 then 'diseases of the genitourinary system'
when substr(A.ICD9_CODE,1,3) between 630 and 679 then 'complications of pregnancy, childbirth, and the puerperium'
when substr(A.ICD9_CODE,1,3) between 680 and 709 then 'diseases of the skin and subcutaneous tissue'
when substr(A.ICD9_CODE,1,3) between 710 and 739 then 'diseases of the musculoskeletal system and connective tissue'
when substr(A.ICD9_CODE,1,3) between 740 and 759 then 'congenital anomalies'
when substr(A.ICD9_CODE,1,3) between 760 and 779 then 'certain conditions originating in the perinatal period'
when substr(A.ICD9_CODE,1,3) between 780 and 799 then 'symptoms, signs, and ill-defined conditions'
when substr(A.ICD9_CODE,1,3) between 800 and 999 then 'injury and poisoning' 
end as ICD_GROUP
from DIAGNOSES_ICD A
left join D_ICD_DIAGNOSES B
on A.ICD9_CODE = B.ICD9_CODE
""").createOrReplaceTempView('DIAGNOSES_ICD_WITH_GROUPING')

In [121]:
DIAGNOSIS_GROUPING_DF = spark.sql("""select distinct HADM_ID, 
case when ICD_GROUP = 'diseases of the circulatory system' then 'D1'
when ICD_GROUP = 'external causes of injury and supplemental classification' then 'D2'
when ICD_GROUP = 'endocrine, nutritional and metabolic diseases, and immunity disorders' then 'D3'
when ICD_GROUP = 'diseases of the respiratory system' then 'D4'
when ICD_GROUP = 'injury and poisoning' then 'D5' end as ICD_GROUP_ID
from DIAGNOSES_ICD_WITH_GROUPING 
where ICD_GROUP in ('diseases of the circulatory system',
'external causes of injury and supplemental classification',
'endocrine, nutritional and metabolic diseases, and immunity disorders',
'diseases of the respiratory system',
'injury and poisoning')""")
DIAGNOSIS_GROUPING_PIVOT_DF = DIAGNOSIS_GROUPING_DF.groupBy('HADM_ID').pivot('ICD_GROUP_ID').count().fillna(0)
DIAGNOSIS_GROUPING_PIVOT_DF.createOrReplaceTempView('DIAGNOSIS_GROUPING_PIVOT')

In [122]:
spark.sql("""select * from DIAGNOSIS_GROUPING_PIVOT limit 5""").show()

+-------+---+---+---+---+---+
|HADM_ID| D1| D2| D3| D4| D5|
+-------+---+---+---+---+---+
| 153296|  1|  0|  1|  0|  0|
| 109960|  1|  1|  0|  0|  0|
| 155280|  1|  1|  1|  1|  0|
| 144037|  1|  1|  1|  0|  0|
| 149907|  0|  1|  0|  1|  1|
+-------+---+---+---+---+---+

In [123]:
DIAGNOSIS_GROUPING_PIVOT_DF.show(5)

+-------+---+---+---+---+---+
|HADM_ID| D1| D2| D3| D4| D5|
+-------+---+---+---+---+---+
| 166435|  1|  1|  1|  0|  1|
| 125592|  1|  1|  1|  0|  1|
| 100263|  1|  1|  1|  1|  1|
| 179104|  1|  1|  1|  0|  0|
| 118388|  1|  0|  1|  1|  0|
+-------+---+---+---+---+---+
only showing top 5 rows

### 1.2 Combine all notes at HADM level

In [124]:
NOTEEVENTS_DF = spark.sql("""SELECT A.HADM_ID, B.D1, B.D2, B.D3, B.D4, B.D5, lower(A.TEXT) as TEXT_LOWER 
from NOTEEVENTS A inner join DIAGNOSIS_GROUPING_PIVOT B on A.HADM_ID = B.HADM_ID""")

#### 1.2.1 Group all Text at HADM level

In [125]:
NOTEEVENTS_GROUPED_DF = NOTEEVENTS_DF.groupby('HADM_ID','D1','D2','D3','D4','D5') \
                        .agg(F.concat_ws("", F.collect_list(NOTEEVENTS_DF.TEXT_LOWER)).alias('TEXT'))

In [126]:
NOTEEVENTS_GROUPED_DF.show(5)

+-------+---+---+---+---+---+--------------------+
|HADM_ID| D1| D2| D3| D4| D5|                TEXT|
+-------+---+---+---+---+---+--------------------+
| 100010|  0|  0|  1|  0|  0|admission date:  ...|
| 100140|  0|  1|  0|  1|  0|[**2117-6-17**] 1...|
| 100227|  1|  1|  0|  1|  1|admission date:  ...|
| 100263|  1|  1|  1|  1|  1|admission date:  ...|
| 100320|  1|  1|  1|  0|  0|admission date:  ...|
+-------+---+---+---+---+---+--------------------+
only showing top 5 rows

In [127]:
NOTEEVENTS_GROUPED_DF.printSchema()

root
 |-- HADM_ID: string (nullable = true)
 |-- D1: long (nullable = true)
 |-- D2: long (nullable = true)
 |-- D3: long (nullable = true)
 |-- D4: long (nullable = true)
 |-- D5: long (nullable = true)
 |-- TEXT: string (nullable = false)

In [128]:
NOTEEVENTS_GROUPED_DF.describe().toPandas().T

             0  ...                                                  4
summary  count  ...                                                max
HADM_ID  57632  ...                                             199999
D1       57632  ...                                                  1
D2       57632  ...                                                  1
D3       57632  ...                                                  1
D4       57632  ...                                                  1
D5       57632  ...                                                  1
TEXT     57632  ...  y\nname:  [**known lastname 95474**], [**known...

[8 rows x 5 columns]

#### 1.2.2 Convert to Parquete

In [130]:
NOTEEVENTS_GROUPED_DF.write.mode('overwrite').parquet(S3_URL+'/NOTEEVENTS_GROUPED_DIAG')

### 2.0 ETL Pipeline

#### 2.1 Replace Anonymized Names and dates

In [275]:
NOTES_RDD = NOTEEVENTS_GROUPED_DF.select('HADM_ID','D1','D2','D3','D4','D5','TEXT').rdd
def ReplaceAnonym(x):
    return (x[0], x[1], x[2], x[3], x[4], x[5],
            16*x[1]+8*x[2]+4*x[3]+2*x[4]+1*x[5],
            re.sub('\n',' ',re.sub(r'\[\*\*.+\*\*\]','xxx',x[6])))
NOTES_GROUPED_ICD_DF = NOTES_RDD.map(ReplaceAnonym).toDF(['HADM_ID','D1','D2','D3','D4','D5','DX_INDEX','TEXT'])

In [190]:
NOTES_GROUPED_ICD_DF.show(5)

+-------+---+---+---+---+---+--------+--------------------+
|HADM_ID| D1| D2| D3| D4| D5|DX_INDEX|                TEXT|
+-------+---+---+---+---+---+--------+--------------------+
| 100010|  0|  0|  1|  0|  0|       4|admission date:  ...|
| 100140|  0|  1|  0|  1|  0|      10|xxx 12:28 pm  che...|
| 100227|  1|  1|  0|  1|  1|      27|admission date:  ...|
| 100263|  1|  1|  1|  1|  1|      31|admission date:  ...|
| 100320|  1|  1|  1|  0|  0|       7|admission date:  ...|
+-------+---+---+---+---+---+--------+--------------------+
only showing top 5 rows

In [323]:
stopword_LIST = ['did', "couldn't", 'herself', 'above', 'hers', 'ain', 'if', 'until', 'me', 'through', 'some', 'be', 'myself', 'because', 'don', "shouldn't", 'here', 'as', 'can', 'it', 'on', 'no', "you've", 'the', 'our', 'we', "that'll", 'do', 'then', 'will', 'most', 'yours', 'yourselves', 'he', 'yourself', 'few', 'with', 'mightn', 'doesn', 'at', 'y', 'only', "you're", 'down', 'how', 'any', 'very', 'wouldn', 'himself', "hasn't", 'll', 'm', 'its', 'off', 'themselves', 'other', 'own', 'are', 'from', 'just', 'itself', 'has', 'ourselves', 'each', 'which', 'weren', 'i', 'should', 'shan', 'having', 'those', 'have', 'than', 'or', 'there', 'were', 'up', "mustn't", "wouldn't", "needn't", 'was', 'why', 're', 'they', "you'd", 'she', 'her', 'isn', "you'll", 'under', 'shouldn', 'to', 'nor', 'and', 'd', 'my', 'o', 'a', 'by', 'after', 'against', 'your', 'does', "it's", 's', 'you', "should've", 'him', 'hasn', 'again', "aren't", 'into', 'where', 'couldn', 'below', 'ma', 'didn', 'ours', 'wasn', 'about', 'what', 'when', 'same', 'is', "don't", 'during', 'in', 'mustn', 'needn', 'had', 'while', 'too', 'both', 'but', 'whom', 'between', "isn't", 'theirs', 'won', 'out', 'an', 'that', 'for', "didn't", 'this', "doesn't", "wasn't", 'am', 'aren', 'these', 't', 'who', 'further', "hadn't", 'his', 'more', 'before', 'them', 'not', "shan't", "mightn't", 'such', 'been', 'haven', 'being', 'over', 'once', 've', 'hadn', 'doing', 'of', "won't", 'now', 'all', "haven't", 'their', "weren't", 'so', "she's"]

#### 2.2 Count Vectorize TEXT

In [324]:
# regular expression tokenizer
regexTokenizer = RegexTokenizer(inputCol="TEXT", outputCol="WORDS", pattern="\\W")
# remove stopwords
stopwordsRemover = StopWordsRemover(inputCol="WORDS", outputCol="WORDS_FILT").setStopWords(stopword_LIST)
# bag of words count
countVectors = CountVectorizer(inputCol="WORDS_FILT", outputCol="FEATURES", vocabSize=5000, minDF=5)
#combine all flags D1 thru D2 into a Dense vector
assembler = VectorAssembler(inputCols = ['D1','D2','D3','D4','D5'], outputCol = 'DX_VEC')

In [325]:
pipeline = Pipeline(stages=[regexTokenizer, stopwordsRemover, countVectors, assembler])
# Fit the pipeline to training documents.
pipelineFit = pipeline.fit(NOTES_GROUPED_ICD_DF)
NOTES_DX_DF = pipelineFit.transform(NOTES_GROUPED_ICD_DF)\
                .select('HADM_ID','D1','D2','D3','D4','D5','DX_INDEX','DX_VEC','FEATURES')

In [223]:
NOTES_DX_DF.select('HADM_ID','D1','D2','D3','D4','D5','DX_INDEX','DX_VEC').show(5, truncate=False)

+-------+---+---+---+---+---+--------+---------------------+
|HADM_ID|D1 |D2 |D3 |D4 |D5 |DX_INDEX|DX_VEC               |
+-------+---+---+---+---+---+--------+---------------------+
|100010 |0  |0  |1  |0  |0  |4       |(5,[2],[1.0])        |
|100140 |0  |1  |0  |1  |0  |10      |(5,[1,3],[1.0,1.0])  |
|100227 |1  |1  |0  |1  |1  |27      |[1.0,1.0,0.0,1.0,1.0]|
|100263 |1  |1  |1  |1  |1  |31      |[1.0,1.0,1.0,1.0,1.0]|
|100320 |1  |1  |1  |0  |0  |7       |[1.0,1.0,1.0,0.0,0.0]|
+-------+---+---+---+---+---+--------+---------------------+
only showing top 5 rows

In [287]:
NOTES_DX_LR_DF = NOTES_DX_DF.select('HADM_ID','DX_INDEX','FEATURES').cache()
NOTES_DX_LR_DF.show(5)

+-------+--------+--------------------+
|HADM_ID|DX_INDEX|            FEATURES|
+-------+--------+--------------------+
| 100010|       4|(5000,[0,1,2,4,5,...|
| 100140|      10|(5000,[0,1,3,4,5,...|
| 100227|      27|(5000,[0,1,2,3,4,...|
| 100263|      31|(5000,[0,1,2,3,4,...|
| 100320|      28|(5000,[0,1,2,3,4,...|
+-------+--------+--------------------+
only showing top 5 rows

### 3 Partition Training & Test sets

In [316]:
(trainingData, testData) = NOTES_DX_DF.randomSplit([0.8, 0.2], seed = 100)
print("Training Dataset Count: " + str(trainingData.count()))
print("Test Dataset Count: " + str(testData.count()))

Training Dataset Count: 46177
Test Dataset Count: 11455

### 4 Model Training and Evaluation

#### 4.1 Logistic Regression using Count Vector Features for combination of 5 therapy areas

In [317]:
lr = LogisticRegression(maxIter=50, featuresCol='FEATURES', labelCol='DX_INDEX')
lrModel = lr.fit(trainingData)
predictions = lrModel.transform(testData)

In [318]:
predictions_DF = predictions.select('HADM_ID','DX_INDEX','prediction').toPandas()

In [319]:
correct_combo = predictions_DF[predictions_DF['DX_INDEX']==predictions_DF['prediction']]['DX_INDEX'].count()
total_test_records = predictions_DF['DX_INDEX'].count()
print(str(round(100*correct_combo/total_test_records,2))+'%')

37.06%

#### 4.2 Simple LR to predict individual therapy areas

In [236]:
lr = LogisticRegression(maxIter=50, featuresCol='FEATURES', labelCol='D1')
lrModel = lr.fit(trainingData)
predictions_D1 = lrModel.transform(testData)

In [237]:
lr = LogisticRegression(maxIter=50, featuresCol='FEATURES', labelCol='D2')
lrModel = lr.fit(trainingData)
predictions_D2 = lrModel.transform(testData)

In [238]:
lr = LogisticRegression(maxIter=50, featuresCol='FEATURES', labelCol='D3')
lrModel = lr.fit(trainingData)
predictions_D3 = lrModel.transform(testData)

In [239]:
lr = LogisticRegression(maxIter=50, featuresCol='FEATURES', labelCol='D4')
lrModel = lr.fit(trainingData)
predictions_D4 = lrModel.transform(testData)

In [240]:
lr = LogisticRegression(maxIter=50, featuresCol='FEATURES', labelCol='D5')
lrModel = lr.fit(trainingData)
predictions_D5 = lrModel.transform(testData)

#### Calculate Accuracy

In [253]:
predictions_DF1 = predictions_D1.select('D1','D2','D3','D4','D5','prediction').toPandas()
correct_combo = predictions_DF1[predictions_DF1['D1']==predictions_DF1['prediction']]['prediction'].count()
total_test_records = predictions_DF1['prediction'].count()
print('D1 accuracy = '+str(round(100*correct_combo/total_test_records,2))+'%')

D1 accuracy = 89.31%

In [254]:
pd.crosstab(predictions_DF1.D1,predictions_DF1.prediction)

prediction   0.0   1.0
D1                    
0           2394   664
1            560  7837

In [255]:
predictions_DF2 = predictions_D2.select('D1','D2','D3','D4','D5','prediction').toPandas()
correct_combo = predictions_DF2[predictions_DF2['D2']==predictions_DF2['prediction']]['prediction'].count()
total_test_records = predictions_DF2['prediction'].count()
print('D2 accuracy = '+str(round(100*correct_combo/total_test_records,2))+'%')

D2 accuracy = 76.06%

In [256]:
pd.crosstab(predictions_DF2.D2,predictions_DF2.prediction)

prediction   0.0   1.0
D2                    
0           1583  1634
1           1108  7130

In [257]:
predictions_DF3 = predictions_D3.select('D1','D2','D3','D4','D5','prediction').toPandas()
correct_combo = predictions_DF3[predictions_DF3['D3']==predictions_DF3['prediction']]['prediction'].count()
total_test_records = predictions_DF3['prediction'].count()
print('D3 accuracy = '+str(round(100*correct_combo/total_test_records,2))+'%')

D3 accuracy = 79.3%

In [258]:
pd.crosstab(predictions_DF3.D3,predictions_DF3.prediction)

prediction   0.0   1.0
D3                    
0           3077  1389
1            982  6007

In [259]:
predictions_DF4 = predictions_D4.select('D1','D2','D3','D4','D5','prediction').toPandas()
correct_combo = predictions_DF4[predictions_DF4['D4']==predictions_DF4['prediction']]['prediction'].count()
total_test_records = predictions_DF4['prediction'].count()
print('D4 accuracy = '+str(round(100*correct_combo/total_test_records,2))+'%')

D4 accuracy = 82.18%

In [260]:
pd.crosstab(predictions_DF4.D4,predictions_DF4.prediction)

prediction   0.0   1.0
D4                    
0           5795   770
1           1271  3619

In [261]:
predictions_DF5 = predictions_D5.select('D1','D2','D3','D4','D5','prediction').toPandas()
correct_combo = predictions_DF5[predictions_DF5['D5']==predictions_DF5['prediction']]['prediction'].count()
total_test_records = predictions_DF5['prediction'].count()
print('D5 accuracy = '+str(round(100*correct_combo/total_test_records,2))+'%')

D5 accuracy = 77.96%

In [262]:
pd.crosstab(predictions_DF5.D5,predictions_DF5.prediction)

prediction   0.0   1.0
D5                    
0           6107   922
1           1603  2823