# MIMIC III Preprocessing

## Initialization and Data Loading

In [1]:
from pyspark import SparkContext, SparkConf
from pyspark.sql.types import *

In [2]:
conf = SparkConf().setAppName("preprocess").setMaster("local")
sc = SparkContext.getOrCreate(conf)
spark = SparkSession.builder.master("local").appName("preprocess").getOrCreate()

ne_struct = StructType([StructField("row_id", IntegerType(), True),
                      StructField("subject_id", IntegerType(), True),
                      StructField("hadm_id", IntegerType(), True),
                      StructField("chartdate", DateType(), True),
                      StructField("category", StringType(), True),
                      StructField("description", StringType(), True),
                      StructField("cgid", IntegerType(), True),
                      StructField("iserror", IntegerType(), True),
                      StructField("text", StringType(), True)])
df_ne = spark.read.csv("./data/NOTEEVENTS-2.csv",
# df_ne = spark.read.csv("./data/NOTEEVENTS-2sample.csv",
                       header=True,
                       schema=ne_struct)
df_ne.registerTempTable("noteevents")
df_ne.filter(df_ne.category=="Discharge summary") \
    .registerTempTable("noteevents2")
    
# i want to cache noteevents, but it's too big

# many icd to one hadm_id
diag_struct = StructType([StructField("ROW_ID", IntegerType(), True),
                          StructField("SUBJECT_ID", IntegerType(), True),
                          StructField("HADM_ID", IntegerType(), True),
                          StructField("SEQ_NUM", IntegerType(), True),
                          StructField("ICD9_CODE", StringType(), True)])
df_diag_m = spark.read.csv("./data/DIAGNOSES_ICD.csv",
                           header=True,
                           schema=diag_struct) \
            .selectExpr("ROW_ID as row_id", 
                        "SUBJECT_ID as subject_id",
                        "HADM_ID as hadm_id",
                        "SEQ_NUM as seq_num",
                        "ICD9_CODE as icd9_code")
df_diag_m.registerTempTable("diagnoses_icd_m")
df_diag_m.cache()

# one icd to one hadm_id (take the smallest seq number as primary)
diag_o_rdd = df_diag_m.rdd.sortBy(lambda x: (x.hadm_id, x.subject_id, x.seq_num)) \
    .groupBy(lambda x: x.hadm_id) \
    .mapValues(list) \
    .reduceByKey(lambda x, y: x if x.seq_num < y.seq_num else y) \
    .map(lambda (hid, d): d[0])
df_diag_o = spark.createDataFrame(diag_o_rdd)
df_diag_o.registerTempTable("diagnoses_icd_o")
df_diag_o.cache()

# get hadm_id list in noteevents
df_hadm_id_list = spark.sql("""
SELECT DISTINCT hadm_id FROM noteevents2
""")
df_hadm_id_list.registerTempTable("hadm_id_list")
df_hadm_id_list.cache()

# get subject_id list in noteevents
df_subject_id_list = spark.sql("""
SELECT DISTINCT subject_id FROM noteevents2
""")
df_subject_id_list.registerTempTable("subject_id_list")
df_subject_id_list.cache()

print df_ne.dtypes
print df_diag_m.dtypes
print df_diag_o.dtypes
print df_hadm_id_list.dtypes
print df_subject_id_list.dtypes

[('row_id', 'int'), ('subject_id', 'int'), ('hadm_id', 'int'), ('chartdate', 'date'), ('category', 'string'), ('description', 'string'), ('cgid', 'int'), ('iserror', 'int'), ('text', 'string')]
[('row_id', 'int'), ('subject_id', 'int'), ('hadm_id', 'int'), ('seq_num', 'int'), ('icd9_code', 'string')]
[('row_id', 'bigint'), ('subject_id', 'bigint'), ('hadm_id', 'bigint'), ('seq_num', 'bigint'), ('icd9_code', 'string')]
[('hadm_id', 'int')]
[('subject_id', 'int')]


In [3]:
df_diag_o2 = spark.sql("""
SELECT row_id, subject_id, diagnoses_icd_o.hadm_id AS hadm_id,
seq_num, icd9_code
FROM diagnoses_icd_o JOIN hadm_id_list
ON diagnoses_icd_o.hadm_id = hadm_id_list.hadm_id
""")
df_diag_o2.registerTempTable("diagnoses_icd_o2")
df_diag_o2.cache()

