In [1]:
library(grid)
library(rpart)
library(rpart.plot)
library(partykit)
library(lattice)
library(ggplot2)
library(caret)

In [2]:
set.seed(1)

In [3]:
jRoadType <- c("都市間高速", "都市高速", "有料道路", "国道", "県道", "主要地方道", "一般道1、一般道2、一般道3", "その他")

In [4]:
printf <- function(...) cat(sprintf(...))

In [5]:
invalids <- c('Time', 'Longitude', 'Latitude', 'Brake', 'Accel', 'RoadType', 'flag')
isValidColumn <- function(x) ! x %in% invalids

In [6]:
CV <- function(dfx, expr, verbose=FALSE) {        
    folds <- createFolds(dfx$flag)
    
    count <- 1
    errs <- c()
    reds <- c()
    blues <- c()
    for (ids in folds) {
        train <- dfx[-ids, ]
        test <- dfx[ids, ]
        fit <- rpart(expr, data=train, method="class", cp=0.013)
        # fit <- rpart(expr, data=train, method="class", cp=0.013)
        p <- predict(fit, newdata=test)
        predictedFlags <- colnames(p)[max.col(p, ties.method = "first")]
        
        if (verbose) {
            printf("Fold%d\n", count)
            result <- correctVsPredict(test, predictedFlags, verbose)
            reds <- c(reds, result[1])
            blues <- c(blues, result[2])
            count <- count + 1
            printf("\n")
        } else {
            result <- correctVsPredict(test, predictedFlags, verbose)
            reds <- c(reds, result[1])
            blues <- c(blues, result[2])
        }
        
        nerr <- sum((predictedFlags == test$flag) == FALSE)
        errs <- c(errs, nerr / nrow(test))
    }
    
    return(c(mean(errs), mean(reds), mean(blues)))
}

In [7]:
correctVsPredict <- function(test, predictedFlags, verbose=FALSE) {    
    # for All
    # nTests <- nrow(test)
    # nCorrectAll <- sum((predictedFlags == test$flag) == TRUE)
    # printf("As fo All: correct/all = %d/%d = %f\n", nCorrectAll, nTests, nCorrectAll / nTests)
    
    # for Red
    predictedRedRows <- test[predictedFlags == "Red", ]
    nCorrectReds <- sum((predictedRedRows$flag == 'Red') == TRUE)
    nPredictedReds <- nrow(predictedRedRows)
    
    # for Blue
    predictedBlueRows <- test[predictedFlags == "Blue", ]
    nCorrectBlues <- sum((predictedBlueRows$flag == 'Blue') == TRUE)
    nPredictedBlues <- nrow(predictedBlueRows)
        
    if (verbose) {
        printf("As for Red: correct/predict = %d/%d = %f\n", nCorrectReds, nPredictedReds, nCorrectReds / nPredictedReds)
        printf("As for Blue: correct/predict = %d/%d = %f\n", nCorrectBlues, nPredictedBlues, nCorrectBlues / nPredictedBlues)  
    }    
    
    c(nCorrectReds/nPredictedReds, nCorrectBlues/nPredictedBlues)
}

In [8]:
printRedRatios <- function(dfx) {
    nRed <- nrow(dfx[dfx$flag == "Red", ])
    nAll <- nrow(dfx)
    printf("Red/All = %d/%d = %f\n", nRed, nAll, nRed/nAll)
    printf("1 - Red/All = %d/%d = %f\n", nAll - nRed, nAll, 1 - nRed/nAll)
}

# Predict Reds

In [9]:
df3 <- read.csv("../data/middle/sp1.csv", stringsAsFactors=FALSE)

In [10]:
df3$flag[df3$flag == "RedA"] <- "Red"
df3$flag[df3$flag == "RedB"] <- "Red"
df3$flag[df3$flag == "BlueA"] <- "Blue"
df3$flag[df3$flag == "BlueB"] <- "Blue"
df3$flag <- as.factor(df3$flag)

In [11]:
allFeatures <- c(colnames(df3))
features <- Filter(isValidColumn, allFeatures)

