# Decision Trees for Predicting Dengue Fever: Subgroup Assessment

**Abstract:** A decision tree was trained on tabular medical data (age, temperature, WBC exam score, Plt exam score) to predict whether or not patients had dengue fever. Datasets were split according to various features to assess the model's predictive stability. Dengue fever is a serious illness that is initially not easily detectable, and initial symptoms tend to be similar to those of the flu. The advantage of using decision trees to predict the existence dengue fever in patients, as opposed to using other machine learning methods, is the method can be easily understood by doctors and is easily interpretable.

## Theory

### Decision Trees

Decision trees, or Classification and Regression Trees (CART), are widely-used machine learning tools that involve successively separating a set of samples into subsets or nodes of a tree. Typically, each node is separated into two subsequent nodes until there are no samples left to separate, or until a max-depth is reached for the tree. The nodes are split based upon a condition or characteristic that influences the dependent variable the most for samples in that node. In training, the most influential conditions are determined for each node when the tree is constructed, thereby grouping similar samples together into nodes, with similarity increasing with each split, (i.e., with increasing depth of the tree). Each node is classified based upon the most common value for the dependent variable in that node, thereby giving the nodes predictive power over the samples grouped into them. After the tree is constructed, the conditions that were used to split the nodes of the training set can be applied to similar datasets to group similar samples together, after which predictions can be made for samples based upon their nodes.

Note that while sample similarity for nodes of the tree increase as the depth of the tree increases, a tree that is too deep will overfit and describe conditions specific only to the training set. Thus, controlling the max-depth of the tree is a technique used to prevent overfitting, thereby forcing the tree to model more general trends inherent in the data.

### Subgroups

#### n-Fold Cross Validation: Training Multiple Models

In n-fold cross validation, the overall dataset is split into equal parts for the purpose of training multiple models, such that all of the data is used at least once in a validation set. The confusion matrices (true negatives, false positives, false negatives, true positives) for each model are then added together to calculate the overall metrics of the model.

For this experiment, 5-fold cross validation was performed. The dataset was split into five equal subsets (each 20% of original dataset). The subsets were grouped into five validation sets (20% of original dataset) and five training sets (80% of original dataset, i.e., original dataset minus validation set). Five logistic regression models were trained and then assessed with corresponding validation sets. The resulting five confusion matrices were summed, producing an overall confusion matrix that would be used to calculate overall metrics.

#### Validation Subgroup Analysis

The validation datasets of each model were split into multiple subgroups based on various features, *(Age, Gender, Epidemic periods, Body Temp, White Blood Cells, Platelet, Comorbidities, coming to ER)*, to assess stability of the model accross datasets. The subgroup confusion matrices were summed across multiple models, resulting in overall confusion matrices for each subgroup. Confusion matrices were then used to calculate overall metrics for each subgroup.

## Setup


This notebook was originally written with the following versions:

    'R version 3.6.1 (2019-07-05)'
    'rpart 4.1.15'
    'rpart 4.1.15'
    'rpart.plot 3.0.8'
    'caTools 1.17.1.2'
    'caret 6.0.83'
    'data.table 1.12.2'
    'pROC 1.15.3'
    

In [1]:
# Check your versions
version$version.string;
paste("rpart", packageVersion("rpart"));
paste("rpart", packageVersion("rpart"));
paste("rpart.plot", packageVersion("rpart.plot"));
paste("caTools", packageVersion("caTools"));
paste("caret", packageVersion("caret"));
paste("data.table", packageVersion("data.table"));
paste("pROC", packageVersion("pROC"));

##### Prepare Notebook

In [2]:
# Import Libraries
library(rpart)
library(rpart.plot)
library(caTools)

library(caret)
library(data.table)
library(pROC)
# library(tidyverse)

# Data Information
# filename <- 'patient_year_vital_lab_exam_add-on_death_outcome_comorbidity_TCIC_dengue_suspected_bmi_ER_label_missing_mask.csv'
filename <- 'patients_cleaned.csv'
path <- '../mydata/'
pathfile <- paste(path,filename,sep='')