DataFrame[row_id: bigint, subject_id: bigint, hadm_id: bigint, seq_num: bigint, icd9_code: string]

In [4]:
df_diag_m2 = spark.sql("""
SELECT row_id, subject_id, diagnoses_icd_m.hadm_id AS hadm_id,
seq_num, icd9_code
FROM diagnoses_icd_m JOIN hadm_id_list
ON diagnoses_icd_m.hadm_id = hadm_id_list.hadm_id
""")
df_diag_m2.registerTempTable("diagnoses_icd_m2")
df_diag_m2.cache()

DataFrame[row_id: int, subject_id: int, hadm_id: int, seq_num: int, icd9_code: string]

## Descriptive Statistics

### noteevents
Basic Counts:

In [5]:
spark.sql("""
SELECT COUNT(*), COUNT(DISTINCT subject_id), COUNT(DISTINCT hadm_id)
FROM noteevents
""").show()
spark.sql("""
SELECT COUNT(*), COUNT(DISTINCT subject_id), COUNT(DISTINCT hadm_id)
FROM noteevents2
""").show()

+--------+--------------------------+-----------------------+
|count(1)|count(DISTINCT subject_id)|count(DISTINCT hadm_id)|
+--------+--------------------------+-----------------------+
| 2083180|                     46146|                  58361|
+--------+--------------------------+-----------------------+

+--------+--------------------------+-----------------------+
|count(1)|count(DISTINCT subject_id)|count(DISTINCT hadm_id)|
+--------+--------------------------+-----------------------+
|   59652|                     41127|                  52726|
+--------+--------------------------+-----------------------+



Categories:

In [6]:
spark.sql("""
SELECT DISTINCT(category)
FROM noteevents
""").show()

+-----------------+
|         category|
+-----------------+
|              ECG|
|     Respiratory |
|          Nursing|
|          General|
|          Consult|
|             Echo|
|        Nutrition|
|       Physician |
|         Pharmacy|
|   Rehab Services|
| Case Management |
|        Radiology|
|    Nursing/other|
|Discharge summary|
|      Social Work|
+-----------------+



### diagnoses_icd: many (icd_code) to one (hadm_id)
Basic Counts:

In [7]:
spark.sql("""
SELECT COUNT(*), COUNT(DISTINCT subject_id), 
COUNT(DISTINCT hadm_id), COUNT(DISTINCT ICD9_CODE)
FROM diagnoses_icd_m
""").show()

spark.sql("""
SELECT COUNT(*), COUNT(DISTINCT subject_id), 
COUNT(DISTINCT hadm_id), COUNT(DISTINCT LOWER(ICD9_CODE))
FROM diagnoses_icd_m
""").show()

+--------+--------------------------+-----------------------+-------------------------+
|count(1)|count(DISTINCT subject_id)|count(DISTINCT hadm_id)|count(DISTINCT ICD9_CODE)|
+--------+--------------------------+-----------------------+-------------------------+
|  651047|                     46520|                  58976|                     6984|
+--------+--------------------------+-----------------------+-------------------------+

+--------+--------------------------+-----------------------+--------------------------------+
|count(1)|count(DISTINCT subject_id)|count(DISTINCT hadm_id)|count(DISTINCT lower(ICD9_CODE))|
+--------+--------------------------+-----------------------+--------------------------------+
|  651047|                     46520|                  58976|                            6984|
+--------+--------------------------+-----------------------+--------------------------------+



### diagnoses_icd: one (icd_code) to one (hadm_id)
Basic Counts:

In [8]:
spark.sql("""
SELECT COUNT(*), COUNT(DISTINCT subject_id), 
COUNT(DISTINCT hadm_id), COUNT(DISTINCT ICD9_CODE)
FROM diagnoses_icd_o
""").show()

spark.sql("""
SELECT COUNT(*), COUNT(DISTINCT subject_id), 
COUNT(DISTINCT hadm_id), COUNT(DISTINCT LOWER(ICD9_CODE))
FROM diagnoses_icd_o
""").show()

