In [1]:
#importing libraries required
import findspark
findspark.init()
import pyspark # Call this only after findspark.init()
from pyspark.context import SparkContext
from pyspark.sql.session import SparkSession

sc = SparkContext.getOrCreate()
spark = SparkSession(sc)
from pyspark.sql import Row
from pyspark.sql.types import *
from pyspark.sql.functions import sum
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
from pyspark.sql.functions import rank, col, unix_timestamp, from_unixtime, to_timestamp
from pyspark.sql import functions as F
import seaborn as sns
timeFmt = "yyyy-MM-dd"
from pyspark.sql.functions import *

In [None]:
#Reading MIMIC data
df = spark.read.csv("mimic_master.csv", header='true', inferSchema='true')
df.show(5)


In [4]:
#count row and column to explore
print((df.count(), len(df.columns)))


(52643, 16)


In [5]:
df.printSchema()

root
 |-- person_id: integer (nullable = true)
 |-- age: integer (nullable = true)
 |-- paramCT: integer (nullable = true)
 |-- gpiCT: integer (nullable = true)
 |-- ndcCT: integer (nullable = true)
 |-- ahfsCT: integer (nullable = true)
 |-- medCT: integer (nullable = true)
 |-- gender: string (nullable = true)
 |-- birth_date: timestamp (nullable = true)
 |-- F_T2D_Diag: timestamp (nullable = true)
 |-- F_T1D_Diag: timestamp (nullable = true)
 |-- F_LD_Diag: timestamp (nullable = true)
 |-- F_KD_Diag: timestamp (nullable = true)
 |-- F_CVD_Diag: timestamp (nullable = true)
 |-- F_ALZ_Diag: timestamp (nullable = true)
 |-- F_ALZD_Diag: timestamp (nullable = true)



In [6]:
#Certain selected columns as required
df1=df.select('person_id','age','gender','birth_date','F_T2D_Diag','F_ALZ_Diag','F_ALZD_Diag')

In [7]:
df1.show()

+---------+---+------+-------------------+-------------------+----------+-----------+
|person_id|age|gender|         birth_date|         F_T2D_Diag|F_ALZ_Diag|F_ALZD_Diag|
+---------+---+------+-------------------+-------------------+----------+-----------+
|      148| 78|     F|2029-07-11 00:00:00|               null|      null|       null|
|      463| 62|     F|2136-09-25 00:00:00|               null|      null|       null|
|      471| 75|     F|2046-08-30 00:00:00|               null|      null|       null|
|      833|  0|     M|2137-05-23 00:00:00|               null|      null|       null|
|     1088| 68|     M|2102-03-05 00:00:00|               null|      null|       null|
|     1238|  0|     F|2197-03-27 00:00:00|               null|      null|       null|
|     1342| 72|     F|2034-03-20 00:00:00|               null|      null|       null|
|     1580| 44|     F|2081-05-14 00:00:00|               null|      null|       null|
|     1591|  0|     M|2106-04-01 00:00:00|            

In [8]:
type(df1)

pyspark.sql.dataframe.DataFrame

In [9]:
# Number of Rows and Columns in the sub dataset
print((df1.count(), len(df1.columns)))

(52643, 7)


In [10]:
# Printing the schema of the dataset
df1.printSchema()


root
 |-- person_id: integer (nullable = true)
 |-- age: integer (nullable = true)
 |-- gender: string (nullable = true)
 |-- birth_date: timestamp (nullable = true)
 |-- F_T2D_Diag: timestamp (nullable = true)
 |-- F_ALZ_Diag: timestamp (nullable = true)
 |-- F_ALZD_Diag: timestamp (nullable = true)



In [11]:
df1.columns

['person_id',
 'age',
 'gender',
 'birth_date',
 'F_T2D_Diag',
 'F_ALZ_Diag',
 'F_ALZD_Diag']

In [12]:
df1.groupby(["gender"]).count().show()

+------+-----+
|gender|count|
+------+-----+
|     F|23184|
|     M|29459|
+------+-----+



# Creating cohorts

