# Proof of concept - Predict Top-5 diagnosis categories

<ul> <li> D1: Diseases of the circulatory system </ul> </li>
<ul> <li> D2: External causes of injury and supplemental classification </ul> </li>
<ul> <li> D3: Endocrine, nutritional and metabolic diseases, and immunity disorders </ul> </li>
<ul> <li> D4: Diseases of the respiratory system </ul> </li>
<ul> <li> D5: Injury and poisoning </ul> </li>

## 1. Pre process data tables

In [3]:
import re
import os
import ast
import time
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
from pyspark.sql import SQLContext
from pyspark.sql.types import *
from pyspark.sql.functions import collect_list
import pyspark.sql.functions as F
import pyspark.sql.types as T
from pyspark.sql import SparkSession
import nltk
from nltk.corpus import stopwords
import string
from nltk.stem import WordNetLemmatizer
import matplotlib.pyplot as plt
#from wordcloud import WordCloud
#from transformers import AutoTokenizer, AutoModel

In [4]:
app_name = "ClickThrough"
master = "local[*]"
spark = SparkSession\
        .builder\
        .appName(app_name)\
        .master(master)\
        .getOrCreate()
sc = spark.sparkContext
sqlContext = SQLContext(sc)

Edit the S3 URL as per your config

In [23]:
S3_PREFIX = 's3://'
S3_BUCKET = '<Add your bucket>'
S3_PATH = '<Add your folder path>'
S3_URL = S3_PREFIX + S3_BUCKET + S3_PATH

In [119]:
filename_LIST = ['D_ICD_DIAGNOSES.csv','DIAGNOSES_ICD.csv','NOTEEVENTS.csv']
for filename_LIST_ITEM in filename_LIST:
    filename,fileformat = filename_LIST_ITEM.split('.')
    exec(filename+'_DF = sqlContext.read.format("'+fileformat+'").option("header", "true").option("multiline",True).'+
         'option("escape",'+"'"+'"'+"')"+'.load("'+S3_URL+'/'
         +filename+'.csv")')
    exec(filename+'_DF.createOrReplaceTempView("'+filename+'")')

### 1.1 Build HADM level diagnosis flag tables

In [120]:
spark.sql("""
select A.*, B.SHORT_TITLE, B.LONG_TITLE, 
case when substr(A.ICD9_CODE,1,1) in ('E','V') then 'external causes of injury and supplemental classification'
when substr(A.ICD9_CODE,1,3) between 001 and 139 then 'infectious and parasitic diseases'
when substr(A.ICD9_CODE,1,3) between 140 and 239 then 'neoplasms'
when substr(A.ICD9_CODE,1,3) between 240 and 279 then 'endocrine, nutritional and metabolic diseases, and immunity disorders'
when substr(A.ICD9_CODE,1,3) between 280 and 289 then 'diseases of the blood and blood-forming organs'
when substr(A.ICD9_CODE,1,3) between 290 and 319 then 'mental disorders'
when substr(A.ICD9_CODE,1,3) between 320 and 389 then 'diseases of the nervous system and sense organs'
when substr(A.ICD9_CODE,1,3) between 390 and 459 then 'diseases of the circulatory system'
when substr(A.ICD9_CODE,1,3) between 460 and 519 then 'diseases of the respiratory system'
when substr(A.ICD9_CODE,1,3) between 520 and 579 then 'diseases of the digestive system'
when substr(A.ICD9_CODE,1,3) between 580 and 629 then 'diseases of the genitourinary system'
when substr(A.ICD9_CODE,1,3) between 630 and 679 then 'complications of pregnancy, childbirth, and the puerperium'
when substr(A.ICD9_CODE,1,3) between 680 and 709 then 'diseases of the skin and subcutaneous tissue'
when substr(A.ICD9_CODE,1,3) between 710 and 739 then 'diseases of the musculoskeletal system and connective tissue'
when substr(A.ICD9_CODE,1,3) between 740 and 759 then 'congenital anomalies'
when substr(A.ICD9_CODE,1,3) between 760 and 779 then 'certain conditions originating in the perinatal period'
when substr(A.ICD9_CODE,1,3) between 780 and 799 then 'symptoms, signs, and ill-defined conditions'
when substr(A.ICD9_CODE,1,3) between 800 and 999 then 'injury and poisoning' 
end as ICD_GROUP
from DIAGNOSES_ICD A
left join D_ICD_DIAGNOSES B
on A.ICD9_CODE = B.ICD9_CODE
""").createOrReplaceTempView('DIAGNOSES_ICD_WITH_GROUPING')