Loading required package: lattice
Loading required package: ggplot2
Registered S3 methods overwritten by 'ggplot2':
  method         from 
  [.quosures     rlang
  c.quosures     rlang
  print.quosures rlang
Type 'citation("pROC")' for a citation.

Attaching package: ‘pROC’

The following objects are masked from ‘package:stats’:

    cov, smooth, var



In [3]:
# View all CSV column names

# df <- read.csv(pathfile)
# names(df)

## Parameters

In [4]:
splits = 10 # number of folds for cross validation
recalls = c(0.85,0.90,0.95) # sensitivities used for calculating results
iters = 20 # iterations experiment is run

# Imported columns from CSV
desired_cols = c('age','sex','Temp','exam_WBC','exam_Plt', 'Opd_Visit_Date',
                'ER', 'Heart Disease', 'CVA', 'CKD', 'Severe Liver Disease', 
                'DM', 'Hypertension', 'Cancer without Metastasis', 'Cancer with Metastasis',
                'lab_result')

# Features used for training + dependent variable
train_cols = c('age','Temp','exam_WBC','exam_Plt','lab_result')


# Features used for creating validation subgroups (includes features from train_cols)
subgroup_cols = c('age','sex','Temp','exam_WBC','exam_Plt', 'week',
                'ER', 'Heart Disease', 'CVA', 'CKD', 'Severe Liver Disease', 
                'DM', 'Hypertension', 'Cancer without Metastasis', 'Cancer with Metastasis',
                'lab_result')

prior <- seq(0.1, 0.9, 0.008)

# Columns to be dropped after creating validation subgroups
# drop_cols = list(set(subgroup_cols) - set(train_cols))

## Data Preparation

In [5]:
# Create a dataframe with just train_cols as features
df <- read.csv(pathfile)
df <- df[,train_cols]
df <- as.data.frame(df)
# df <- cbind(index = as.numeric(row.names(df)), df) # adds indices?

# remove negative Temp values
df$Temp[which(df$Temp == -1)] <- NA

# remove rows with NA values
df <- na.omit(df)

# shuffle the df by row
df <- df[sample(nrow(df)),]

print(nrow(df))
head(df) # df[seq(6),]

[1] 4894


Unnamed: 0,age,Temp,exam_WBC,exam_Plt,lab_result
2647,70,36.9,5.0,62,False
1859,39,37.0,9.4,123,False
2484,22,37.3,4.8,153,False
4399,68,38.2,2.7,57,True
4042,66,39.1,7.1,202,True
117,36,39.3,9.9,247,True


In [6]:
summary(df)

      age              Temp          exam_WBC         exam_Plt     lab_result  
 Min.   :  0.00   Min.   :34.80   Min.   : 0.600   Min.   :  2.0   False:1952  
 1st Qu.: 27.00   1st Qu.:37.50   1st Qu.: 4.000   1st Qu.:124.0   True :2942  
 Median : 47.00   Median :38.30   Median : 6.000   Median :169.0               
 Mean   : 46.82   Mean   :38.18   Mean   : 6.814   Mean   :171.1               
 3rd Qu.: 66.00   3rd Qu.:38.90   3rd Qu.: 8.500   3rd Qu.:216.0               
 Max.   :104.00   Max.   :41.30   Max.   :41.600   Max.   :976.0               

## Training

In [7]:
parms = 33
prior <- seq(0.1, 0.9, 0.008)
ncols = 14
result_table <- matrix(ncol=ncols) #matrix(NA, ncol = 282 ) # nrow(valid_df))
result_tables <- list()
overall_table <- matrix(ncol=ncols)

# ppv<-rep(NA,iters)
# npv<-rep(NA,iters)
# tpr<-rep(NA,iters)
# tnr<-rep(NA,iters)
# acc<-rep(NA,iters)
# LR_p<-rep(NA,iters)
# LR_n<-rep(NA,iters)
# pred<-rep(NA,iters)


# for (i in prior)

for (iter in 1:iters) {
    result_table <- matrix(ncol=ncols) #matrix(NA, ncol = 282 ) # nrow(valid_df))
    flds <- createFolds(df$lab_result, k = splits, list = TRUE, returnTrain = FALSE) # separate df into subsets
    for (i in 1:splits) {
        idx <- flds[[i]]
        train_df <- df[-idx,]
        valid_df <- df[idx,]
        dtreeM <- rpart(formula = lab_result ~ ., data = train_df, method = "class", 
            parms = list(prior = c(prior[parms], 1 - prior[parms])))
        cp <- dtreeM$cptable[which.min(dtreeM$cptable[, 3]), 1]
        dtreeM_pruned <- prune(dtreeM, cp = cp)
        preds <- predict(dtreeM_pruned, newdata = valid_df, type = "class")

    #     print(mean(preds == valid_df[,5])) # accuracy

        accuracy = mean(preds == valid_df[,5])

        valid_result <- cbind(df[rownames(valid_df),], preds)

        # Confusion Matrix stuff
        FN <- nrow(subset(valid_result, (preds != lab_result) & (preds == 'False')))
        TN <- nrow(subset(valid_result, (preds == lab_result) & (preds == 'False')))
        FP <- nrow(subset(valid_result, (preds != lab_result) & (preds == 'True')))
        TP <- nrow(subset(valid_result, (preds == lab_result) & (preds == 'True')))
        size <- FN + TN + FP + TP

        PPV = if (TP+FP) TP / (TP + FP) else 0 # positive predict value
        NPV = if (TN+FN) TN / (TN + FN) else 0 # negative predict value
        F1 = 2*TP / (2*TP + FP + FN) #
        accuracy = (TP + TN) / (TP + TN + FP + FN)
        sensitivity = if (TP+FN) TP /(TP + FN) else 0 
        specificity = if (TN+FP) TN /(TN + FP) else 0 
        odds_ratio = if (FP*FN) (TP * TN) /(FP * FN) else 0

        model_num = i

        roc_obj <- roc(as.numeric(valid_df$lab_result), as.numeric(preds))
        roc_auc <- auc(roc_obj)

        model_metrics <- cbind(model_num, roc_auc, PPV, NPV, F1, accuracy, sensitivity, specificity, odds_ratio, FN, TN, FP, TP, size)

        result_table <- rbind(result_table, model_metrics)

    #     toprint = paste(format(accuracy, digits=4), FN, TN, FP, TP, sep=" | ")
    #     print(toprint)

    #     result_table <- rbind(result_table, cbind(df[rownames(valid_df), ], 
    #         result, validation = nrow(valid_df))) #rep(j, nrow(valid_df))))
        }
    result_table <- result_table[-1,]


    # Combine Metrics -> Overall Metrics
    overall <- data.frame(NA, stringsAsFactors=FALSE)
    for(i in 2:14){
        overall <- cbind(overall, sum(result_table[,i]))
    }
    colnames(overall) <- (colnames(result_table))

    FN = overall[1,10]
    TN = overall[1,11]
    FP = overall[1,12]
    TP = overall[1,13]
    size <- sum(FN,TN,FP,TP)

    PPV = if (TP+FP) TP / (TP + FP) else 0 # positive predict value
    NPV = if (TN+FN) TN / (TN + FN) else 0 # negative predict value
    F1 = 2*TP / (2*TP + FP + FN) #
    accuracy = (TP + TN) / (TP + TN + FP + FN)
    sensitivity = if (TP+FN) TP /(TP + FN) else 0 
    specificity = if (TN+FP) TN /(TN + FP) else 0 
    odds_ratio = if (FP*FN) (TP * TN) /(FP * FN) else 0
    model_num = NA


    # roc_obj <- roc(as.numeric(valid_df$lab_result), as.numeric(preds))
    # roc_auc <- auc(roc_obj)
    roc_auc <- mean(result_table[1:splits,2])
    # roc_auc <- simple_auc(sensitivity,1-specificity)
    # roc_auc <- 0

    overall <- cbind(model_num, roc_auc, PPV, NPV, F1, accuracy, sensitivity, specificity, odds_ratio, FN, TN, FP, TP, size)

    result_table <- rbind(result_table, overall)

    iteration <- iter
    overall <- cbind(iteration, roc_auc, PPV, NPV, F1, accuracy, sensitivity, specificity, odds_ratio, FN, TN, FP, TP, size)
    overall_table <-rbind(overall_table, overall)
    
    result_tables[[iter]] <- result_table
    }
overall_table <- overall_table[-1,]


Setting levels: control = 1, case = 2
Setting direction: controls < cases
Setting levels: control = 1, case = 2
Setting direction: controls < cases
Setting levels: control = 1, case = 2
Setting direction: controls < cases
Setting levels: control = 1, case = 2
Setting direction: controls < cases
Setting levels: control = 1, case = 2
Setting direction: controls < cases
Setting levels: control = 1, case = 2
Setting direction: controls < cases
Setting levels: control = 1, case = 2
Setting direction: controls < cases
Setting levels: control = 1, case = 2
Setting direction: controls < cases
Setting levels: control = 1, case = 2
Setting direction: controls < cases
Setting levels: control = 1, case = 2
Setting direction: controls < cases
Setting levels: control = 1, case = 2
Setting direction: controls < cases
Setting levels: control = 1, case = 2
Setting direction: controls < cases
Setting levels: control = 1, case = 2
Setting direction: controls < cases
Setting levels: control = 1, case = 2


Setting levels: control = 1, case = 2
Setting direction: controls < cases
Setting levels: control = 1, case = 2
Setting direction: controls < cases
Setting levels: control = 1, case = 2
Setting direction: controls < cases
Setting levels: control = 1, case = 2
Setting direction: controls < cases
Setting levels: control = 1, case = 2
Setting direction: controls < cases
Setting levels: control = 1, case = 2
Setting direction: controls < cases
Setting levels: control = 1, case = 2
Setting direction: controls < cases
Setting levels: control = 1, case = 2
Setting direction: controls < cases
Setting levels: control = 1, case = 2
Setting direction: controls < cases
Setting levels: control = 1, case = 2
Setting direction: controls < cases
Setting levels: control = 1, case = 2
Setting direction: controls < cases
Setting levels: control = 1, case = 2
Setting direction: controls < cases
Setting levels: control = 1, case = 2
Setting direction: controls < cases
Setting levels: control = 1, case = 2


In [8]:
# average across iterations
avgs <- data.frame(NA, stringsAsFactors=FALSE)
sds <- data.frame(NA, stringsAsFactors=FALSE)
for(i in 2:14){
    avgs <- cbind(avgs, mean(overall_table[,i]))
    sds <- cbind(sds, sd(overall_table[,i]))
}

colnames(avgs) <- (colnames(overall_table))
colnames(sds) <- (colnames(overall_table))
avgs[,1] <- 'Avg'
sds[,1] <- 'SD'
combined <- rbind(avgs,sds)
# overall_table <-rbind(overall_table, avgs)
# overall_table <-rbind(overall_table, sds)

In [9]:
combined

iteration,roc_auc,PPV,NPV,F1,accuracy,sensitivity,specificity,odds_ratio,FN,TN,FP,TP,size
Avg,0.767603526,0.785922609,0.818958713,0.842554383,0.796005313,0.907987763,0.627228484,16.6255866,270.7,1224.35,727.65,2671.3,4894
SD,0.002559414,0.001990486,0.004896385,0.001908505,0.002423384,0.003006209,0.004328738,0.6587088,8.844267,8.449696,8.449696,8.844267,0


In [10]:
overall_table

iteration,roc_auc,PPV,NPV,F1,accuracy,sensitivity,specificity,odds_ratio,FN,TN,FP,TP,size
1,0.7729092,0.7900324,0.8250166,0.8460445,0.8007765,0.910605,0.6352459,17.74021,263,1240,712,2679,4894
2,0.7695911,0.7861948,0.8277966,0.8451501,0.7987331,0.9136642,0.6255123,17.6764,254,1221,731,2688,4894
3,0.767206,0.7858616,0.8172115,0.8420388,0.7954638,0.9068661,0.6275615,16.40729,274,1225,727,2668,4894
4,0.7663355,0.7847283,0.8186702,0.8419726,0.7950552,0.9082257,0.6244877,16.45782,270,1219,733,2672,4894
5,0.7683877,0.7860288,0.8224613,0.8435974,0.7970985,0.9102651,0.6265369,17.01788,264,1223,729,2678,4894
6,0.7685545,0.786365,0.8215962,0.8434988,0.7970985,0.9095853,0.6275615,16.95142,266,1225,727,2676,4894
7,0.7688284,0.7865961,0.8217158,0.8436318,0.7973028,0.9095853,0.6280738,16.98863,266,1226,726,2676,4894
8,0.7655733,0.785124,0.812749,0.8404423,0.7936248,0.9041468,0.6270492,15.85925,282,1224,728,2660,4894
9,0.7662794,0.7851045,0.8162993,0.8414576,0.7946465,0.9065262,0.6260246,16.23449,275,1222,730,2667,4894
10,0.7648182,0.7844294,0.8123752,0.8400442,0.7930119,0.9041468,0.6255123,15.75545,282,1221,731,2660,4894


In [11]:
save_file = 'DT_dengue_Overall_Results.csv'
save_pathfile <- paste(path,save_file,sep='')

# Write CSV
write.csv(combined, file = save_pathfile)

## Overall Results

In [12]:
result_table[6,]

In [13]:
round_df <- function(x, digits) {
    # round all numeric variables
    # x: data frame 
    # digits: number of digits to round
    y = x
    numeric_columns <- sapply(x, mode) == 'numeric'
    y[numeric_columns] <- round(x[numeric_columns], digits)
    y
}

round_df(combined, 5)

iteration,roc_auc,PPV,NPV,F1,accuracy,sensitivity,specificity,odds_ratio,FN,TN,FP,TP,size
Avg,0.7676,0.78592,0.81896,0.84255,0.79601,0.90799,0.62723,16.62559,270.7,1224.35,727.65,2671.3,4894
SD,0.00256,0.00199,0.0049,0.00191,0.00242,0.00301,0.00433,0.65871,8.84427,8.4497,8.4497,8.84427,0