+--------+--------------------------+-----------------------+-------------------------+
|count(1)|count(DISTINCT subject_id)|count(DISTINCT hadm_id)|count(DISTINCT ICD9_CODE)|
+--------+--------------------------+-----------------------+-------------------------+
|   58976|                     46520|                  58976|                     2789|
+--------+--------------------------+-----------------------+-------------------------+

+--------+--------------------------+-----------------------+--------------------------------+
|count(1)|count(DISTINCT subject_id)|count(DISTINCT hadm_id)|count(DISTINCT lower(ICD9_CODE))|
+--------+--------------------------+-----------------------+--------------------------------+
|   58976|                     46520|                  58976|                            2789|
+--------+--------------------------+-----------------------+--------------------------------+



Just to check if I really did get "seq_num = 1" for all diagnosis, the code below should return empty. 

In [9]:
# check code
spark.sql("""
SELECT *
FROM diagnoses_icd_o
WHERE seq_num <> 1
""").show()

+------+----------+-------+-------+---------+
|row_id|subject_id|hadm_id|seq_num|icd9_code|
+------+----------+-------+-------+---------+
+------+----------+-------+-------+---------+



### noteevents and diagnoses_icd (one to one)
Basic Counts:

In [10]:
spark.sql("""
SELECT COUNT(DISTINCT subject_id), 
COUNT(DISTINCT hadm_id), COUNT(DISTINCT icd9_code)
FROM diagnoses_icd_o2
""").show()

+--------------------------+-----------------------+-------------------------+
|count(DISTINCT subject_id)|count(DISTINCT hadm_id)|count(DISTINCT icd9_code)|
+--------------------------+-----------------------+-------------------------+
|                     41127|                  52726|                     2706|
+--------------------------+-----------------------+-------------------------+



Top 50 ICD 9 codes based on "subject_id" count

In [11]:
spark.sql("""
SELECT icd9_code, COUNT(DISTINCT subject_id) AS sid_count
FROM diagnoses_icd_o2
GROUP BY icd9_code
ORDER BY sid_count DESC
LIMIT 50
""").show(n=50)

+---------+---------+
|icd9_code|sid_count|
+---------+---------+
|    41401|     3435|
|     0389|     1837|
|    41071|     1672|
|    V3001|     1390|
|     4241|     1128|
|    51881|      986|
|    V3000|      949|
|      431|      948|
|    V3101|      851|
|      486|      654|
|     5070|      581|
|     4240|      552|
|     4280|      513|
|     5849|      500|
|      430|      491|
|    41011|      469|
|    41041|      465|
|     5789|      410|
|     5770|      343|
|    41519|      334|
|     1983|      330|
|    43411|      330|
|    43491|      318|
|    42731|      307|
|    99859|      299|
|    03842|      291|
|    85221|      279|
|    56212|      251|
|    42823|      246|
|    V3401|      235|
|    99662|      229|
|    42833|      229|
|     4271|      221|
|    51884|      220|
|     4321|      214|
|     5712|      212|
|    99811|      209|
|    49121|      208|
|     4373|      202|
|    03849|      201|
|    85220|      197|
|    03811|      192|
|     4414

Top 50 ICD 9 codes based on "hadm_id" count

In [12]:
spark.sql("""
SELECT icd9_code, COUNT(DISTINCT hadm_id) AS hadm_count
FROM diagnoses_icd_o2
GROUP BY icd9_code
ORDER BY hadm_count DESC
LIMIT 50
""").show(n=50)

+---------+----------+
|icd9_code|hadm_count|
+---------+----------+
|    41401|      3464|
|     0389|      1976|
|    41071|      1719|
|    V3001|      1390|
|     4241|      1136|
|    51881|      1089|
|      431|       966|
|    V3000|       949|
|    V3101|       851|
|      486|       703|
|     5070|       641|
|     4240|       558|
|     4280|       553|
|     5849|       518|
|      430|       495|
|    41011|       472|
|    41041|       467|
|     5789|       432|
|     5770|       365|
|     1983|       355|
|    41519|       337|
|    43411|       331|
|    43491|       318|
|    42731|       317|
|    99859|       308|
|    03842|       304|
|    85221|       289|
|    99662|       285|
|    25013|       284|
|    42823|       284|
|    56212|       267|
|    42833|       267|
|    49121|       264|
|     4271|       255|
|     5712|       247|
|    51884|       245|
|     4373|       242|
|    V3401|       235|
|     4321|       231|
|    29181|       229|
|    99811|