In [13]:
#Age diag
Age_T2D_diag =F.round((F.col("F_T2D_Diag").cast("long") - F.col("birth_date").cast("long"))/(365*60*60*24), 3)
Age_AD_diag =F.round((F.col("F_ALZ_Diag").cast("long") - F.col("birth_date").cast("long"))/(365*60*60*24), 3)
Age_T2DAD_both_diag=F.round((F.col("F_T2D_Diag").cast("long") - F.col("F_ALZ_Diag").cast("long"))/(365*60*60*24), 3)
Age_DEM_diag=F.round((F.col("F_ALZD_Diag").cast("long") - F.col("F_ALZ_Diag").cast("long"))/(365*60*60*24), 3)
Age_T2DDEM_both_diag=F.round((F.col("F_T2D_Diag").cast("long") - F.col("F_ALZD_Diag").cast("long"))/(365*60*60*24), 3)

In [14]:
df1=df1.withColumn("T2D_ONLY",Age_T2D_diag).withColumn("AD_ONLY",Age_AD_diag).withColumn("BOTH_T2D_AD"\
,Age_T2DAD_both_diag).withColumn("DEM_ONLY",Age_DEM_diag).withColumn("BOTH_AD_DEM",Age_T2DDEM_both_diag)

In [15]:
df1.show(3)

+---------+---+------+-------------------+----------+----------+-----------+--------+-------+-----------+--------+-----------+
|person_id|age|gender|         birth_date|F_T2D_Diag|F_ALZ_Diag|F_ALZD_Diag|T2D_ONLY|AD_ONLY|BOTH_T2D_AD|DEM_ONLY|BOTH_AD_DEM|
+---------+---+------+-------------------+----------+----------+-----------+--------+-------+-----------+--------+-----------+
|      148| 78|     F|2029-07-11 00:00:00|      null|      null|       null|    null|   null|       null|    null|       null|
|      463| 62|     F|2136-09-25 00:00:00|      null|      null|       null|    null|   null|       null|    null|       null|
|      471| 75|     F|2046-08-30 00:00:00|      null|      null|       null|    null|   null|       null|    null|       null|
+---------+---+------+-------------------+----------+----------+-----------+--------+-------+-----------+--------+-----------+
only showing top 3 rows



In [16]:
df1

DataFrame[person_id: int, age: int, gender: string, birth_date: timestamp, F_T2D_Diag: timestamp, F_ALZ_Diag: timestamp, F_ALZD_Diag: timestamp, T2D_ONLY: double, AD_ONLY: double, BOTH_T2D_AD: double, DEM_ONLY: double, BOTH_AD_DEM: double]

In [17]:
df1.columns

['person_id',
 'age',
 'gender',
 'birth_date',
 'F_T2D_Diag',
 'F_ALZ_Diag',
 'F_ALZD_Diag',
 'T2D_ONLY',
 'AD_ONLY',
 'BOTH_T2D_AD',
 'DEM_ONLY',
 'BOTH_AD_DEM']

In [18]:
#df1.filter(df1.age >=100).show()

In [19]:
df2=df1.filter(df1.age <=110) #filtering invalid dates as age with 300 or more found in data.

In [20]:
#df2.describe().show(3,False)

In [21]:
df2.printSchema()

root
 |-- person_id: integer (nullable = true)
 |-- age: integer (nullable = true)
 |-- gender: string (nullable = true)
 |-- birth_date: timestamp (nullable = true)
 |-- F_T2D_Diag: timestamp (nullable = true)
 |-- F_ALZ_Diag: timestamp (nullable = true)
 |-- F_ALZD_Diag: timestamp (nullable = true)
 |-- T2D_ONLY: double (nullable = true)
 |-- AD_ONLY: double (nullable = true)
 |-- BOTH_T2D_AD: double (nullable = true)
 |-- DEM_ONLY: double (nullable = true)
 |-- BOTH_AD_DEM: double (nullable = true)



In [22]:
df2.show(4)

