Skip to content

Commit

Permalink
A Gentle Introduction to k-fold Cross-Validation
Browse files Browse the repository at this point in the history
  • Loading branch information
khanhnamle1994 committed Jun 4, 2018
1 parent e4b23e5 commit 6faaecb
Show file tree
Hide file tree
Showing 2 changed files with 548 additions and 0 deletions.
@@ -0,0 +1,274 @@
{
"cells": [
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# A Gentle Introduction to k-fold Cross-Validation\n",
"Cross-validation is a statistical method used to estimate the skill of machine learning models.\n",
"\n",
"It is commonly used in applied machine learning to compare and select a model for a given predictive modeling problem because it is easy to understand, easy to implement, and results in skill estimates that generally have a lower bias than other methods.\n",
"\n",
"In this tutorial, you will discover a gentle introduction to the k-fold cross-validation procedure for estimating the skill of machine learning models.\n",
"\n",
"After completing this tutorial, you will know:\n",
"\n",
"* That k-fold cross validation is a procedure used to estimate the skill of the model on new data.\n",
"* There are common tactics that you can use to select the value of k for your dataset.\n",
"* There are commonly used variations on cross-validation such as stratified and repeated that are available in scikit-learn.\n",
"\n",
"Let’s get started."
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Tutorial Overview\n",
"This tutorial is divided into 5 parts; they are:\n",
"\n",
"1. k-Fold Cross-Validation\n",
"2. Configuration of k\n",
"3. Worked Example\n",
"4. Cross-Validation API\n",
"5. Variations on Cross-Validation"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## k-Fold Cross-Validation\n",
"Cross-validation is a resampling procedure used to evaluate machine learning models on a limited data sample.\n",
"\n",
"The procedure has a single parameter called k that refers to the number of groups that a given data sample is to be split into. As such, the procedure is often called k-fold cross-validation. When a specific value for k is chosen, it may be used in place of k in the reference to the model, such as k=10 becoming 10-fold cross-validation.\n",
"\n",
"Cross-validation is primarily used in applied machine learning to estimate the skill of a machine learning model on unseen data. That is, to use a limited sample in order to estimate how the model is expected to perform in general when used to make predictions on data not used during the training of the model.\n",
"\n",
"It is a popular method because it is simple to understand and because it generally results in a less biased or less optimistic estimate of the model skill than other methods, such as a simple train/test split.\n",
"\n",
"The general procedure is as follows:\n",
"\n",
"1. Shuffle the dataset randomly.\n",
"2. Split the dataset into k groups\n",
"3. For each unique group:\n",
" * Take the group as a hold out or test data set\n",
" * Take the remaining groups as a training data set\n",
" * Fit a model on the training set and evaluate it on the test set\n",
" * Retain the evaluation score and discard the model\n",
"4. Summarize the skill of the model using the sample of model evaluation scores\n",
"\n",
"Importantly, each observation in the data sample is assigned to an individual group and stays in that group for the duration of the procedure. This means that each sample is given the opportunity to be used in the hold out set 1 time and used to train the model k-1 times.\n",
"\n",
"It is also important that any preparation of the data prior to fitting the model occur on the CV-assigned training dataset within the loop rather than on the broader data set. This also applies to any tuning of hyperparameters. A failure to perform these operations within the loop may result in data leakage and an optimistic estimate of the model skill.\n",
"\n",
"The results of a k-fold cross-validation run are often summarized with the mean of the model skill scores. It is also good practice to include a measure of the variance of the skill scores, such as the standard deviation or standard error."
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Configuration of k\n",
"The k value must be chosen carefully for your data sample.\n",
"\n",
"A poorly chosen value for k may result in a mis-representative idea of the skill of the model, such as a score with a high variance (that may change a lot based on the data used to fit the model), or a high bias, (such as an overestimate of the skill of the model).\n",
"\n",
"Three common tactics for choosing a value for k are as follows:\n",
"\n",
"* Representative: The value for k is chosen such that each train/test group of data samples is large enough to be statistically representative of the broader dataset.\n",
"* k=10: The value for k is fixed to 10, a value that has been found through experimentation to generally result in a model skill estimate with low bias a modest variance.\n",
"* k=n: The value for k is fixed to n, where n is the size of the dataset to give each test sample an opportunity to be used in the hold out dataset. This approach is called leave-one-out cross-validation.\n",
"\n",
"A value of k=10 is very common in the field of applied machine learning, and is recommend if you are struggling to choose a value for your dataset.\n",
"\n",
"If a value for k is chosen that does not evenly split the data sample, then one group will contain a remainder of the examples. It is preferable to split the data sample into k groups with the same number of samples, such that the sample of model skill scores are all equivalent.\n",
"\n",
"## Worked Example\n",
"To make the cross-validation procedure concrete, let’s look at a worked example.\n",
"\n",
"Imagine we have a data sample with 6 observations:"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"collapsed": true
},
"outputs": [],
"source": [
"[0.1, 0.2, 0.3, 0.4, 0.5, 0.6]"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"The first step is to pick a value for k in order to determine the number of folds used to split the data. Here, we will use a value of k=3. That means we will shuffle the data and then split the data into 3 groups. Because we have 6 observations, each group will have an equal number of 2 observations.\n",
"\n",
"For example:"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"collapsed": true
},
"outputs": [],
"source": [
"Fold1: [0.5, 0.2]\n",
"Fold2: [0.1, 0.3]\n",
"Fold3: [0.4, 0.6]"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"We can then make use of the sample, such as to evaluate the skill of a machine learning algorithm.\n",
"\n",
"Three models are trained and evaluated with each fold given a chance to be the held out test set.\n",
"\n",
"For example:\n",
"\n",
"* Model1: Trained on Fold1 + Fold2, Tested on Fold3\n",
"* Model2: Trained on Fold2 + Fold3, Tested on Fold1\n",
"* Model3: Trained on Fold1 + Fold3, Tested on Fold2\n",
"\n",
"The models are then discarded after they are evaluated as they have served their purpose.\n",
"\n",
"The skill scores are collected for each model and summarized for use.\n",
"\n",
"## Cross-Validation API\n",
"We do not have to implement k-fold cross-validation manually. The scikit-learn library provides an implementation that will split a given data sample up.\n",
"\n",
"The KFold() scikit-learn class can be used. It takes as arguments the number of splits, whether or not to shuffle the sample, and the seed for the pseudorandom number generator used prior to the shuffle.\n",
"\n",
"For example, we can create an instance that splits a dataset into 3 folds, shuffles prior to the split, and uses a value of 1 for the pseudorandom number generator."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"collapsed": true
},
"outputs": [],
"source": [
"kfold = KFold(3, True, 1)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"The split() function can then be called on the class where the data sample is provided as an argument. Called repeatedly, the split will return each group of train and test sets. Specifically, arrays are returned containing the indexes into the original data sample of observations to use for train and test sets on each iteration.\n",
"\n",
"For example, we can enumerate the splits of the indices for a data sample using the created KFold instance as follows:"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"collapsed": true
},
"outputs": [],
"source": [
"# enumerate splits\n",
"for train, test in kfold.split(data):\n",
"\tprint('train: %s, test: %s' % (train, test))"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"We can tie all of this together with our small dataset used in the worked example of the prior section."
]
},
{
"cell_type": "code",
"execution_count": 1,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"train: [0.1 0.4 0.5 0.6], test: [0.2 0.3]\n",
"train: [0.2 0.3 0.4 0.6], test: [0.1 0.5]\n",
"train: [0.1 0.2 0.3 0.5], test: [0.4 0.6]\n"
]
}
],
"source": [
"# scikit-learn k-fold cross-validation\n",
"from numpy import array\n",
"from sklearn.model_selection import KFold\n",
"# data sample\n",
"data = array([0.1, 0.2, 0.3, 0.4, 0.5, 0.6])\n",
"# prepare cross validation\n",
"kfold = KFold(3, True, 1)\n",
"# enumerate splits\n",
"for train, test in kfold.split(data):\n",
"\tprint('train: %s, test: %s' % (data[train], data[test]))"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Running the example prints the specific observations chosen for each train and test set. The indices are used directly on the original data array to retrieve the observation values.\n",
"\n",
"Usefully, the k-fold cross validation implementation in scikit-learn is provided as a component operation within broader methods, such as grid-searching model hyperparameters and scoring a model on a dataset.\n",
"\n",
"Nevertheless, the KFold class can be used directly in order to split up a dataset prior to modeling such that all models will use the same data splits. This is especially helpful if you are working with very large data samples. The use of the same splits across algorithms can have benefits for statistical tests that you may wish to perform on the data later.\n",
"\n",
"## Variations on Cross-Validation\n",
"There are a number of variations on the k-fold cross validation procedure.\n",
"\n",
"Three commonly used variations are as follows:\n",
"\n",
"* Train/Test Split: Taken to one extreme, k may be set to 1 such that a single train/test split is created to evaluate the model.\n",
"* LOOCV: Taken to another extreme, k may be set to the total number of observations in the dataset such that each observation is given a chance to be the held out of the dataset. This is called leave-one-out cross-validation, or LOOCV for short.\n",
"* Stratified: The splitting of data into folds may be governed by criteria such as ensuring that each fold has the same proportion of observations with a given categorical value, such as the class outcome value. This is called stratified cross-validation.\n",
"* Repeated: This is where the k-fold cross-validation procedure is repeated n times, where importantly, the data sample is shuffled prior to each repetition, which results in a different split of the sample.\n",
"\n",
"The scikit-learn library provides a suite of cross-validation implementation. You can see the full list in the Model Selection API.\n",
"\n",
"## Summary\n",
"In this tutorial, you discovered a gentle introduction to the k-fold cross-validation procedure for estimating the skill of machine learning models.\n",
"\n",
"Specifically, you learned:\n",
"\n",
"* That k-fold cross validation is a procedure used to estimate the skill of the model on new data.\n",
"* There are common tactics that you can use to select the value of k for your dataset.\n",
"* There are commonly used variations on cross-validation, such as stratified and repeated, that are available in scikit-learn."
]
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.6.4"
}
},
"nbformat": 4,
"nbformat_minor": 2
}

0 comments on commit 6faaecb

Please sign in to comment.