### noteevents and diagnoses_icd (many to one)
Basic Counts:

In [13]:
spark.sql("""
SELECT COUNT(DISTINCT subject_id), 
COUNT(DISTINCT hadm_id), COUNT(DISTINCT icd9_code)
FROM diagnoses_icd_m2
""").show()

+--------------------------+-----------------------+-------------------------+
|count(DISTINCT subject_id)|count(DISTINCT hadm_id)|count(DISTINCT icd9_code)|
+--------------------------+-----------------------+-------------------------+
|                     41127|                  52726|                     6918|
+--------------------------+-----------------------+-------------------------+



Top ICD 9 codes based on "subject_id" count

In [14]:
spark.sql("""
SELECT icd9_code, COUNT(DISTINCT subject_id) AS sid_count
FROM diagnoses_icd_m2
GROUP BY icd9_code
ORDER BY sid_count DESC
LIMIT 50
""").show(n=50)

+---------+---------+
|icd9_code|sid_count|
+---------+---------+
|     4019|    17138|
|    41401|    10579|
|    42731|    10053|
|     4280|     9669|
|     5849|     7505|
|     2724|     7324|
|    25000|     7181|
|    51881|     6493|
|     5990|     5687|
|     2720|     5199|
|    53081|     5148|
|     2859|     4895|
|      486|     4329|
|     2851|     4195|
|     2762|     4021|
|     2449|     3732|
|      496|     3491|
|    99592|     3449|
|     5070|     3319|
|     0389|     3304|
|    V5861|     3126|
|     3051|     2926|
|    41071|     2861|
|      311|     2859|
|     5859|     2855|
|    40390|     2784|
|     2761|     2757|
|     2875|     2751|
|      412|     2723|
|     4240|     2613|
|     5119|     2528|
|     V290|     2519|
|    V1582|     2486|
|    78552|     2335|
|     4241|     2285|
|     9971|     2278|
|    42789|     2259|
|    V4581|     2256|
|    V4582|     2202|
|     7742|     2174|
|     5845|     2122|
|     V053|     2116|
|     5180

Top ICD 9 codes based on "hadm_id" count

In [15]:
spark.sql("""
SELECT icd9_code, COUNT(DISTINCT hadm_id) AS hadm_count
FROM diagnoses_icd_m2
GROUP BY icd9_code
ORDER BY hadm_count DESC
LIMIT 50
""").show(n=50)

+---------+----------+
|icd9_code|hadm_count|
+---------+----------+
|     4019|     20046|
|     4280|     12842|
|    42731|     12589|
|    41401|     12178|
|     5849|      8906|
|    25000|      8783|
|     2724|      8503|
|    51881|      7249|
|     5990|      6442|
|    53081|      6154|
|     2720|      5766|
|     2859|      5295|
|     2449|      4785|
|      486|      4732|
|     2851|      4499|
|     2762|      4358|
|      496|      4296|
|    99592|      3792|
|    V5861|      3697|
|     5070|      3592|
|     0389|      3580|
|     5859|      3367|
|    40390|      3350|
|      311|      3347|
|     3051|      3272|
|      412|      3203|
|     2875|      3002|
|    41071|      3001|
|     2761|      2985|
|    V4581|      2943|
|     4240|      2876|
|    V1582|      2741|
|     5119|      2693|
|    V4582|      2651|
|    40391|      2566|
|     V290|      2529|
|     4241|      2517|
|    78552|      2501|
|    V5867|      2497|
|    42789|      2396|
|    32723|

## Data Preprocessing (all icd9 codes)

Returns RDD[(hadm_id, list(icd9_codes))]

In [16]:
icd9_score_hadm = spark.sql("""
SELECT icd9_code, COUNT(DISTINCT hadm_id) AS score
FROM diagnoses_icd_m2
GROUP BY icd9_code
""").rdd.cache()

icd9_score_subj = spark.sql("""
SELECT icd9_code, COUNT(DISTINCT subject_id) AS score
FROM diagnoses_icd_m2
GROUP BY icd9_code
""").rdd.cache()