+---------+---+------+-------------------+----------+----------+-----------+--------+-------+-----------+--------+-----------+
|person_id|age|gender|         birth_date|F_T2D_Diag|F_ALZ_Diag|F_ALZD_Diag|T2D_ONLY|AD_ONLY|BOTH_T2D_AD|DEM_ONLY|BOTH_AD_DEM|
+---------+---+------+-------------------+----------+----------+-----------+--------+-------+-----------+--------+-----------+
|      148| 78|     F|2029-07-11 00:00:00|      null|      null|       null|    null|   null|       null|    null|       null|
|      463| 62|     F|2136-09-25 00:00:00|      null|      null|       null|    null|   null|       null|    null|       null|
|      471| 75|     F|2046-08-30 00:00:00|      null|      null|       null|    null|   null|       null|    null|       null|
|      833|  0|     M|2137-05-23 00:00:00|      null|      null|       null|    null|   null|       null|    null|       null|
+---------+---+------+-------------------+----------+----------+-----------+--------+-------+-----------+------

In [23]:
df_final=df2.select('person_id','gender','T2D_ONLY','AD_ONLY','DEM_ONLY',\
                'BOTH_T2D_AD','BOTH_AD_DEM')


## missing data checking and handling

In [24]:
df_final.show()

+---------+------+--------+-------+--------+-----------+-----------+
|person_id|gender|T2D_ONLY|AD_ONLY|DEM_ONLY|BOTH_T2D_AD|BOTH_AD_DEM|
+---------+------+--------+-------+--------+-----------+-----------+
|      148|     F|    null|   null|    null|       null|       null|
|      463|     F|    null|   null|    null|       null|       null|
|      471|     F|    null|   null|    null|       null|       null|
|      833|     M|    null|   null|    null|       null|       null|
|     1088|     M|    null|   null|    null|       null|       null|
|     1238|     F|    null|   null|    null|       null|       null|
|     1342|     F|    null|   null|    null|       null|       null|
|     1580|     F|    null|   null|    null|       null|       null|
|     1591|     M|    null|   null|    null|       null|       null|
|     1645|     F|    null|   null|    null|       null|       null|
|     1829|     M|  53.854|   null|    null|       null|       null|
|     1959|     F|    null|   null

In [25]:
from pyspark.sql.functions import isnan, when, count, col
df_final.select([count(when(isnan(c) | col(c).isNull(), c)).alias(c) for c in df_final.columns]).show()

+---------+------+--------+-------+--------+-----------+-----------+
|person_id|gender|T2D_ONLY|AD_ONLY|DEM_ONLY|BOTH_T2D_AD|BOTH_AD_DEM|
+---------+------+--------+-------+--------+-----------+-----------+
|        0|     0|   38695|  49915|   49969|      50232|      49458|
+---------+------+--------+-------+--------+-----------+-----------+



### Imputation of mean in selected null columns

In [26]:
def fill_with_mean(df, exclude=set()): 
    stats = df.agg(*(
     avg(c).alias(c) for c in df.columns if c not in exclude 
    )) 
    return df.na.fill(stats.first().asDict()) 

df_final=fill_with_mean(df_final, ["person_id", "gender"])

In [27]:
df_final.show(5)

+---------+------+-----------------+-----------------+--------------------+------------------+-------------------+
|person_id|gender|         T2D_ONLY|          AD_ONLY|            DEM_ONLY|       BOTH_T2D_AD|        BOTH_AD_DEM|
+---------+------+-----------------+-----------------+--------------------+------------------+-------------------+
|      148|     F|66.59624051391883|80.04218681318676|0.008650872817955116|-0.773536231884058|-1.0035460526315794|
|      463|     F|66.59624051391883|80.04218681318676|0.008650872817955116|-0.773536231884058|-1.0035460526315794|
|      471|     F|66.59624051391883|80.04218681318676|0.008650872817955116|-0.773536231884058|-1.0035460526315794|
|      833|     M|66.59624051391883|80.04218681318676|0.008650872817955116|-0.773536231884058|-1.0035460526315794|
|     1088|     M|66.59624051391883|80.04218681318676|0.008650872817955116|-0.773536231884058|-1.0035460526315794|
+---------+------+-----------------+-----------------+--------------------+-----

In [28]:
df_final.groupby(["gender"]).count().show()

+------+-----+
|gender|count|
+------+-----+
|     F|21733|
|     M|28637|
+------+-----+



