Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Added an implementation to Stratify Data #2671

Merged
merged 27 commits into from
Nov 15, 2020
Merged
Show file tree
Hide file tree
Changes from 25 commits
Commits
Show all changes
27 commits
Select commit Hold shift + click to select a range
c0f725e
StratifiedSplit file added with basic code
Abilityguy Oct 8, 2020
db94a2c
Added Stratified split implementation
Abilityguy Oct 10, 2020
e6985fc
Minor changes to preprocess_split file
Abilityguy Oct 10, 2020
9d34e23
Stratified Split implementation done
Abilityguy Oct 11, 2020
1f8697b
Added Stratified Split implementation and integrated with preprocess …
Abilityguy Oct 13, 2020
196d561
Minor style fix
Abilityguy Oct 13, 2020
dd11c74
Basic idea added in comments
Abilityguy Oct 13, 2020
a7befb5
Refactoring code and style fixes
Abilityguy Oct 14, 2020
2f36940
Moved StratifiedSplit templates moved into split_data.hpp and refacto…
Abilityguy Oct 19, 2020
fdc01e5
Fix for documentation failing test
Abilityguy Oct 19, 2020
d06cf96
Code review changes
Abilityguy Oct 21, 2020
86bc837
Added tests for stratified split
Abilityguy Oct 23, 2020
9eabc40
Possible fix for failing tests
Abilityguy Oct 23, 2020
e7539c6
Review changes
Abilityguy Oct 24, 2020
9f988a3
Merge branch 'master' into StratifiedSplit
Abilityguy Oct 24, 2020
58b8202
Fix for the failing tests
Abilityguy Oct 24, 2020
4244a9d
Replaced the casting with floor
Abilityguy Oct 24, 2020
0bdc8b8
Style fixes and algorithm fix
Abilityguy Oct 30, 2020
76fe709
Changed unordered map implementation to uvec implementation
Abilityguy Nov 1, 2020
1a49c5f
Removed the 'ReportIgnoredParam' attribute
Abilityguy Nov 8, 2020
b1e8382
Fix for failing style check
Abilityguy Nov 8, 2020
2fc09bf
Changes made based on review comments
Abilityguy Nov 11, 2020
3f2220f
Changed to direct looping of labels
Abilityguy Nov 12, 2020
67dc33e
Added tests to preprocess_split_test
Abilityguy Nov 13, 2020
04f2d46
Removed unused variables in a test case
Abilityguy Nov 13, 2020
db62edf
Merge branch 'master' into StratifiedSplit
Abilityguy Nov 14, 2020
765d278
Modified HISTORY.md
Abilityguy Nov 14, 2020
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
152 changes: 148 additions & 4 deletions src/mlpack/core/data/split_data.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,138 @@