def get_id_to_topicd9(id_type, topX):
    if id_type == "hadm_id":
        icd9_score = icd9_score_hadm
    else:
        icd9_score = icd9_score_subj
        
    icd9_topX = set([i.icd9_code for i in icd9_score.takeOrdered(topX, key=lambda x: -x.score)])
    
    id_to_topicd9 = df_diag_m2.rdd \
        .map(lambda x: (x.hadm_id if id_type=="hadm_id" else x.subject_id, x.icd9_code)) \
        .groupByKey() \
        .mapValues(lambda x: set(x) & icd9_topX) \
        .filter(lambda (x, y): y)
        
    return id_to_topicd9, list(icd9_topX)

# for i in get_id_to_topicd9("hadm_id", 10)[0].take(3):
#     print i
# for i in get_id_to_topicd9("subject_id", 50)[0].take(3):
#     print i

Obtain dataframe for the merged noteevents and ID-to-ICD9 mapping

In [17]:
def sparse2vec(mapper, data):
    out = [0] * len(mapper)
    for i in data:
        out[mapper[i]] = 1
    return out

def get_id_to_texticd9(id_type, topX):
    id_to_topicd9, topicd9 = get_id_to_topicd9(id_type, topX)
    mapper = dict(zip(topicd9, range(topX)))
    
    ne_topX = df_ne.rdd \
        .filter(lambda x: x.category == "Discharge summary") \
        .map(lambda x: (x.hadm_id if id_type=="hadm_id" else x.subject_id, x.text)) \
        .groupByKey() \
        .mapValues(lambda x: " ".join(x)) \
        .join(id_to_topicd9) \
        .map(lambda (id_, (text, icd9)): \
             [id_, text]+sparse2vec(mapper, icd9))
#              list(Vectors.sparse(topX, dict.fromkeys(map(lambda x: mapper[x], icd9), 1))))
        
    return spark.createDataFrame(ne_topX, ["id", "text"]+topicd9), mapper

# get_id_to_texticd9("hadm_id", 10)[0].show()

## Feature Extraction

### TF-IDF
Input df must be RDD[(label, text)]

In [18]:
from pyspark.ml.feature import HashingTF, IDF, RegexTokenizer, StopWordsRemover

def create_TFIDF(sentenceData, inputCol="text", outputCol="features", minDocFreq=3, numFeatures=20):
    tokenizer = RegexTokenizer(pattern="[.:\s]+", inputCol=inputCol, outputCol="z_words")
    wordsData = tokenizer.transform(sentenceData)
    
    remover = StopWordsRemover(inputCol="z_words", outputCol="z_filtered")
    wordsDataFiltered = remover.transform(wordsData)
    
    hashingTF = HashingTF(inputCol="z_filtered", outputCol="z_rawFeatures", numFeatures=numFeatures)
    featurizedData = hashingTF.transform(wordsDataFiltered)
    # alternatively, CountVectorizer can also be used to get term frequency vectors

    idf = IDF(inputCol="z_rawFeatures", outputCol=outputCol, minDocFreq=minDocFreq)
    idfModel = idf.fit(featurizedData)
    rescaledData = idfModel.transform(featurizedData)
    
    return rescaledData.drop("z_words", "z_filtered", "z_rawFeatures", inputCol)

In [19]:
from pyspark.mllib.util import Vectors
from pyspark.mllib.linalg import VectorUDT
from pyspark.sql.functions import UserDefinedFunction
from pyspark.sql.types import DataType, StringType

def output_csv(df, path):
    udf = UserDefinedFunction(lambda x: Vectors.stringify(x), StringType())
    new_df = df.withColumn('features', udf(df.features))
    
    new_df.write.csv(path, header=True)
    
def read_csv(path):
    df = spark.read.csv(path, header=True, inferSchema=True)
    
    udf = UserDefinedFunction(lambda x: Vectors.parse(x), VectorUDT())
    new_df = df.withColumn('features', udf(df.features))
    
    return new_df

Output to pickle file

In [20]:
df_id2texticd9, topicd9_mapper = get_id_to_texticd9("hadm_id", 10)
df_id2featurelabel = create_TFIDF(df_id2texticd9, numFeatures=40000)