In [29]:
df_final.select('person_id','gender','T2D_ONLY').show(5)

+---------+------+-----------------+
|person_id|gender|         T2D_ONLY|
+---------+------+-----------------+
|      148|     F|66.59624051391883|
|      463|     F|66.59624051391883|
|      471|     F|66.59624051391883|
|      833|     M|66.59624051391883|
|     1088|     M|66.59624051391883|
+---------+------+-----------------+
only showing top 5 rows



In [30]:
df_final.crosstab('T2D_ONLY', 'gender').show()

+---------------+---+---+
|T2D_ONLY_gender|  F|  M|
+---------------+---+---+
|         64.013|  1|  0|
|         63.479|  0|  1|
|         54.665|  0|  1|
|         74.361|  0|  1|
|         77.988|  0|  2|
|         84.822|  0|  1|
|         58.253|  3|  0|
|         76.129|  0|  1|
|         71.094|  3|  0|
|         64.898|  0|  1|
|         59.402|  0|  1|
|          67.86|  0|  1|
|          52.18|  4|  0|
|         55.432|  0|  1|
|         40.366|  1|  0|
|         70.671|  1|  0|
|         78.558|  1|  0|
|         78.211|  3|  0|
|         61.638|  0|  1|
|         52.156|  0|  1|
+---------------+---+---+
only showing top 20 rows



In [31]:
df_final.select('gender').show(5)

+------+
|gender|
+------+
|     F|
|     F|
|     F|
|     M|
|     M|
+------+
only showing top 5 rows



# Converting to pandas

In [32]:
pd_final=df_final.toPandas()

In [33]:
# Checking for missing values
pd_final.isna().sum()

person_id      0
gender         0
T2D_ONLY       0
AD_ONLY        0
DEM_ONLY       0
BOTH_T2D_AD    0
BOTH_AD_DEM    0
dtype: int64

In [34]:
#counting gender 
pd_final['gender'].value_counts()

M    28637
F    21733
Name: gender, dtype: int64

## Initial visualization 

In [None]:
%%time
sns.pairplot(pd_final, kind='reg', plot_kws={'line_kws':{'color': 'cyan'}})
plt.show()

In [67]:
df_final1=df_final.drop('person_id') # PID not required

In [89]:
df_final

DataFrame[person_id: int, gender: string, T2D_ONLY: double, AD_ONLY: double, DEM_ONLY: double, BOTH_T2D_AD: double, BOTH_AD_DEM: double]

In [90]:
df_final1=df_final.select('person_id','T2D_ONLY','AD_ONLY','DEM_ONLY','BOTH_T2D_AD','BOTH_AD_DEM')

In [91]:
df_final1.show()

+---------+-----------------+-----------------+--------------------+------------------+-------------------+
|person_id|         T2D_ONLY|          AD_ONLY|            DEM_ONLY|       BOTH_T2D_AD|        BOTH_AD_DEM|
+---------+-----------------+-----------------+--------------------+------------------+-------------------+
|      148|66.59624051391883|80.04218681318676|0.008650872817955116|-0.773536231884058|-1.0035460526315794|
|      463|66.59624051391883|80.04218681318676|0.008650872817955116|-0.773536231884058|-1.0035460526315794|
|      471|66.59624051391883|80.04218681318676|0.008650872817955116|-0.773536231884058|-1.0035460526315794|
|      833|66.59624051391883|80.04218681318676|0.008650872817955116|-0.773536231884058|-1.0035460526315794|
|     1088|66.59624051391883|80.04218681318676|0.008650872817955116|-0.773536231884058|-1.0035460526315794|
|     1238|66.59624051391883|80.04218681318676|0.008650872817955116|-0.773536231884058|-1.0035460526315794|
|     1342|66.59624051391883

In [92]:
from functools import partial
from pyspark.sql import Row
def flatten_table(column_names, column_values):
    row = zip(column_names, column_values)
    _, person_id = next(row)  # Special casing retrieving the first column
    return [
        Row(person_id=person_id,Disease_group=column, Disease_Age=value)
        for column, value in row
    ]
    