In [12]:
expr <- paste("flag ~ ", paste(features, collapse=" + "))

In [13]:
roadTypes <- unique(df3$RoadType)

In [14]:
for (i in roadTypes) {
    printf("RoadType: %d (%s)\n", i, jRoadType[i+1])
    dfx <- df3[df3$RoadType == i, ]
    result <- CV(dfx, expr, FALSE)
    printRedRatios(dfx)
    printf("Red: Mean correct/predict = %f\n", result[2])
    printf("Blue: Mean correct/predict = %f\n", result[3])
    printf("CV value: %f", result[1])
    printf("\n\n")
}

RoadType: 7 (その他)
Red/All = 25/27 = 0.925926
1 - Red/All = 2/27 = 0.074074
Red: Mean correct/predict = 0.933333
Blue: Mean correct/predict = NaN
CV value: 0.066667

RoadType: 6 (一般道1、一般道2、一般道3)
Red/All = 90/189 = 0.476190
1 - Red/All = 99/189 = 0.523810
Red: Mean correct/predict = 0.613969
Blue: Mean correct/predict = 0.644714
CV value: 0.380994

RoadType: 4 (県道)
Red/All = 36/71 = 0.507042
1 - Red/All = 35/71 = 0.492958
Red: Mean correct/predict = 0.483333
Blue: Mean correct/predict = 0.350000
CV value: 0.540476

RoadType: 5 (主要地方道)
Red/All = 154/303 = 0.508251
1 - Red/All = 149/303 = 0.491749
Red: Mean correct/predict = 0.569190
Blue: Mean correct/predict = 0.551010
CV value: 0.441613

RoadType: 3 (国道)
Red/All = 202/324 = 0.623457
1 - Red/All = 122/324 = 0.376543
Red: Mean correct/predict = 0.626721
Blue: Mean correct/predict = 0.399503
CV value: 0.453504

RoadType: 0 (都市間高速)
Red/All = 55/81 = 0.679012
1 - Red/All = 26/81 = 0.320988
Red: Mean correct/predict = 0.654524
Blue: Mean corr

In [15]:
fit3 <- rpart(expr, data=df3, method="class")

In [16]:
fit3$cptable[which.min(fit3$cptable[,"xerror"]),"CP"]

In [17]:
printcp(fit3)


Classification tree:
rpart(formula = expr, data = df3, method = "class")

Variables actually used in tree construction:
 [1] AccelerationSpeed AheadDistance     Curve             DistManBicycle   
 [5] DistSignal        Pitch             RiskFactor        Speed            
 [9] TimeHeadway       TimeToCollision  

Root node error: 449/1042 = 0.4309

n= 1042 

        CP nsplit rel error  xerror     xstd
1 0.031180      0   1.00000 1.00000 0.035602
2 0.023385      6   0.81069 0.92873 0.035223
3 0.018931      8   0.76392 0.94655 0.035331
4 0.013363     10   0.72606 0.91091 0.035106
5 0.011136     12   0.69933 0.91759 0.035151
6 0.010000     13   0.68820 0.92650 0.035209


In [18]:
summary(fit3)

Call:
rpart(formula = expr, data = df3, method = "class")
  n= 1042 

          CP nsplit rel error    xerror       xstd
1 0.03118040      0 1.0000000 1.0000000 0.03560167
2 0.02338530      6 0.8106904 0.9287305 0.03522312
3 0.01893096      8 0.7639198 0.9465479 0.03533107
4 0.01336303     10 0.7260579 0.9109131 0.03510616
5 0.01113586     12 0.6993318 0.9175947 0.03515108
6 0.01000000     13 0.6881960 0.9265033 0.03520900

Variable importance
       RiskFactor       TimeHeadway   TimeToCollision             Speed 
               14                13                12                12 
    AheadDistance AccelerationSpeed        DistSignal             Curve 
               11                 8                 8                 7 
   DistManBicycle            Engine             Pitch     SteeringAngle 
                4                 3                 3                 2 
  ManBicycleCount        ManBicycle         LaneCount              Jerk 
                1                 1      