print topicd9_mapper
print df_id2featurelabel.dtypes
df_id2featurelabel.show()

output_csv(df_id2featurelabel, "./data/DATA_TFIDF_HADM_TOP10")

{u'4019': 0, u'2724': 1, u'25000': 2, u'4280': 3, u'41401': 4, u'42731': 7, u'5849': 8, u'53081': 5, u'51881': 6, u'5990': 9}
[('id', 'bigint'), ('4019', 'bigint'), ('2724', 'bigint'), ('25000', 'bigint'), ('4280', 'bigint'), ('41401', 'bigint'), ('53081', 'bigint'), ('51881', 'bigint'), ('42731', 'bigint'), ('5849', 'bigint'), ('5990', 'bigint'), ('features', 'vector')]
+------+----+----+-----+----+-----+-----+-----+-----+----+----+--------------------+
|    id|4019|2724|25000|4280|41401|53081|51881|42731|5849|5990|            features|
+------+----+----+-----+----+-----+-----+-----+-----+----+----+--------------------+
|117760|   0|   0|    0|   0|    0|    1|    1|    0|   0|   0|(40000,[69,372,69...|
|129030|   1|   1|    0|   0|    0|    1|    0|    0|   0|   0|(40000,[13,32,83,...|
|172040|   0|   0|    0|   0|    1|    0|    0|    0|   1|   0|(40000,[10,69,152...|
|156170|   0|   0|    1|   1|    0|    0|    0|    1|   1|   0|(40000,[3,78,130,...|
|199180|   0|   0|    1|   1|  

[Test] Load csv file
count should be the same with the sql query

In [21]:
testdf = read_csv("./data/DATA_TFIDF_HADM_TOP10")
print testdf.count()
testdf.show()

40562
+------+----+----+-----+----+-----+-----+-----+-----+----+----+--------------------+
|    id|4019|2724|25000|4280|41401|53081|51881|42731|5849|5990|            features|
+------+----+----+-----+----+-----+-----+-----+-----+----+----+--------------------+
|185344|   0|   0|    0|   0|    1|    0|    0|    1|   0|   1|(40000,[20,32,69,...|
|169474|   1|   0|    0|   0|    0|    0|    0|    0|   0|   0|(40000,[63,80,207...|
|180054|   1|   0|    0|   0|    0|    0|    0|    0|   1|   1|(40000,[32,115,13...|
|137734|   0|   1|    0|   0|    0|    0|    0|    1|   0|   0|(40000,[48,148,20...|
|121864|   1|   0|    0|   1|    0|    0|    0|    1|   0|   0|(40000,[273,379,8...|
|115884|   1|   0|    1|   0|    0|    0|    0|    0|   0|   0|(40000,[100,361,5...|
|105994|   1|   0|    0|   1|    0|    0|    0|    0|   0|   0|(40000,[20,32,207...|
|110594|   0|   0|    0|   0|    1|    0|    0|    0|   0|   0|(40000,[78,107,14...|
|176144|   0|   1|    1|   1|    1|    0|    0|    0|   0| 

In [22]:
spark.sql("""
SELECT icd9_code
FROM diagnoses_icd_m2
GROUP BY icd9_code
ORDER BY COUNT(DISTINCT hadm_id) DESC
LIMIT 10
""").show()
    
id_to_topicd9, topicd9 = get_id_to_topicd9("hadm_id", 10)
print id_to_topicd9.count()

spark.sql("""
SELECT COUNT(DISTINCT hadm_id) AS hadm_count
FROM diagnoses_icd_m2
WHERE icd9_code IN
    (SELECT icd9_code
    FROM diagnoses_icd_m2
    GROUP BY icd9_code
    ORDER BY COUNT(DISTINCT hadm_id) DESC
    LIMIT 10)
""").show()

+---------+
|icd9_code|
+---------+
|     4019|
|     4280|
|    42731|
|    41401|
|     5849|
|    25000|
|     2724|
|    51881|
|     5990|
|    53081|
+---------+

40562
+----------+
|hadm_count|
+----------+
|     40562|
+----------+



### Top 10 ICD 9 codes category (cleaned) -- to follow
### Top 50 ICD 9 codes category (cleaned) -- to follow

In [23]:
sc.stop()