df_final1=df_final1.rdd.flatMap(partial(flatten_table, df_final1.columns)).toDF()

In [93]:
df_final1.show()

+---------+-------------+--------------------+
|person_id|Disease_group|         Disease_Age|
+---------+-------------+--------------------+
|      148|     T2D_ONLY|   66.59624051391883|
|      148|      AD_ONLY|   80.04218681318676|
|      148|     DEM_ONLY|0.008650872817955116|
|      148|  BOTH_T2D_AD|  -0.773536231884058|
|      148|  BOTH_AD_DEM| -1.0035460526315794|
|      463|     T2D_ONLY|   66.59624051391883|
|      463|      AD_ONLY|   80.04218681318676|
|      463|     DEM_ONLY|0.008650872817955116|
|      463|  BOTH_T2D_AD|  -0.773536231884058|
|      463|  BOTH_AD_DEM| -1.0035460526315794|
|      471|     T2D_ONLY|   66.59624051391883|
|      471|      AD_ONLY|   80.04218681318676|
|      471|     DEM_ONLY|0.008650872817955116|
|      471|  BOTH_T2D_AD|  -0.773536231884058|
|      471|  BOTH_AD_DEM| -1.0035460526315794|
|      833|     T2D_ONLY|   66.59624051391883|
|      833|      AD_ONLY|   80.04218681318676|
|      833|     DEM_ONLY|0.008650872817955116|
|      833|  

In [95]:
df_final2=df_final.select('person_id','gender')

In [96]:
df_final=df_final2.join(df_final1.select('person_id', 'Disease_group','Disease_Age'), ['person_id'])

In [97]:
df_final.show()

+---------+------+-------------+--------------------+
|person_id|gender|Disease_group|         Disease_Age|
+---------+------+-------------+--------------------+
|      148|     F|     T2D_ONLY|   66.59624051391883|
|      148|     F|      AD_ONLY|   80.04218681318676|
|      148|     F|     DEM_ONLY|0.008650872817955116|
|      148|     F|  BOTH_T2D_AD|  -0.773536231884058|
|      148|     F|  BOTH_AD_DEM| -1.0035460526315794|
|      463|     F|     T2D_ONLY|   66.59624051391883|
|      463|     F|      AD_ONLY|   80.04218681318676|
|      463|     F|     DEM_ONLY|0.008650872817955116|
|      463|     F|  BOTH_T2D_AD|  -0.773536231884058|
|      463|     F|  BOTH_AD_DEM| -1.0035460526315794|
|      471|     F|     T2D_ONLY|   66.59624051391883|
|      471|     F|      AD_ONLY|   80.04218681318676|
|      471|     F|     DEM_ONLY|0.008650872817955116|
|      471|     F|  BOTH_T2D_AD|  -0.773536231884058|
|      471|     F|  BOTH_AD_DEM| -1.0035460526315794|
|      833|     M|     T2D_O

# FEATURE ENGINEERING

In [98]:
from pyspark.ml.feature import OneHotEncoder, StringIndexer, VectorAssembler

categoricalColumns = ['person_id','gender']
stages = []

for categoricalCol in categoricalColumns:
    stringIndexer = StringIndexer(inputCol = categoricalCol, outputCol = categoricalCol + 'Index')
    encoder = OneHotEncoder(inputCols=[stringIndexer.getOutputCol()], outputCols=[categoricalCol + "classVec"])
    stages += [stringIndexer, encoder]

label_stringIdx = StringIndexer(inputCol = 'Disease_group', outputCol = 'label')
stages += [label_stringIdx]

numericCols = ['Disease_Age']
assemblerInputs = [c + "classVec" for c in categoricalColumns] + numericCols
assembler = VectorAssembler(inputCols=assemblerInputs, outputCol="features")
stages += [assembler]

In [99]:
from pyspark.ml import Pipeline
cols = df_final.columns
pipeline = Pipeline(stages = stages)
pipelineModel = pipeline.fit(df_final)
df_final = pipelineModel.transform(df_final)
selectedCols = ['label', 'features'] + cols
df_final =df_final.select(selectedCols)
df_final.printSchema()

