Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Added an implementation to Stratify Data #2671

Merged
merged 27 commits into from
Nov 15, 2020
Merged
Show file tree
Hide file tree
Changes from 9 commits
Commits
Show all changes
27 commits
Select commit Hold shift + click to select a range
c0f725e
StratifiedSplit file added with basic code
Abilityguy Oct 8, 2020
db94a2c
Added Stratified split implementation
Abilityguy Oct 10, 2020
e6985fc
Minor changes to preprocess_split file
Abilityguy Oct 10, 2020
9d34e23
Stratified Split implementation done
Abilityguy Oct 11, 2020
1f8697b
Added Stratified Split implementation and integrated with preprocess …
Abilityguy Oct 13, 2020
196d561
Minor style fix
Abilityguy Oct 13, 2020
dd11c74
Basic idea added in comments
Abilityguy Oct 13, 2020
a7befb5
Refactoring code and style fixes
Abilityguy Oct 14, 2020
2f36940
Moved StratifiedSplit templates moved into split_data.hpp and refacto…
Abilityguy Oct 19, 2020
fdc01e5
Fix for documentation failing test
Abilityguy Oct 19, 2020
d06cf96
Code review changes
Abilityguy Oct 21, 2020
86bc837
Added tests for stratified split
Abilityguy Oct 23, 2020
9eabc40
Possible fix for failing tests
Abilityguy Oct 23, 2020
e7539c6
Review changes
Abilityguy Oct 24, 2020
9f988a3
Merge branch 'master' into StratifiedSplit
Abilityguy Oct 24, 2020
58b8202
Fix for the failing tests
Abilityguy Oct 24, 2020
4244a9d
Replaced the casting with floor
Abilityguy Oct 24, 2020
0bdc8b8
Style fixes and algorithm fix
Abilityguy Oct 30, 2020
76fe709
Changed unordered map implementation to uvec implementation
Abilityguy Nov 1, 2020
1a49c5f
Removed the 'ReportIgnoredParam' attribute
Abilityguy Nov 8, 2020
b1e8382
Fix for failing style check
Abilityguy Nov 8, 2020
2fc09bf
Changes made based on review comments
Abilityguy Nov 11, 2020
3f2220f
Changed to direct looping of labels
Abilityguy Nov 12, 2020
67dc33e
Added tests to preprocess_split_test
Abilityguy Nov 13, 2020
04f2d46
Removed unused variables in a test case
Abilityguy Nov 13, 2020
db62edf
Merge branch 'master' into StratifiedSplit
Abilityguy Nov 14, 2020
765d278
Modified HISTORY.md
Abilityguy Nov 14, 2020
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
133 changes: 130 additions & 3 deletions src/mlpack/core/data/split_data.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,124 @@