namespace mlpack {
namespace data {

/**
* Given an input dataset and labels, stratify into a training set and test set.
* It is recommended to have the input labels between the range [0, n) where n
* is the number of different labels. The NormalizeLabels() function in
* mlpack::data can be used for this.
* Example usage below. This overload places the stratified dataset into the
* four output parameters given (trainData, testData, trainLabel,
* and testLabel).
*
* @code
* arma::mat input = loadData();
* arma::Row<size_t> label = loadLabel();
* arma::mat trainData;
* arma::mat testData;
* arma::Row<size_t> trainLabel;
* arma::Row<size_t> testLabel;
* math::RandomSeed(100); // Set the seed if you like.
*
* // Stratify the dataset into a training and test set, with 30% of the data
* // being held out for the test set.
* StratifiedSplit(input, label, trainData,
* testData, trainLabel, testLabel, 0.3);
* @endcode
*
* @param input Input dataset to stratify.
* @param inputLabel Input labels to stratify.
* @param trainData Matrix to store training data into.
* @param testData Matrix to store test data into.
* @param trainLabel Vector to store training labels into.
* @param testLabel Vector to store test labels into.
* @param testRatio Percentage of dataset to use for test set (between 0 and 1).
* @param shuffleData If true, the sample order is shuffled; otherwise, each
* sample is visited in linear order. (Default true.)
*/
template<typename T, typename U>
void StratifiedSplit(const arma::Mat<T>& input,
const arma::Row<U>& inputLabel,
arma::Mat<T>& trainData,
arma::Mat<T>& testData,
arma::Row<U>& trainLabel,
arma::Row<U>& testLabel,
const double testRatio,
const bool shuffleData = true)
{
/**
* Basic idea:
* Let us say we have to stratify a dataset based on labels:
* 0 0 0 0 0 (5 0s)
* 1 1 1 1 1 1 1 1 1 1 1 (11 1s)
*
* Let our test ratio be 0.2.
* Then, the number of 0 labels in our test set = floor(5 * 0.2) = 1.
* The number of 1 labels in our test set = floor(11 * 0.2) = 2.
*
* In our first pass over the dataset,
* We visit each label and keep count of each label in our 'labelCounts' uvec.
*
* We then take a second pass over the dataset.
* We now maintain an additional uvec 'testLabelCounts' to hold the label
* counts of our test set.
*
* In this pass, when we encounter a label we check the 'testLabelCounts' uvec
* for the count of this label in the test set.
* If this count is less than the required number of labels in the test set,
* we add the data to the test set and increment the label count in the uvec.
* If this count is equal to or more than the required count in the test set,
* we add this data to the train set.
*
* Based on the above steps, we get the following labels in the split set:
* Train set (4 0s, 9 1s)
* 0 0 0 0
* 1 1 1 1 1 1 1 1 1
*
* Test set (1 0s, 2 1s)
* 0
* 1 1
*/
size_t trainIdx = 0;
size_t testIdx = 0;
size_t trainSize = 0;
size_t testSize = 0;
arma::uvec labelCounts;
arma::uvec testLabelCounts;
U maxLabel = inputLabel.max();

labelCounts.zeros(maxLabel+1);
testLabelCounts.zeros(maxLabel+1);

arma::uvec order =
arma::linspace<arma::uvec>(0, input.n_cols - 1, input.n_cols);

if (shuffleData)
{
order = arma::shuffle(order);
}

for (U label : inputLabel)
{
++labelCounts[label];
}

for (arma::uword labelCount : labelCounts)
{
testSize += floor(labelCount * testRatio);
trainSize += labelCount - floor(labelCount * testRatio);
}
Abilityguy marked this conversation as resolved.
Show resolved Hide resolved

trainData.set_size(input.n_rows, trainSize);
testData.set_size(input.n_rows, testSize);
trainLabel.set_size(trainSize);
testLabel.set_size(testSize);

for (arma::uword i : order)
{
U label = inputLabel[i];
if (testLabelCounts[label] < floor(labelCounts[label] * testRatio))
{
testLabelCounts[label] += 1;
testData.col(testIdx) = input.col(i);
testLabel[testIdx] = inputLabel[i];
testIdx += 1;
}
else
{
trainData.col(trainIdx) = input.col(i);
trainLabel[trainIdx] = inputLabel[i];
trainIdx += 1;
}
}
}

/**
* Given an input dataset and labels, split into a training set and test set.
* Example usage below. This overload places the split dataset into the four
Expand Down Expand Up @@ -167,7 +299,10 @@ void Split(const arma::Mat<T>& input,
* @param inputLabel Input labels to split.
* @param testRatio Percentage of dataset to use for test set (between 0 and 1).
* @param shuffleData If true, the sample order is shuffled; otherwise, each
* sample is visited in linear order. (Default true).
* sample is visited in linear order. (Default true).
* @param stratifyData If true, the train and test splits are stratified
* so that the ratio of each class in the training and test sets is the same
* as in the original dataset.
* @return std::tuple containing trainData (arma::Mat<T>), testData
* (arma::Mat<T>), trainLabel (arma::Row<U>), and testLabel (arma::Row<U>).
*/
Expand All @@ -176,15 +311,24 @@ std::tuple<arma::Mat<T>, arma::Mat<T>, arma::Row<U>, arma::Row<U>>
Split(const arma::Mat<T>& input,
const arma::Row<U>& inputLabel,
const double testRatio,
const bool shuffleData = true)
const bool shuffleData = true,
const bool stratifyData = false)
Abilityguy marked this conversation as resolved.
Show resolved Hide resolved
{
arma::Mat<T> trainData;
arma::Mat<T> testData;
arma::Row<U> trainLabel;
arma::Row<U> testLabel;

Split(input, inputLabel, trainData, testData, trainLabel, testLabel,
testRatio, shuffleData);
if (stratifyData)
{
StratifiedSplit(input, inputLabel, trainData, testData, trainLabel,
testLabel, testRatio, shuffleData);
}
else
{
Split(input, inputLabel, trainData, testData, trainLabel, testLabel,
testRatio, shuffleData);
}

return std::make_tuple(std::move(trainData),
std::move(testData),
Expand Down
23 changes: 18 additions & 5 deletions src/mlpack/methods/preprocess/preprocess_split_main.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -69,6 +69,10 @@ BINDING_EXAMPLE(
"test_ratio", 0.3, "training", "X_train", "training_labels", "y_train",
"test", "X_test", "test_labels", "y_test"));

BINDING_EXAMPLE(
"To maintain the ratio of each class in the train and test sets, the" +
PRINT_PARAM_STRING("stratify_data") + " option can be used.");

// See also...
BINDING_SEE_ALSO("@preprocess_binarize", "#preprocess_binarize");
BINDING_SEE_ALSO("@preprocess_describe", "#preprocess_describe");
Expand All @@ -90,6 +94,7 @@ PARAM_DOUBLE_IN("test_ratio", "Ratio of test set; if not set,"

PARAM_INT_IN("seed", "Random seed (0 for std::time(NULL)).", "s", 0);
PARAM_FLAG("no_shuffle", "Avoid shuffling and splitting the data.", "S");
PARAM_FLAG("stratify_data", "Stratify the data according to labels", "z")

using namespace mlpack;
using namespace mlpack::data;
Expand All @@ -102,6 +107,7 @@ static void mlpackMain()
// Parse command line options.
const double testRatio = IO::GetParam<double>("test_ratio");
const bool shuffleData = IO::GetParam<bool>("no_shuffle");
const bool stratifyData = IO::GetParam<bool>("stratify_data");

if (IO::GetParam<int>("seed") == 0)
mlpack::math::RandomSeed(std::time(NULL));
Expand Down Expand Up @@ -148,11 +154,15 @@ static void mlpackMain()
IO::GetParam<arma::Mat<size_t>>("input_labels");
arma::Row<size_t> labelsRow = labels.row(0);

const auto value = data::Split(data, labelsRow, testRatio, !shuffleData);
Log::Info << "Training data contains " << get<0>(value).n_cols << " points."
<< endl;
Log::Info << "Test data contains " << get<1>(value).n_cols << " points."
<< endl;
Timer::Start("splitting_data");
const auto value =
data::Split(data, labelsRow, testRatio, !shuffleData, stratifyData);
Timer::Stop("splitting_data");

Log::Info << "Training data contains "
<< get<0>(value).n_cols << " points." << endl;
Log::Info << "Test data contains "
<< get<1>(value).n_cols << " points." << endl;

if (IO::HasParam("training"))
IO::GetParam<arma::mat>("training") = std::move(get<0>(value));
Expand All @@ -167,7 +177,10 @@ static void mlpackMain()
}
else // We have no labels, so just split the dataset.
{
Timer::Start("splitting_data");
const auto value = data::Split(data, testRatio, !shuffleData);
Timer::Stop("splitting_data");

Log::Info << "Training data contains " << get<0>(value).n_cols << " points."
<< endl;
Log::Info << "Test data contains " << get<1>(value).n_cols << " points."
Expand Down