In [121]:
DIAGNOSIS_GROUPING_DF = spark.sql("""select distinct HADM_ID, 
case when ICD_GROUP = 'diseases of the circulatory system' then 'D1'
when ICD_GROUP = 'external causes of injury and supplemental classification' then 'D2'
when ICD_GROUP = 'endocrine, nutritional and metabolic diseases, and immunity disorders' then 'D3'
when ICD_GROUP = 'diseases of the respiratory system' then 'D4'
when ICD_GROUP = 'injury and poisoning' then 'D5' end as ICD_GROUP_ID
from DIAGNOSES_ICD_WITH_GROUPING 
where ICD_GROUP in ('diseases of the circulatory system',
'external causes of injury and supplemental classification',
'endocrine, nutritional and metabolic diseases, and immunity disorders',
'diseases of the respiratory system',
'injury and poisoning')""")
DIAGNOSIS_GROUPING_PIVOT_DF = DIAGNOSIS_GROUPING_DF.groupBy('HADM_ID').pivot('ICD_GROUP_ID').count().fillna(0)
DIAGNOSIS_GROUPING_PIVOT_DF.createOrReplaceTempView('DIAGNOSIS_GROUPING_PIVOT')

In [122]:
spark.sql("""select * from DIAGNOSIS_GROUPING_PIVOT limit 5""").show()

+-------+---+---+---+---+---+
|HADM_ID| D1| D2| D3| D4| D5|
+-------+---+---+---+---+---+
| 153296|  1|  0|  1|  0|  0|
| 109960|  1|  1|  0|  0|  0|
| 155280|  1|  1|  1|  1|  0|
| 144037|  1|  1|  1|  0|  0|
| 149907|  0|  1|  0|  1|  1|
+-------+---+---+---+---+---+

In [123]:
DIAGNOSIS_GROUPING_PIVOT_DF.show(5)

+-------+---+---+---+---+---+
|HADM_ID| D1| D2| D3| D4| D5|
+-------+---+---+---+---+---+
| 166435|  1|  1|  1|  0|  1|
| 125592|  1|  1|  1|  0|  1|
| 100263|  1|  1|  1|  1|  1|
| 179104|  1|  1|  1|  0|  0|
| 118388|  1|  0|  1|  1|  0|
+-------+---+---+---+---+---+
only showing top 5 rows

### 1.2 Combine all notes at HADM level

In [124]:
NOTEEVENTS_DF = spark.sql("""SELECT A.HADM_ID, B.D1, B.D2, B.D3, B.D4, B.D5, lower(A.TEXT) as TEXT_LOWER 
from NOTEEVENTS A inner join DIAGNOSIS_GROUPING_PIVOT B on A.HADM_ID = B.HADM_ID""")

#### 1.2.1 Group all Text at HADM level

In [125]:
NOTEEVENTS_GROUPED_DF = NOTEEVENTS_DF.groupby('HADM_ID','D1','D2','D3','D4','D5') \
                        .agg(F.concat_ws("", F.collect_list(NOTEEVENTS_DF.TEXT_LOWER)).alias('TEXT'))

In [126]:
NOTEEVENTS_GROUPED_DF.show(5)

+-------+---+---+---+---+---+--------------------+
|HADM_ID| D1| D2| D3| D4| D5|                TEXT|
+-------+---+---+---+---+---+--------------------+
| 100010|  0|  0|  1|  0|  0|admission date:  ...|
| 100140|  0|  1|  0|  1|  0|[**2117-6-17**] 1...|
| 100227|  1|  1|  0|  1|  1|admission date:  ...|
| 100263|  1|  1|  1|  1|  1|admission date:  ...|
| 100320|  1|  1|  1|  0|  0|admission date:  ...|
+-------+---+---+---+---+---+--------------------+
only showing top 5 rows

In [127]:
NOTEEVENTS_GROUPED_DF.printSchema()

root
 |-- HADM_ID: string (nullable = true)
 |-- D1: long (nullable = true)
 |-- D2: long (nullable = true)
 |-- D3: long (nullable = true)
 |-- D4: long (nullable = true)
 |-- D5: long (nullable = true)
 |-- TEXT: string (nullable = false)

In [128]:
NOTEEVENTS_GROUPED_DF.describe().toPandas().T

             0  ...                                                  4
summary  count  ...                                                max
HADM_ID  57632  ...                                             199999
D1       57632  ...                                                  1
D2       57632  ...                                                  1
D3       57632  ...                                                  1
D4       57632  ...                                                  1
D5       57632  ...                                                  1
TEXT     57632  ...  y\nname:  [**known lastname 95474**], [**known...

[8 rows x 5 columns]

#### 1.2.2 Convert to Parquete

In [130]:
NOTEEVENTS_GROUPED_DF.write.mode('overwrite').parquet(S3_URL+'/NOTEEVENTS_GROUPED_DIAG')