namespace mlpack {
namespace data {

/**
* Given an input dataset and labels, stratify into a training set and test set.
* Example usage below. This overload places the stratified dataset into the
* four output parameters given (trainData, testData, trainLabel,
* and testLabel).
*
* @code
* arma::mat input = loadData();
* arma::Row<size_t> label = loadLabel();
* arma::mat trainData;
* arma::mat testData;
* arma::Row<size_t> trainLabel;
* arma::Row<size_t> testLabel;
* math::RandomSeed(100); // Set the seed if you like.
*
* // Stratify the dataset into a training and test set, with 30% of the data
* // being held out for the test set.
* StratifiedSplit(input, label, trainData,
* testData, trainLabel, testLabel, 0.3);
* @endcode
*
* @param input Input dataset to stratify.
* @param inputLabel Input labels to stratify.
* @param trainData Matrix to store training data into.
* @param testData Matrix to store test data into.
* @param trainLabel Vector to store training labels into.
* @param testLabel Vector to store test labels into.
* @param testRatio Percentage of dataset to use for test set (between 0 and 1).
* @param shuffleData If true, the sample order is shuffled; otherwise, each
* sample is visited in linear order. (Default true.)
*/
template<typename T, typename U>
void StratifiedSplit(const arma::Mat<T>& input,
const arma::Row<U>& inputLabel,
arma::Mat<T>& trainData,
arma::Mat<T>& testData,
arma::Row<U>& trainLabel,
arma::Row<U>& testLabel,
const double testRatio,
const bool shuffleData = true)
{
/**
* Basic idea:
* Let us say we have to stratify a dataset based on labels:
* 0 0 0 0 0 1 1 1 1 1 1 (5 0s, 6 1s)
*
* Let our test ratio be 0.2:
* We visit each label and keep the count of each label in our unordered map
*
* Whenever we encounter a label, we calculate
* current_count*test_ratio --- labelMap[label]*testRatio and
* current_count+1 * test_ratio --- (labelMap[label]+1)*testRatio
*
* We then static_cast these counts to size_t to remove their decimal points.
* If in this case, our integer counts are same then we add to our train set.
* If there is a difference in counts, then we add to our test set
*
* Considering our example
* 0 -- train set ( 0*0.2 == 1*0.2 ) (After casting)
* 0 -- train set ( 1*0.2 == 2*0.2 ) (After casting)
* 0 -- train set ( 2*0.2 == 3*0.2 ) (After casting)
* 0 -- train set ( 3*0.2 == 4*0.2 ) (After casting)
* 0 -- test set ( 4*0.2 < 5*0.2 ) (After casting)
*
* 1 -- train set ( 0*0.2 == 1*0.2 ) (After casting)
* 1 -- train set ( 1*0.2 == 2*0.2 ) (After casting)
* 1 -- train set ( 2*0.2 == 3*0.2 ) (After casting)
* 1 -- train set ( 3*0.2 == 4*0.2 ) (After casting)
* 1 -- test set ( 4*0.2 < 5*0.2 ) (After casting)
* 1 -- train set ( 5*0.2 == 6*0.2 ) (After casting)
Abilityguy marked this conversation as resolved.
Show resolved Hide resolved
*
* Finally
* train set,
* 0 0 0 0 1 1 1 1 1 (4 0s, 5 1s)
*
* test set,
* 0 1
*/
arma::uvec Indexes;
Abilityguy marked this conversation as resolved.
Show resolved Hide resolved
Indexes.set_size(inputLabel.n_cols);

size_t trainIdx = 0;
size_t testIdx = inputLabel.n_cols - 1;
std::unordered_map<U, size_t> labelMap;

arma::uvec order =
arma::linspace<arma::uvec>(0, input.n_cols - 1, input.n_cols);

if (shuffleData)
{
order = arma::shuffle(order);
}

for (auto i: order)
{
auto label = inputLabel[i];
Abilityguy marked this conversation as resolved.
Show resolved Hide resolved
if (static_cast<size_t>(labelMap[label]*testRatio) <
static_cast<size_t>((labelMap[label]+1)*testRatio))
Abilityguy marked this conversation as resolved.
Show resolved Hide resolved
{
Indexes[testIdx] = i;
testIdx -= 1;
}
else
{
Indexes[trainIdx] = i;
trainIdx += 1;
}
labelMap[label] += 1;
}
Abilityguy marked this conversation as resolved.
Show resolved Hide resolved

labelMap.clear();
Abilityguy marked this conversation as resolved.
Show resolved Hide resolved
testData = input.cols(Indexes.subvec(trainIdx, Indexes.n_rows-1));
testLabel = inputLabel.cols(Indexes.subvec(trainIdx, Indexes.n_rows-1));
trainData = input.cols(Indexes.subvec(0, trainIdx-1));
trainLabel = inputLabel.cols(Indexes.subvec(0, trainIdx-1));
}

/**
* Given an input dataset and labels, split into a training set and test set.
* Example usage below. This overload places the split dataset into the four
Expand Down Expand Up @@ -176,15 +294,24 @@ std::tuple<arma::Mat<T>, arma::Mat<T>, arma::Row<U>, arma::Row<U>>
Split(const arma::Mat<T>& input,
const arma::Row<U>& inputLabel,
const double testRatio,
const bool shuffleData = true)
const bool shuffleData = true,
const bool stratifyData = false)
Abilityguy marked this conversation as resolved.
Show resolved Hide resolved
{
arma::Mat<T> trainData;
arma::Mat<T> testData;
arma::Row<U> trainLabel;
arma::Row<U> testLabel;

Split(input, inputLabel, trainData, testData, trainLabel, testLabel,
testRatio, shuffleData);
if (stratifyData)
{
StratifiedSplit(input, inputLabel, trainData, testData, trainLabel,
testLabel, testRatio, shuffleData);
}
else
{
Split(input, inputLabel, trainData, testData, trainLabel, testLabel,
testRatio, shuffleData);
}

return std::make_tuple(std::move(trainData),
std::move(testData),
Expand Down
29 changes: 24 additions & 5 deletions src/mlpack/methods/preprocess/preprocess_split_main.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -69,6 +69,21 @@ BINDING_EXAMPLE(
"test_ratio", 0.3, "training", "X_train", "training_labels", "y_train",
"test", "X_test", "test_labels", "y_test"));

BINDING_EXAMPLE(
"There is an option to stratify the dataset according to the labels." +
PRINT_PARAM_STRING("stratify_data") + " option to stratify the data;"
"an example to stratify the data is" +
Abilityguy marked this conversation as resolved.
Show resolved Hide resolved
"\n\n" +
"If we had a dataset " + PRINT_DATASET("X") + " and associated labels " +
PRINT_DATASET("y") + ", and we wanted to stratify these into " +
PRINT_DATASET("X_train") + ", " + PRINT_DATASET("y_train") + ", " +
PRINT_DATASET("X_test") + ", and " + PRINT_DATASET("y_test") + ", with 30% "
"of the data in the test set, we could run"
"\n\n" +
PRINT_CALL("preprocess_split", "input", "X", "input_labels", "y",
"test_ratio", 0.3, "training", "X_train", "training_labels", "y_train",
"test", "X_test", "test_labels", "y_test", "stratify_data", true));

// See also...
BINDING_SEE_ALSO("@preprocess_binarize", "#preprocess_binarize");
BINDING_SEE_ALSO("@preprocess_describe", "#preprocess_describe");
Expand All @@ -90,6 +105,7 @@ PARAM_DOUBLE_IN("test_ratio", "Ratio of test set; if not set,"

PARAM_INT_IN("seed", "Random seed (0 for std::time(NULL)).", "s", 0);
PARAM_FLAG("no_shuffle", "Avoid shuffling and splitting the data.", "S");
PARAM_FLAG("stratify_data", "Stratify the data according to labels", "z")

using namespace mlpack;
using namespace mlpack::data;
Expand All @@ -102,6 +118,7 @@ static void mlpackMain()
// Parse command line options.
const double testRatio = IO::GetParam<double>("test_ratio");
const bool shuffleData = IO::GetParam<bool>("no_shuffle");
const bool stratifyData = IO::GetParam<bool>("stratify_data");

if (IO::GetParam<int>("seed") == 0)
mlpack::math::RandomSeed(std::time(NULL));
Expand All @@ -125,6 +142,7 @@ static void mlpackMain()
{
ReportIgnoredParam({{ "input_labels", true }}, "training_labels");
ReportIgnoredParam({{ "input_labels", true }}, "test_labels");
ReportIgnoredParam("stratify_data", "input_labels is not provided");
Abilityguy marked this conversation as resolved.
Show resolved Hide resolved
}

// Check test_ratio.
Expand All @@ -148,11 +166,12 @@ static void mlpackMain()
IO::GetParam<arma::Mat<size_t>>("input_labels");
arma::Row<size_t> labelsRow = labels.row(0);

const auto value = data::Split(data, labelsRow, testRatio, !shuffleData);
Log::Info << "Training data contains " << get<0>(value).n_cols << " points."
<< endl;
Log::Info << "Test data contains " << get<1>(value).n_cols << " points."
<< endl;
const auto value =
data::Split(data, labelsRow, testRatio, !shuffleData, stratifyData);
Log::Info << "Training data contains "
<< get<0>(value).n_cols << " points." << endl;
Log::Info << "Test data contains "
<< get<1>(value).n_cols << " points." << endl;

if (IO::HasParam("training"))
IO::GetParam<arma::mat>("training") = std::move(get<0>(value));
Expand Down