root
 |-- label: double (nullable = false)
 |-- features: vector (nullable = true)
 |-- person_id: integer (nullable = true)
 |-- gender: string (nullable = true)
 |-- Disease_group: string (nullable = true)
 |-- Disease_Age: double (nullable = true)



# MACHINE LEARNING USING SPARK ML ON THE PIPELINED MODEL

In [100]:
train, test = df_final.randomSplit([0.7, 0.3], seed = 4000)
print("Training Dataset Count: " + str(train.count()))
print("Test Dataset Count: " + str(test.count()))

Training Dataset Count: 232043
Test Dataset Count: 99467


## LOGISTIC REGRESSION CLASSIFIER

Parameters : maxIter = 10

In [101]:
from pyspark.ml.classification import LogisticRegression
lr = LogisticRegression(featuresCol = 'features', labelCol = 'label', maxIter=10)
lrModel = lr.fit(train)

# Making predictions on test data using the transform() method.

In [102]:
predictions = lrModel.transform(test)

In [103]:
selected = predictions.select("label","prediction", "probability")
selected.show()

+-----+----------+--------------------+
|label|prediction|         probability|
+-----+----------+--------------------+
|  0.0|       0.0|[0.99932778726371...|
|  0.0|       0.0|[0.99932778726371...|
|  0.0|       0.0|[0.99932778726371...|
|  0.0|       0.0|[0.99932778726371...|
|  0.0|       0.0|[0.99932778726371...|
|  0.0|       0.0|[0.99932778726371...|
|  0.0|       0.0|[0.99932778726371...|
|  0.0|       0.0|[0.99932778726371...|
|  0.0|       0.0|[0.99932778726371...|
|  0.0|       0.0|[0.99932778726371...|
|  0.0|       0.0|[0.99932778726371...|
|  0.0|       0.0|[0.99932778726371...|
|  0.0|       0.0|[0.99932778726371...|
|  0.0|       0.0|[0.99932778726371...|
|  0.0|       0.0|[0.99932778726371...|
|  0.0|       0.0|[0.99932778726371...|
|  0.0|       0.0|[0.99932778726371...|
|  0.0|       0.0|[0.99932778726371...|
|  0.0|       0.0|[0.99932778726371...|
|  0.0|       0.0|[0.99932778726371...|
+-----+----------+--------------------+
only showing top 20 rows



In [104]:
from pyspark.ml.evaluation import BinaryClassificationEvaluator

evaluator = BinaryClassificationEvaluator()
print('Test Area Under ROC curve: ', evaluator.evaluate(predictions))

Test Area Under ROC curve:  0.8774112038071622


In [105]:
cm = predictions.select("label", "prediction")
cm.groupby('label').agg({'label': 'count'}).show()
cm.groupby('prediction').agg({'prediction': 'count'}).show()

+-----+------------+
|label|count(label)|
+-----+------------+
|  0.0|       19930|
|  1.0|       19906|
|  4.0|       19711|
|  3.0|       20034|
|  2.0|       19886|
+-----+------------+

+----------+-----------------+
|prediction|count(prediction)|
+----------+-----------------+
|       0.0|            22180|
|       1.0|             6995|
|       4.0|            18365|
|       3.0|            17291|
|       2.0|            34636|
+----------+-----------------+



In [106]:
def accuracy_m(model): 
    predictions = model.transform(test)
    cm = predictions.select("label", "prediction")
    acc = cm.filter(cm.label == cm.prediction).count() / cm.count()
    print("Model accuracy: %.3f%%" % (acc * 100)) 
accuracy_m(model = lrModel)

Model accuracy: 23.418%


In [107]:
from pyspark.ml.tuning import ParamGridBuilder, CrossValidator

# Create ParamGrid for Cross Validation
paramGrid = (ParamGridBuilder()
             .addGrid(lr.regParam, [0.01, 0.5, 2.0])
             .addGrid(lr.elasticNetParam, [0.0, 0.5, 1.0])
             .addGrid(lr.maxIter, [1, 5, 10])
             .build())

cv = CrossValidator(estimator=lr, estimatorParamMaps=paramGrid, evaluator=evaluator, numFolds=10)

cvModel = cv.fit(train)
predictions2 = cvModel.transform(test)
print('Test Area Under ROC for k=5 CV :', evaluator.evaluate(predictions2))

Test Area Under ROC for k=5 CV : 0.9907905506631922


In [108]:
def accuracy_m(model): 
    predictions = model.transform(test)
    cm = predictions.select("label", "prediction")
    acc = cm.filter(cm.label == cm.prediction).count() / cm.count()
    print("Model accuracy for logistic regression model on test data with k=5 CV: %.3f%%" % (acc * 100)) 
accuracy_m(model = cvModel)

Model accuracy for logistic regression model on test data with k=5 CV: 20.050%


# DECISION TREE CLASSIFIER

In [None]:
from pyspark.ml.classification import DecisionTreeClassifier

dt = DecisionTreeClassifier(featuresCol = 'features', labelCol = 'label', maxDepth = 3)
dtModel = dt.fit(train)
predictions = dtModel.transform(test)
predictions.select('label', 'prediction', 'probability').show(5)

In [None]:
from pyspark.ml.evaluation import BinaryClassificationEvaluator

evaluator = BinaryClassificationEvaluator()
print('Test Area Under ROC curve for DT classifier : ', evaluator.evaluate(predictions))

## Confusion matrix for Decision tree classifier model

In [None]:
cm = predictions.select("label", "prediction")
cm.groupby('label').agg({'label': 'count'}).show()
cm.groupby('prediction').agg({'prediction': 'count'}).show()

In [None]:
def accuracy_m(model): 
    predictions = model.transform(test)
    cm = predictions.select("label", "prediction")
    acc = cm.filter(cm.label == cm.prediction).count() / cm.count()
    print("Model accuracy for DT classifier: %.3f%%" % (acc * 100)) 
accuracy_m(model = dtModel)

# ENSEMBLE OF DECISION TREES : RANDOM FOREST CLASSIFIER

In [None]:
from pyspark.ml.classification import RandomForestClassifier

rf = RandomForestClassifier(featuresCol = 'features', labelCol = 'label')
rfModel = rf.fit(train)
predictions = rfModel.transform(test)
predictions.select('label','prediction', 'probability').show(10)

In [None]:
from pyspark.ml.evaluation import BinaryClassificationEvaluator

evaluator = BinaryClassificationEvaluator()
print('Test Area Under ROC curve for RF classifier : ', evaluator.evaluate(predictions))

## Confusion matrix for random forest classifier model

In [None]:
cm = predictions.select("label", "prediction")
cm.groupby('label').agg({'label': 'count'}).show()
cm.groupby('prediction').agg({'prediction': 'count'}).show()

In [None]:
def accuracy_m(model): 
    predictions = model.transform(test)
    cm = predictions.select("label", "prediction")
    acc = cm.filter(cm.label == cm.prediction).count() / cm.count()
    print("Model accuracy for RF classifier: %.3f%%" % (acc * 100)) 
accuracy_m(model = rfModel)

# EXTRACTING FEATURE IMPORTANCE FROM RANDOM FOREST CLASSIFIER

In [None]:
def ExtractFeatureImp(featureImp, dataset, featuresCol):
    list_extract = []
    for i in dataset.schema[featuresCol].metadata["ml_attr"]["attrs"]:
        list_extract = list_extract + dataset.schema[featuresCol].metadata["ml_attr"]["attrs"][i]
    varlist = pd.DataFrame(list_extract)
    varlist['score'] = varlist['idx'].apply(lambda x: featureImp[x])
    return(varlist.sort_values('score', ascending = False))

In [None]:
result = ExtractFeatureImp(rfModel.featureImportances, df_final1, "features").head(10)
result

In [None]:
df=df_final1
df2 = rfModel.transform(df)

In [None]:
from pyspark.ml.feature import VectorSlicer
from pyspark.ml.linalg import Vectors
from pyspark.sql.types import Row

varlist = ExtractFeatureImp(rfModel.featureImportances,df, "features")
varidx = [x for x in varlist['idx'][0:10]]
varidx
slicer = VectorSlicer(inputCol="features", outputCol="features2", indices=varidx)

df3 = slicer.transform(df2)

df3 = df3.drop('rawPrediction', 'probability', 'prediction')
rf2 = RandomForestClassifier(labelCol="label", featuresCol="features2", seed = 8464,
                            numTrees=10, cacheNodeIds = True, subsamplingRate = 0.7)
mod2 = rf2.fit(df3)
pred = mod2.transform(df3)

In [None]:
pred.show(5)

# GRADIENT BOOST TREE CLASSIFIER

In [None]:
from pyspark.ml.classification import GBTClassifier

gbt = GBTClassifier(featuresCol = 'features', labelCol = 'label', maxIter=10)
gbtModel = gbt.fit(train)
predictions = gbtModel.transform(test)
predictions.select('label','prediction', 'probability').show(10)

In [None]:
from pyspark.ml.evaluation import BinaryClassificationEvaluator

evaluator = BinaryClassificationEvaluator()
print('Test Area Under ROC curve for GBT classifier : ', evaluator.evaluate(predictions))

## Confusion matrix for GRADIENT BOOST TREE CLASSIFIER

In [None]:
cm = predictions.select("label", "prediction")
cm.groupby('label').agg({'label': 'count'}).show()
cm.groupby('prediction').agg({'prediction': 'count'}).show()

In [None]:
def accuracy_m(model): 
    predictions = model.transform(test)
    cm = predictions.select("label", "prediction")
    acc = cm.filter(cm.label == cm.prediction).count() / cm.count()
    print("Model accuracy for RF classifier: %.3f%%" % (acc * 100)) 
accuracy_m(model = rfModel)

# FEATURE IMPORTANCE FROM THE TREE CLASSIFIERS

In [None]:
rfModel.featureImportances # Add this for report
dtModel.featureImportances # Add this for report
gbtModel.featureImportances # Add this for report

In [None]:
f_imp = dtModel.featureImportances


def ExtractFeatureImp(featureImp, dataset, featuresCol):
    list_extract = []
    for i in dataset.schema[featuresCol].metadata["ml_attr"]["attrs"]:
        list_extract = list_extract + dataset.schema[featuresCol].metadata["ml_attr"]["attrs"][i]
    varlist = pd.DataFrame(list_extract)
    varlist['score'] = varlist['idx'].apply(lambda x: featureImp[x])
    return(varlist.sort_values('score', ascending = False))


result1 = ExtractFeatureImp(dtModel.featureImportances,df, "features").head(4)
result1

In [None]:
f_imp = gbtModel.featureImportances


def ExtractFeatureImp(featureImp, dataset, featuresCol):
    list_extract = []
    for i in dataset.schema[featuresCol].metadata["ml_attr"]["attrs"]:
        list_extract = list_extract + dataset.schema[featuresCol].metadata["ml_attr"]["attrs"][i]
    varlist = pd.DataFrame(list_extract)
    varlist['score'] = varlist['idx'].apply(lambda x: featureImp[x])
    return(varlist.sort_values('score', ascending = False))


result2 = ExtractFeatureImp(gbtModel.featureImportances,df, "features").head(10)
result2

In [None]:
f_imp = rfModel.featureImportances


def ExtractFeatureImp(featureImp, dataset, featuresCol):
    list_extract = []
    for i in dataset.schema[featuresCol].metadata["ml_attr"]["attrs"]:
        list_extract = list_extract + dataset.schema[featuresCol].metadata["ml_attr"]["attrs"][i]
    varlist = pd.DataFrame(list_extract)
    varlist['score'] = varlist['idx'].apply(lambda x: featureImp[x])
    return(varlist.sort_values('score', ascending = False))


result3 = ExtractFeatureImp(rfModel.featureImportances,df, "features").head(10)
result3

In [None]:
merged_df = pd.concat([result1, result2, result3])

In [None]:
merged_df.sort_values(by='score', ascending=False).head(10)

## Saving MIMIC  as mimic_final CSV file

In [None]:
df_final1=pd.read_csv("final_mimic.csv",index_col=0) #this is saving pyspark dataframe to a new csv file named final_mimic 