Permalink
Switch branches/tags
Nothing to show
Find file Copy path
Fetching contributors…
Cannot retrieve contributors at this time
166 lines (127 sloc) 8.5 KB
layout title
page
Cross-validation
library(knitr)
opts_chunk$set(fig.path=paste0("figure/", sub("(.*).Rmd","\\1",basename(knitr:::knit_concord$get('infile'))), "-"))

Cross-validation

Here we describe cross-validation: one of the fundamental methods in machine learning for method assessment and picking parameters in a prediction or machine learning task. Suppose we have a set of observations with many features and each observation is associated with a label. We will call this set our training data. Our task is to predict the label of any new samples by learning patterns from the training data. For a concrete example, let's consider gene expression values, where each gene acts as a feature. We will be given a new set of unlabeled data (the test data) with the task of predicting the tissue type of the new samples.

If we choose a machine learning algorithm with a tunable parameter, we have to come up with a strategy for picking an optimal value for this parameter. We could try some values, and then just choose the one which performs the best on our training data, in terms of the number of errors the algorithm would make if we apply it to the samples we have been given for training. However, we have seen how this leads to over-fitting.

Let's start by loading the tissue gene expression dataset:

library(tissuesGeneExpression)
data(tissuesGeneExpression)

For illustration purposes, we will drop one of the tissues which doesn't have many samples:

table(tissue)
ind <- which(tissue != "placenta")
y <- tissue[ind]
X <- t( e[,ind] )

This tissue will not form part of our example.

Now let's try out k-nearest neighbors for classification, using $k=5$. What is our average error in predicting the tissue in the training set, when we've used the same data for training and for testing?

library(class)
pred <- knn(train =  X, test = X, cl=y, k=5)
mean(y != pred)

We have no errors in prediction in the training set with $k=5$. What if we use $k=1$?

pred <- knn(train=X, test=X, cl=y, k=1)
mean(y != pred)

Trying to classify the same observations as we use to train the model can be very misleading. In fact, for k-nearest neighbors, using k=1 will always give 0 classification error in the training set, because we use the single observation to classify itself. The reliable way to get a sense of the performance of an algorithm is to make it give a prediction for a sample it has never seen. Similarly, if we want to know what the best value for a tunable parameter is, we need to see how different values of the parameter perform on samples, which are not in the training data.

Cross-validation is a widely-used method in machine learning, which solves this training and test data problem, while still using all the data for testing the predictive accuracy. It accomplishes this by splitting the data into a number of folds. If we have $N$ folds, then the first step of the algorithm is to train the algorithm using $(N-1)$ of the folds, and test the algorithm's accuracy on the single left-out fold. This is then repeated N times until each fold has been used as in the test set. If we have $M$ parameter settings to try out, then this is accomplished in an outer loop, so we have to fit the algorithm a total of $N \times M$ times.

We will use the createFolds function from the caret package to make 5 folds of our gene expression data, which are balanced over the tissues. Don't be confused by the fact that the createFolds function uses the same letter 'k' as the 'k' in k-nearest neighbors. These 'k' are totally unrelated. The caret function createFolds is asking for how many folds to create, the $N$ from above. The 'k' in the knn function is for how many closest observations to use in classifying a new sample. Here we will create 10 folds:

library(caret)
set.seed(1)
idx <- createFolds(y, k=10)
sapply(idx, length)

The folds are returned as a list of numeric indices. The first fold of data is therefore:

y[idx[[1]]] ##the labels
head( X[idx[[1]], 1:3] ) ##the genes (only showing the first 3 genes...)

We can see that, in fact, the tissues are fairly equally represented across the 10 folds:

sapply(idx, function(i) table(y[i]))

Because tissues have very different gene expression profiles, predicting tissue with all genes will be very easy. For illustration purposes we will try to predict tissue type with just two dimensional data. We will reduce the dimension of our data using cmdscale:

library(rafalib)
mypar()
Xsmall <- cmdscale(dist(X))
plot(Xsmall,col=as.fumeric(y))
legend("topleft",levels(factor(y)),fill=seq_along(levels(factor(y))))

Now we can try out the k-nearest neighbors method on a single fold. We provide the knn function with all the samples in Xsmall except those which are in the first fold. We remove these samples using the code -idx[[1]] inside the square brackets. We then use those samples in the test set. The cl argument is for the true classifications or labels (here, tissue) of the training data. We use 5 observations to classify in our k-nearest neighbor algorithm:

pred <- knn(train=Xsmall[ -idx[[1]] , ], test=Xsmall[ idx[[1]], ], cl=y[ -idx[[1]] ], k=5)
table(true=y[ idx[[1]] ], pred)
mean(y[ idx[[1]] ] != pred)

Now we have some misclassifications. How well do we do for the rest of the folds?

for (i in 1:10) {
  pred <- knn(train=Xsmall[ -idx[[i]] , ], test=Xsmall[ idx[[i]], ], cl=y[ -idx[[i]] ], k=5)
  print(paste0(i,") error rate: ", round(mean(y[ idx[[i]] ] != pred),3)))
}

So we can see there is some variation for each fold, with error rates hovering around 0.1-0.3. But is k=5 the best setting for the k parameter? In order to explore the best setting for k, we need to create an outer loop, where we try different values for k, and then calculate the average test set error across all the folds.

We will try out each value of k from 1 to 12. Instead of using two for loops, we will use sapply:

set.seed(1)
ks <- 1:12
res <- sapply(ks, function(k) {
  ##try out each version of k from 1 to 12
  res.k <- sapply(seq_along(idx), function(i) {
    ##loop over each of the 10 cross-validation folds
    ##predict the held-out samples using k nearest neighbors
    pred <- knn(train=Xsmall[ -idx[[i]], ],
                test=Xsmall[ idx[[i]], ],
                cl=y[ -idx[[i]] ], k = k)
    ##the ratio of misclassified samples
    mean(y[ idx[[i]] ] != pred)
  })
  ##average over the 10 folds
  mean(res.k)
})

Now for each value of k, we have an associated test set error rate from the cross-validation procedure.

res

We can then plot the error rate for each value of k, which helps us to see in what region there might be a minimal error rate:

plot(ks, res, type="o",ylab="misclassification error")

Remember, because the training set is a random sample and because our fold-generation procedure involves random number generation, the "best" value of k we pick through this procedure is also a random variable. If we had new training data and if we recreated our folds, we might get a different value for the optimal k.

Finally, to show that gene expression can perfectly predict tissue, we use 5 dimensions instead of 2, which results in perfect prediction:

Xsmall <- cmdscale(dist(X),k=5)
set.seed(1)
ks <- 1:12
res <- sapply(ks, function(k) {
  res.k <- sapply(seq_along(idx), function(i) {
    pred <- knn(train=Xsmall[ -idx[[i]], ],
                test=Xsmall[ idx[[i]], ],
                cl=y[ -idx[[i]] ], k = k)
    mean(y[ idx[[i]] ] != pred)
  })
  mean(res.k)
})
plot(ks, res, type="o",ylim=c(0,0.20),ylab="misclassification error")

Important note: we applied cmdscale to the entire dataset to create a smaller one for illustration purposes. However, in a real machine learning application, this may result in an underestimation of test set error for small sample sizes, where dimension reduction using the unlabeled full dataset gives a boost in performance. A safer choice would have been to transform the data separately for each fold, by calculating a rotation and dimension reduction using the training set only and applying this to the test set.