From 21ef31f2a510f22fecce2809e34aee6bc0f9a7ab Mon Sep 17 00:00:00 2001 From: stereomatchingkiss Date: Wed, 24 Feb 2016 12:59:09 +0800 Subject: [PATCH 01/21] add train test split --- src/mlpack/core/util/split_data.hpp | 69 ++++++++++++++++++++++++++++ src/mlpack/tests/CMakeLists.txt | 1 + src/mlpack/tests/split_data_test.cpp | 50 ++++++++++++++++++++ 3 files changed, 120 insertions(+) create mode 100644 src/mlpack/core/util/split_data.hpp create mode 100644 src/mlpack/tests/split_data_test.cpp diff --git a/src/mlpack/core/util/split_data.hpp b/src/mlpack/core/util/split_data.hpp new file mode 100644 index 00000000000..2d4f1a713c2 --- /dev/null +++ b/src/mlpack/core/util/split_data.hpp @@ -0,0 +1,69 @@ +#ifndef __MLPACK_CORE_UTIL_SPLIT_DATA_HPP +#define __MLPACK_CORE_UTIL_SPLIT_DATA_HPP + +#include + +#include +#include +#include +#include +#include + +namespace mlpack { +namespace util { + +/** + *Split training data and test data + *@param input input data want to split + *@param label input label want to split + *@param testRatio the ratio of test data + *@param seed seed of the random device + *@code + *arma::mat trainData = loadData(); + *arma::Row label = loadLabel(); + *std::random_device rd; + *auto trainTest = TrainTestSplit(trainData, label, 0.25, rd()); + *@endcode + */ +template +std::tuple, arma::Mat, +arma::Row, arma::Row> +TrainTestSplit(arma::Mat const &input, + arma::Row const &label, + double testRatio, + unsigned int seed = 0) +{ + size_t const testSize = + static_cast(input.n_cols * testRatio); + size_t const trainSize = input.n_cols - testSize; + + arma::Mat trainData(input.n_rows, trainSize); + arma::Mat testData(input.n_rows, testSize); + arma::Row trainLabel(trainSize); + arma::Row testLabel(testSize); + + std::vector permutation(input.n_cols); + std::iota(std::begin(permutation), std::end(permutation), 0); + + std::mt19937 gen(seed); + std::shuffle(std::begin(permutation), std::end(permutation), gen); + + for(size_t i = 0; i != trainData.n_cols; ++i) + { + trainData.col(i) = input.col(permutation[i]); + trainLabel(i) = label(permutation[i]); + } + + for(size_t i = 0; i != testData.n_cols; ++i) + { + testData.col(i) = input.col(permutation[i + trainSize]); + testLabel(i) = label(permutation[i + trainSize]); + } + + return std::make_tuple(trainData, testData, trainLabel, testLabel); +} + +} // namespace util +} // namespace mlpack + +#endif diff --git a/src/mlpack/tests/CMakeLists.txt b/src/mlpack/tests/CMakeLists.txt index fff69ead121..945ea97ca12 100644 --- a/src/mlpack/tests/CMakeLists.txt +++ b/src/mlpack/tests/CMakeLists.txt @@ -65,6 +65,7 @@ add_executable(mlpack_test sort_policy_test.cpp sparse_autoencoder_test.cpp sparse_coding_test.cpp + split_data_test.cpp termination_policy_test.cpp tree_test.cpp tree_traits_test.cpp diff --git a/src/mlpack/tests/split_data_test.cpp b/src/mlpack/tests/split_data_test.cpp new file mode 100644 index 00000000000..05c37d14b61 --- /dev/null +++ b/src/mlpack/tests/split_data_test.cpp @@ -0,0 +1,50 @@ +/** + * @file sparse_autoencoder_test.cpp + * @author Siddharth Agrawal + * + * Test the SparseAutoencoder class. + */ + +#include +#include + +#include +#include "old_boost_test_definitions.hpp" + +using namespace mlpack; +using namespace arma; + +BOOST_AUTO_TEST_SUITE(SplitDataTest); + +void compareData(arma::mat const &inputData, arma::mat const &compareData, + arma::Row const &inputLabel) +{ + for(size_t i = 0; i != compareData.n_cols; ++i){ + arma::mat const &lhsCol = inputData.col(inputLabel(i)); + arma::mat const &rhsCol = compareData.col(i); + for(size_t j = 0; j != lhsCol.n_rows; ++j){ + BOOST_REQUIRE_CLOSE(lhsCol(j), rhsCol(j), 1e-5); + } + } +} + +BOOST_AUTO_TEST_CASE(SplitDataSplitResult) +{ + arma::mat trainData(2,10); + trainData.randu(); + arma::Row labels(trainData.n_cols); + for(size_t i = 0; i != labels.n_cols; ++i){ + labels(i) = i; + } + + auto const value = util::TrainTestSplit(trainData, labels, 0.2); + BOOST_REQUIRE(std::get<0>(value).n_cols == 8); + BOOST_REQUIRE(std::get<1>(value).n_cols == 2); + BOOST_REQUIRE(std::get<2>(value).n_cols == 8); + BOOST_REQUIRE(std::get<3>(value).n_cols == 2); + + compareData(trainData, std::get<0>(value), std::get<2>(value)); + compareData(trainData, std::get<1>(value), std::get<3>(value)); +} + +BOOST_AUTO_TEST_SUITE_END(); From e77cb59052b5c9de706f8b2c0d87549507c0a89c Mon Sep 17 00:00:00 2001 From: stereomatchingkiss Date: Wed, 24 Feb 2016 13:01:30 +0800 Subject: [PATCH 02/21] refine comments --- src/mlpack/core/util/split_data.hpp | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/src/mlpack/core/util/split_data.hpp b/src/mlpack/core/util/split_data.hpp index 2d4f1a713c2..f3483015466 100644 --- a/src/mlpack/core/util/split_data.hpp +++ b/src/mlpack/core/util/split_data.hpp @@ -13,7 +13,8 @@ namespace mlpack { namespace util { /** - *Split training data and test data + *Split training data and test data, please define + *ARMA_USE_CXX11 to enable move of c++11 *@param input input data want to split *@param label input label want to split *@param testRatio the ratio of test data From e99632f750912a714063505135f56e2c4f1da259 Mon Sep 17 00:00:00 2001 From: stereomatchingkiss Date: Wed, 24 Feb 2016 17:10:00 +0800 Subject: [PATCH 03/21] split train test support arma cube --- src/mlpack/core/util/split_data.hpp | 79 ++++++++++++++++++++++++----- 1 file changed, 65 insertions(+), 14 deletions(-) diff --git a/src/mlpack/core/util/split_data.hpp b/src/mlpack/core/util/split_data.hpp index f3483015466..3d6cc8da3d3 100644 --- a/src/mlpack/core/util/split_data.hpp +++ b/src/mlpack/core/util/split_data.hpp @@ -12,8 +12,58 @@ namespace mlpack { namespace util { +namespace details{ + +template +inline +arma::Mat createData(arma::Mat const &input, + size_t dataSize) +{ + return arma::Mat(input.n_rows, dataSize); +} + +template +inline +arma::Cube createData(arma::Cube const &input, + size_t dataSize) +{ + return arma::Cube(input.n_rows, input.n_cols, dataSize); +} + +template +inline +void extractData(arma::Mat const &input, arma::Mat &output, + size_t inputIndex, size_t outputIndex) +{ + output.col(outputIndex) = input.col(inputIndex); +} + +template +inline +void extractData(arma::Cube const &input, arma::Cube &output, + size_t inputIndex, size_t outputIndex) +{ + output.slice(outputIndex) = input.slice(inputIndex); +} + +template +inline +size_t extractSize(arma::Mat const &input) +{ + return input.n_cols; +} + +template +inline +size_t extractSize(arma::Cube const &input) +{ + return input.n_slices; +} + +} + /** - *Split training data and test data, please define + *Split training data and test data, please define *ARMA_USE_CXX11 to enable move of c++11 *@param input input data want to split *@param label input label want to split @@ -27,38 +77,39 @@ namespace util { *@endcode */ template -std::tuple, arma::Mat, +std::tuple, arma::Row> -TrainTestSplit(arma::Mat const &input, +TrainTestSplit(T const &input, arma::Row const &label, double testRatio, unsigned int seed = 0) { size_t const testSize = - static_cast(input.n_cols * testRatio); - size_t const trainSize = input.n_cols - testSize; + static_cast(details::extractSize(input) * testRatio); + size_t const trainSize = details::extractSize(input) - testSize; - arma::Mat trainData(input.n_rows, trainSize); - arma::Mat testData(input.n_rows, testSize); + T trainData = details::createData(input, trainSize); + T testData = details::createData(input, testSize); arma::Row trainLabel(trainSize); arma::Row testLabel(testSize); - std::vector permutation(input.n_cols); + std::vector permutation(details::extractSize(input)); std::iota(std::begin(permutation), std::end(permutation), 0); std::mt19937 gen(seed); std::shuffle(std::begin(permutation), std::end(permutation), gen); - for(size_t i = 0; i != trainData.n_cols; ++i) + for(size_t i = 0; i != trainSize; ++i) { - trainData.col(i) = input.col(permutation[i]); - trainLabel(i) = label(permutation[i]); + details::extractData(input, trainData, permutation[i], i); + trainLabel(i) = label(permutation[i]); } - for(size_t i = 0; i != testData.n_cols; ++i) + for(size_t i = 0; i != testSize; ++i) { - testData.col(i) = input.col(permutation[i + trainSize]); - testLabel(i) = label(permutation[i + trainSize]); + details::extractData(input, testData, + permutation[i + trainSize], i); + testLabel(i) = label(permutation[i + trainSize]); } return std::make_tuple(trainData, testData, trainLabel, testLabel); From f7872fbc842cb6798b4076075aa3f54ab3d3273f Mon Sep 17 00:00:00 2001 From: stereomatchingkiss Date: Wed, 24 Feb 2016 17:10:11 +0800 Subject: [PATCH 04/21] add test of cube --- src/mlpack/tests/split_data_test.cpp | 33 +++++++++++++++++++++++++++- 1 file changed, 32 insertions(+), 1 deletion(-) diff --git a/src/mlpack/tests/split_data_test.cpp b/src/mlpack/tests/split_data_test.cpp index 05c37d14b61..bb0866fde1b 100644 --- a/src/mlpack/tests/split_data_test.cpp +++ b/src/mlpack/tests/split_data_test.cpp @@ -28,7 +28,19 @@ void compareData(arma::mat const &inputData, arma::mat const &compareData, } } -BOOST_AUTO_TEST_CASE(SplitDataSplitResult) +void compareData(arma::cube const &inputData, arma::cube const &compareData, + arma::Row const &inputLabel) +{ + for(size_t i = 0; i != compareData.n_slices; ++i){ + arma::mat const &lhsMat = inputData.slice(inputLabel(i)); + arma::mat const &rhsMat = compareData.slice(i); + for(size_t j = 0; j != lhsMat.size(); ++j){ + BOOST_REQUIRE_CLOSE(lhsMat(j), rhsMat(j), 1e-5); + } + } +} + +BOOST_AUTO_TEST_CASE(SplitDataSplitResultMat) { arma::mat trainData(2,10); trainData.randu(); @@ -47,4 +59,23 @@ BOOST_AUTO_TEST_CASE(SplitDataSplitResult) compareData(trainData, std::get<1>(value), std::get<3>(value)); } +BOOST_AUTO_TEST_CASE(SplitDataSplitResultCube) +{ + arma::cube trainData(2,2,10); + trainData.randu(); + arma::Row labels(trainData.n_slices); + for(size_t i = 0; i != labels.n_cols; ++i){ + labels(i) = i; + } + + auto const value = util::TrainTestSplit(trainData, labels, 0.2); + BOOST_REQUIRE(std::get<0>(value).n_slices == 8); + BOOST_REQUIRE(std::get<1>(value).n_slices == 2); + BOOST_REQUIRE(std::get<2>(value).n_cols == 8); + BOOST_REQUIRE(std::get<3>(value).n_cols == 2); + + compareData(trainData, std::get<0>(value), std::get<2>(value)); + compareData(trainData, std::get<1>(value), std::get<3>(value)); +} + BOOST_AUTO_TEST_SUITE_END(); From 7a7d8e555163d76c356b1a1caa3c774e97c7b1fa Mon Sep 17 00:00:00 2001 From: stereomatchingkiss Date: Sat, 27 Feb 2016 11:10:45 +0800 Subject: [PATCH 05/21] change function to class --- src/mlpack/core/util/split_data.hpp | 219 ++++++++++++++++++++++------ 1 file changed, 175 insertions(+), 44 deletions(-) diff --git a/src/mlpack/core/util/split_data.hpp b/src/mlpack/core/util/split_data.hpp index 3d6cc8da3d3..bc9fe373369 100644 --- a/src/mlpack/core/util/split_data.hpp +++ b/src/mlpack/core/util/split_data.hpp @@ -12,55 +12,186 @@ namespace mlpack { namespace util { -namespace details{ - -template -inline -arma::Mat createData(arma::Mat const &input, - size_t dataSize) +/** + *Split training data and test data, please define + *ARMA_USE_CXX11 to enable move of c++11 + */ +class TrainTestSplit { - return arma::Mat(input.n_rows, dataSize); -} +public: + /** + * @brief TrainTestSplit + * @param testRatio the ratio of test data + * @param slice indicate how many slice(depth) per image, + * this parameter only work on arma::Cube + * @param seed seed of the random device + */ + TrainTestSplit(double testRatio, + size_t slice = 1, + arma::arma_rng::seed_type seed = 0) : + seed(seed), + slice(slice), + testRatio(testRatio) + {} + + /** + *Split training data and test data, please define + *ARMA_USE_CXX11 to enable move of c++11 + *@param input input data want to split + *@param label input label want to split + *@param testRatio the ratio of test data + *@param seed seed of the random device + *@code + *arma::mat input = loadData(); + *arma::Row label = loadLabel(); + *arma::mat trainData; + *arma::mat testData; + *arma::Row trainLabel; + *arma::Row testLabel; + *std::random_device rd; + *TrainTestSplit tts(0.25); + *tts.Split(input, label, trainData, testData, trainLabel, + * testLabel); + *@endcode + */ + template + void Split(T const &input, + arma::Row const &inputLabel, + T &trainData, + T &testData, + arma::Row &trainLabel, + arma::Row &testLabel) + { + size_t const testSize = + static_cast(ExtractSize(input) * testRatio); + size_t const trainSize = ExtractSize(input) - testSize; + + ResizeData(input, trainData, trainSize); + ResizeData(input, testData, testSize); + trainLabel.set_size(trainSize); + testLabel.set_size(testSize); + + std::vector permutation(ExtractSize(input)); + std::iota(std::begin(permutation), std::end(permutation), 0); + + std::mt19937 gen(seed); + std::shuffle(std::begin(permutation), std::end(permutation), gen); + + for(size_t i = 0; i != trainSize; ++i) + { + ExtractData(input, trainData, permutation[i], i); + trainLabel(i) = inputLabel(permutation[i]); + } + + for(size_t i = 0; i != testSize; ++i) + { + ExtractData(input, testData, + permutation[i + trainSize], i); + testLabel(i) = inputLabel(permutation[i + trainSize]); + } + } -template -inline -arma::Cube createData(arma::Cube const &input, - size_t dataSize) -{ - return arma::Cube(input.n_rows, input.n_cols, dataSize); -} + /** + *Overload of Split, if you do not like to pass in + *so many param, you could call this api instead + *@param input input data want to split + *@param label input label want to split + *@return They are trainData, testData, trainLabel and + *testLabel + */ + template + std::tuple, arma::Row> + Split(T const &input, + arma::Row const &inputLabel) + { + T trainData; + T testData; + arma::Row trainLabel; + arma::Row testLabel; -template -inline -void extractData(arma::Mat const &input, arma::Mat &output, - size_t inputIndex, size_t outputIndex) -{ - output.col(outputIndex) = input.col(inputIndex); -} + Split(input, inputLabel, trainData, testData, + trainLabel, testLabel); -template -inline -void extractData(arma::Cube const &input, arma::Cube &output, - size_t inputIndex, size_t outputIndex) -{ - output.slice(outputIndex) = input.slice(inputIndex); -} + return std::make_tuple(trainData, testData, + trainLabel, testLabel); + } -template -inline -size_t extractSize(arma::Mat const &input) -{ - return input.n_cols; -} + void Seed(arma::arma_rng::seed_type value) + { + seed = value; + } + arma::arma_rng::seed_type Seed() const + { + return seed; + } -template -inline -size_t extractSize(arma::Cube const &input) -{ - return input.n_slices; -} + size_t Slice() const + { + return slice; + } + void Slice(size_t value) + { + slice = value; + } + + void TestRatio(double value) + { + testRatio = value; + } + double TestRatio() const + { + return testRatio; + } + + +private: + template + void ExtractData(arma::Mat const &input, arma::Mat &output, + size_t inputIndex, size_t outputIndex) const + { + output.col(outputIndex) = input.col(inputIndex); + } + + template + void ExtractData(arma::Cube const &input, arma::Cube &output, + size_t inputIndex, size_t outputIndex) const + { + output.slice(outputIndex) = input.slice(inputIndex); + } + + template + size_t ExtractSize(arma::Mat const &input) const + { + return input.n_cols; + } + + template + size_t ExtractSize(arma::Cube const &input) const + { + return input.n_slices; + } + + template + void ResizeData(arma::Mat const &input, + arma::Mat &output, + size_t dataSize) const + { + output.set_size(input.n_rows, dataSize); + } + + template + void ResizeData(arma::Cube const &input, + arma::Cube &output, + size_t dataSize) const + { + output.set_size(input.n_rows, input.n_cols, dataSize); + } -} + arma::arma_rng::seed_type seed; + size_t slice; + double testRatio; +}; /** *Split training data and test data, please define @@ -76,7 +207,7 @@ size_t extractSize(arma::Cube const &input) *auto trainTest = TrainTestSplit(trainData, label, 0.25, rd()); *@endcode */ -template +/*template std::tuple, arma::Row> TrainTestSplit(T const &input, @@ -113,7 +244,7 @@ TrainTestSplit(T const &input, } return std::make_tuple(trainData, testData, trainLabel, testLabel); -} +}*/ } // namespace util } // namespace mlpack From 8407475b1d7e1ab8ced448d1733d78fd8ccf4373 Mon Sep 17 00:00:00 2001 From: stereomatchingkiss Date: Sat, 27 Feb 2016 11:19:24 +0800 Subject: [PATCH 06/21] use arma shuffle and linspace to replace old solution --- src/mlpack/core/util/split_data.hpp | 18 +++++++++--------- 1 file changed, 9 insertions(+), 9 deletions(-) diff --git a/src/mlpack/core/util/split_data.hpp b/src/mlpack/core/util/split_data.hpp index bc9fe373369..17bbc5a3a50 100644 --- a/src/mlpack/core/util/split_data.hpp +++ b/src/mlpack/core/util/split_data.hpp @@ -71,23 +71,23 @@ class TrainTestSplit trainLabel.set_size(trainSize); testLabel.set_size(testSize); - std::vector permutation(ExtractSize(input)); - std::iota(std::begin(permutation), std::end(permutation), 0); - - std::mt19937 gen(seed); - std::shuffle(std::begin(permutation), std::end(permutation), gen); + using Col = arma::Col; + arma_rng::set_seed(seed); + Col const sequence = arma::linspace(0, ExtractSize(input) - 1, + ExtractSize(input)); + arma::Col const order = arma::shuffle(sequence); for(size_t i = 0; i != trainSize; ++i) { - ExtractData(input, trainData, permutation[i], i); - trainLabel(i) = inputLabel(permutation[i]); + ExtractData(input, trainData, order[i], i); + trainLabel(i) = inputLabel(order[i]); } for(size_t i = 0; i != testSize; ++i) { ExtractData(input, testData, - permutation[i + trainSize], i); - testLabel(i) = inputLabel(permutation[i + trainSize]); + order[i + trainSize], i); + testLabel(i) = inputLabel(order[i + trainSize]); } } From 2fb7ee4610f82c15d64e0d6fc5bdcce118ffd727 Mon Sep 17 00:00:00 2001 From: stereomatchingkiss Date: Sat, 27 Feb 2016 11:30:45 +0800 Subject: [PATCH 07/21] support multiple slice --- src/mlpack/core/util/split_data.hpp | 68 ++++++----------------------- 1 file changed, 13 insertions(+), 55 deletions(-) diff --git a/src/mlpack/core/util/split_data.hpp b/src/mlpack/core/util/split_data.hpp index 17bbc5a3a50..7789e90500f 100644 --- a/src/mlpack/core/util/split_data.hpp +++ b/src/mlpack/core/util/split_data.hpp @@ -25,6 +25,7 @@ class TrainTestSplit * @param slice indicate how many slice(depth) per image, * this parameter only work on arma::Cube * @param seed seed of the random device + * @warning slice should not less than 1 */ TrainTestSplit(double testRatio, size_t slice = 1, @@ -32,7 +33,11 @@ class TrainTestSplit seed(seed), slice(slice), testRatio(testRatio) - {} + { + if(slice < 1){ + throw std::out_of_range("The range of slice should not less than 1"); + } + } /** *Split training data and test data, please define @@ -132,6 +137,9 @@ class TrainTestSplit } void Slice(size_t value) { + if(value < 1){ + throw std::out_of_range("The range of slice should not less than 1"); + } slice = value; } @@ -157,7 +165,10 @@ class TrainTestSplit void ExtractData(arma::Cube const &input, arma::Cube &output, size_t inputIndex, size_t outputIndex) const { - output.slice(outputIndex) = input.slice(inputIndex); + outputIndex *= slice; + inputIndex *= slice; + output.slices(outputIndex, outputIndex + slice - 1) = + input.slices(inputIndex, inputIndex + slice - 1); } template @@ -193,59 +204,6 @@ class TrainTestSplit double testRatio; }; -/** - *Split training data and test data, please define - *ARMA_USE_CXX11 to enable move of c++11 - *@param input input data want to split - *@param label input label want to split - *@param testRatio the ratio of test data - *@param seed seed of the random device - *@code - *arma::mat trainData = loadData(); - *arma::Row label = loadLabel(); - *std::random_device rd; - *auto trainTest = TrainTestSplit(trainData, label, 0.25, rd()); - *@endcode - */ -/*template -std::tuple, arma::Row> -TrainTestSplit(T const &input, - arma::Row const &label, - double testRatio, - unsigned int seed = 0) -{ - size_t const testSize = - static_cast(details::extractSize(input) * testRatio); - size_t const trainSize = details::extractSize(input) - testSize; - - T trainData = details::createData(input, trainSize); - T testData = details::createData(input, testSize); - arma::Row trainLabel(trainSize); - arma::Row testLabel(testSize); - - std::vector permutation(details::extractSize(input)); - std::iota(std::begin(permutation), std::end(permutation), 0); - - std::mt19937 gen(seed); - std::shuffle(std::begin(permutation), std::end(permutation), gen); - - for(size_t i = 0; i != trainSize; ++i) - { - details::extractData(input, trainData, permutation[i], i); - trainLabel(i) = label(permutation[i]); - } - - for(size_t i = 0; i != testSize; ++i) - { - details::extractData(input, testData, - permutation[i + trainSize], i); - testLabel(i) = label(permutation[i + trainSize]); - } - - return std::make_tuple(trainData, testData, trainLabel, testLabel); -}*/ - } // namespace util } // namespace mlpack From 0172424186220b40215164b35bf26e48085604dd Mon Sep 17 00:00:00 2001 From: stereomatchingkiss Date: Sun, 28 Feb 2016 09:54:31 +0800 Subject: [PATCH 08/21] need to specify slice number explicitly --- src/mlpack/core/util/split_data.hpp | 11 +++++------ 1 file changed, 5 insertions(+), 6 deletions(-) diff --git a/src/mlpack/core/util/split_data.hpp b/src/mlpack/core/util/split_data.hpp index 7789e90500f..f7c84614b96 100644 --- a/src/mlpack/core/util/split_data.hpp +++ b/src/mlpack/core/util/split_data.hpp @@ -28,7 +28,7 @@ class TrainTestSplit * @warning slice should not less than 1 */ TrainTestSplit(double testRatio, - size_t slice = 1, + size_t slice, arma::arma_rng::seed_type seed = 0) : seed(seed), slice(slice), @@ -69,8 +69,7 @@ class TrainTestSplit { size_t const testSize = static_cast(ExtractSize(input) * testRatio); - size_t const trainSize = ExtractSize(input) - testSize; - + size_t const trainSize = ExtractSize(input) - testSize; ResizeData(input, trainData, trainSize); ResizeData(input, testData, testSize); trainLabel.set_size(trainSize); @@ -79,7 +78,7 @@ class TrainTestSplit using Col = arma::Col; arma_rng::set_seed(seed); Col const sequence = arma::linspace(0, ExtractSize(input) - 1, - ExtractSize(input)); + ExtractSize(input)); arma::Col const order = arma::shuffle(sequence); for(size_t i = 0; i != trainSize; ++i) @@ -180,7 +179,7 @@ class TrainTestSplit template size_t ExtractSize(arma::Cube const &input) const { - return input.n_slices; + return input.n_slices / slice; } template @@ -196,7 +195,7 @@ class TrainTestSplit arma::Cube &output, size_t dataSize) const { - output.set_size(input.n_rows, input.n_cols, dataSize); + output.set_size(input.n_rows, input.n_cols, dataSize * slice); } arma::arma_rng::seed_type seed; From 11e039001ceacf0440ea2e728f601a06a763ef93 Mon Sep 17 00:00:00 2001 From: stereomatchingkiss Date: Sun, 28 Feb 2016 09:58:04 +0800 Subject: [PATCH 09/21] test multi slice and use new api to do the test --- src/mlpack/tests/split_data_test.cpp | 75 +++++++++++++++++++++------- 1 file changed, 57 insertions(+), 18 deletions(-) diff --git a/src/mlpack/tests/split_data_test.cpp b/src/mlpack/tests/split_data_test.cpp index bb0866fde1b..1bc2c7fb8c3 100644 --- a/src/mlpack/tests/split_data_test.cpp +++ b/src/mlpack/tests/split_data_test.cpp @@ -40,42 +40,81 @@ void compareData(arma::cube const &inputData, arma::cube const &compareData, } } -BOOST_AUTO_TEST_CASE(SplitDataSplitResultMat) +void compareData(arma::cube const &inputData, arma::cube const &compareData, + arma::Row const &inputLabel, size_t slice) { - arma::mat trainData(2,10); - trainData.randu(); - arma::Row labels(trainData.n_cols); - for(size_t i = 0; i != labels.n_cols; ++i){ - labels(i) = i; + for(size_t i = 0; i != inputLabel.size(); ++i){ + size_t const inputIndex = inputLabel(i)*slice; + arma::cube const &lhs = inputData.slices(inputIndex, + inputIndex+slice-1); + arma::cube const &rhs = compareData.slices(i*slice, i+slice-1); + for(size_t j = 0; j != lhs.size(); ++j){ + BOOST_REQUIRE_CLOSE(lhs(j), rhs(j), 1e-5); + } } +} + +BOOST_AUTO_TEST_CASE(SplitDataSplitResultMat) +{ + arma::mat input(2,10); + input.randu(); + using Labels = arma::Row; + Labels const labels = + arma::linspace(0, input.n_cols-1, + input.n_cols); - auto const value = util::TrainTestSplit(trainData, labels, 0.2); + util::TrainTestSplit tts(0.2, 1); + auto const value = tts.Split(input, labels); BOOST_REQUIRE(std::get<0>(value).n_cols == 8); BOOST_REQUIRE(std::get<1>(value).n_cols == 2); BOOST_REQUIRE(std::get<2>(value).n_cols == 8); BOOST_REQUIRE(std::get<3>(value).n_cols == 2); - compareData(trainData, std::get<0>(value), std::get<2>(value)); - compareData(trainData, std::get<1>(value), std::get<3>(value)); + compareData(input, std::get<0>(value), std::get<2>(value)); + compareData(input, std::get<1>(value), std::get<3>(value)); } BOOST_AUTO_TEST_CASE(SplitDataSplitResultCube) { - arma::cube trainData(2,2,10); - trainData.randu(); - arma::Row labels(trainData.n_slices); - for(size_t i = 0; i != labels.n_cols; ++i){ - labels(i) = i; - } + arma::cube input(2,2,10); + input.randu(); + using Labels = arma::Row; + Labels const labels = + arma::linspace(0, input.n_slices-1, + input.n_slices); - auto const value = util::TrainTestSplit(trainData, labels, 0.2); + util::TrainTestSplit tts(0.2, 1); + auto const value = tts.Split(input, labels); BOOST_REQUIRE(std::get<0>(value).n_slices == 8); BOOST_REQUIRE(std::get<1>(value).n_slices == 2); BOOST_REQUIRE(std::get<2>(value).n_cols == 8); BOOST_REQUIRE(std::get<3>(value).n_cols == 2); - compareData(trainData, std::get<0>(value), std::get<2>(value)); - compareData(trainData, std::get<1>(value), std::get<3>(value)); + compareData(input, std::get<0>(value), std::get<2>(value)); + compareData(input, std::get<1>(value), std::get<3>(value)); +} + +BOOST_AUTO_TEST_CASE(SplitDataSplitResultCubeMultiSlice) +{ + size_t const slice = 3; + arma::cube input(2,2,slice*2); + input.randu(); + using Labels = arma::Row; + Labels const labels = + arma::linspace(0, input.n_slices/slice-1, + input.n_slices/slice); + + util::TrainTestSplit tts(0.5, slice); + auto const value = tts.Split(input, labels); + BOOST_REQUIRE(std::get<0>(value).n_slices == 3); + BOOST_REQUIRE(std::get<1>(value).n_slices == 3); + BOOST_REQUIRE(std::get<2>(value).n_cols == 1); + BOOST_REQUIRE(std::get<3>(value).n_cols == 1); + + compareData(input, std::get<0>(value), + std::get<2>(value), slice); + compareData(input, std::get<1>(value), + std::get<3>(value), slice); } BOOST_AUTO_TEST_SUITE_END(); From ce379138e9d033384980dc9089e8f33224fcb982 Mon Sep 17 00:00:00 2001 From: stereomatchingkiss Date: Sun, 28 Feb 2016 10:43:44 +0800 Subject: [PATCH 10/21] fix bug--miss namespce --- src/mlpack/core/util/split_data.hpp | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/mlpack/core/util/split_data.hpp b/src/mlpack/core/util/split_data.hpp index f7c84614b96..73359f53379 100644 --- a/src/mlpack/core/util/split_data.hpp +++ b/src/mlpack/core/util/split_data.hpp @@ -76,7 +76,7 @@ class TrainTestSplit testLabel.set_size(testSize); using Col = arma::Col; - arma_rng::set_seed(seed); + arma::arma_rng::set_seed(seed); Col const sequence = arma::linspace(0, ExtractSize(input) - 1, ExtractSize(input)); arma::Col const order = arma::shuffle(sequence); From 7af98adfdf96d000d5dca726d6b1a836a542f2b4 Mon Sep 17 00:00:00 2001 From: stereomatchingkiss Date: Thu, 3 Mar 2016 01:13:12 +0800 Subject: [PATCH 11/21] make the type of arma::Row become generic --- src/mlpack/core/util/split_data.hpp | 18 +++++++++--------- 1 file changed, 9 insertions(+), 9 deletions(-) diff --git a/src/mlpack/core/util/split_data.hpp b/src/mlpack/core/util/split_data.hpp index 73359f53379..c6413ddda51 100644 --- a/src/mlpack/core/util/split_data.hpp +++ b/src/mlpack/core/util/split_data.hpp @@ -59,13 +59,13 @@ class TrainTestSplit * testLabel); *@endcode */ - template + template void Split(T const &input, - arma::Row const &inputLabel, + arma::Row const &inputLabel, T &trainData, T &testData, - arma::Row &trainLabel, - arma::Row &testLabel) + arma::Row &trainLabel, + arma::Row &testLabel) { size_t const testSize = static_cast(ExtractSize(input) * testRatio); @@ -103,16 +103,16 @@ class TrainTestSplit *@return They are trainData, testData, trainLabel and *testLabel */ - template + template std::tuple, arma::Row> + arma::Row, arma::Row> Split(T const &input, - arma::Row const &inputLabel) + arma::Row const &inputLabel) { T trainData; T testData; - arma::Row trainLabel; - arma::Row testLabel; + arma::Row trainLabel; + arma::Row testLabel; Split(input, inputLabel, trainData, testData, trainLabel, testLabel); From 5a8c50d46a46d91edbaa5fea3652024ffec3f6db Mon Sep 17 00:00:00 2001 From: stereomatchingkiss Date: Tue, 15 Mar 2016 09:01:37 +0800 Subject: [PATCH 12/21] remove support of cube --- src/mlpack/core/util/split_data.hpp | 94 ++++++----------------------- 1 file changed, 20 insertions(+), 74 deletions(-) diff --git a/src/mlpack/core/util/split_data.hpp b/src/mlpack/core/util/split_data.hpp index c6413ddda51..0b1f6699573 100644 --- a/src/mlpack/core/util/split_data.hpp +++ b/src/mlpack/core/util/split_data.hpp @@ -22,21 +22,14 @@ class TrainTestSplit /** * @brief TrainTestSplit * @param testRatio the ratio of test data - * @param slice indicate how many slice(depth) per image, - * this parameter only work on arma::Cube * @param seed seed of the random device * @warning slice should not less than 1 */ - TrainTestSplit(double testRatio, - size_t slice, + TrainTestSplit(double testRatio, arma::arma_rng::seed_type seed = 0) : - seed(seed), - slice(slice), + seed(seed), testRatio(testRatio) - { - if(slice < 1){ - throw std::out_of_range("The range of slice should not less than 1"); - } + { } /** @@ -60,37 +53,36 @@ class TrainTestSplit *@endcode */ template - void Split(T const &input, + void Split(arma::Mat const &input, arma::Row const &inputLabel, - T &trainData, - T &testData, + arma::Mat &trainData, + arma::Mat &testData, arma::Row &trainLabel, arma::Row &testLabel) { size_t const testSize = - static_cast(ExtractSize(input) * testRatio); - size_t const trainSize = ExtractSize(input) - testSize; - ResizeData(input, trainData, trainSize); - ResizeData(input, testData, testSize); + static_cast(input.n_cols * testRatio); + size_t const trainSize = input.n_cols - testSize; + trainData.set_size(input.n_rows, trainSize); + testData.set_size(input.n_rows, testSize); trainLabel.set_size(trainSize); testLabel.set_size(testSize); using Col = arma::Col; arma::arma_rng::set_seed(seed); - Col const sequence = arma::linspace(0, ExtractSize(input) - 1, - ExtractSize(input)); + Col const sequence = arma::linspace(0, input.n_cols - 1, + input.n_cols); arma::Col const order = arma::shuffle(sequence); for(size_t i = 0; i != trainSize; ++i) - { - ExtractData(input, trainData, order[i], i); + { + trainData.col(i) = input.col(order[i]); trainLabel(i) = inputLabel(order[i]); } for(size_t i = 0; i != testSize; ++i) { - ExtractData(input, testData, - order[i + trainSize], i); + testData.col(i) = input.col(order[i + trainSize]); testLabel(i) = inputLabel(order[i + trainSize]); } } @@ -104,13 +96,13 @@ class TrainTestSplit *testLabel */ template - std::tuple, arma::Mat, arma::Row, arma::Row> - Split(T const &input, + Split(arma::Mat const &input, arma::Row const &inputLabel) { - T trainData; - T testData; + arma::Mat trainData; + arma::Mat testData; arma::Row trainLabel; arma::Row testLabel; @@ -151,53 +143,7 @@ class TrainTestSplit return testRatio; } - -private: - template - void ExtractData(arma::Mat const &input, arma::Mat &output, - size_t inputIndex, size_t outputIndex) const - { - output.col(outputIndex) = input.col(inputIndex); - } - - template - void ExtractData(arma::Cube const &input, arma::Cube &output, - size_t inputIndex, size_t outputIndex) const - { - outputIndex *= slice; - inputIndex *= slice; - output.slices(outputIndex, outputIndex + slice - 1) = - input.slices(inputIndex, inputIndex + slice - 1); - } - - template - size_t ExtractSize(arma::Mat const &input) const - { - return input.n_cols; - } - - template - size_t ExtractSize(arma::Cube const &input) const - { - return input.n_slices / slice; - } - - template - void ResizeData(arma::Mat const &input, - arma::Mat &output, - size_t dataSize) const - { - output.set_size(input.n_rows, dataSize); - } - - template - void ResizeData(arma::Cube const &input, - arma::Cube &output, - size_t dataSize) const - { - output.set_size(input.n_rows, input.n_cols, dataSize * slice); - } - +private: arma::arma_rng::seed_type seed; size_t slice; double testRatio; From 860256202d7cba23dd3500a05e36ecd351f1d673 Mon Sep 17 00:00:00 2001 From: stereomatchingkiss Date: Tue, 15 Mar 2016 09:02:19 +0800 Subject: [PATCH 13/21] remove support of cube --- src/mlpack/tests/split_data_test.cpp | 71 +--------------------------- 1 file changed, 1 insertion(+), 70 deletions(-) diff --git a/src/mlpack/tests/split_data_test.cpp b/src/mlpack/tests/split_data_test.cpp index 1bc2c7fb8c3..5d9fcbc5779 100644 --- a/src/mlpack/tests/split_data_test.cpp +++ b/src/mlpack/tests/split_data_test.cpp @@ -28,32 +28,6 @@ void compareData(arma::mat const &inputData, arma::mat const &compareData, } } -void compareData(arma::cube const &inputData, arma::cube const &compareData, - arma::Row const &inputLabel) -{ - for(size_t i = 0; i != compareData.n_slices; ++i){ - arma::mat const &lhsMat = inputData.slice(inputLabel(i)); - arma::mat const &rhsMat = compareData.slice(i); - for(size_t j = 0; j != lhsMat.size(); ++j){ - BOOST_REQUIRE_CLOSE(lhsMat(j), rhsMat(j), 1e-5); - } - } -} - -void compareData(arma::cube const &inputData, arma::cube const &compareData, - arma::Row const &inputLabel, size_t slice) -{ - for(size_t i = 0; i != inputLabel.size(); ++i){ - size_t const inputIndex = inputLabel(i)*slice; - arma::cube const &lhs = inputData.slices(inputIndex, - inputIndex+slice-1); - arma::cube const &rhs = compareData.slices(i*slice, i+slice-1); - for(size_t j = 0; j != lhs.size(); ++j){ - BOOST_REQUIRE_CLOSE(lhs(j), rhs(j), 1e-5); - } - } -} - BOOST_AUTO_TEST_CASE(SplitDataSplitResultMat) { arma::mat input(2,10); @@ -63,7 +37,7 @@ BOOST_AUTO_TEST_CASE(SplitDataSplitResultMat) arma::linspace(0, input.n_cols-1, input.n_cols); - util::TrainTestSplit tts(0.2, 1); + util::TrainTestSplit tts(0.2); auto const value = tts.Split(input, labels); BOOST_REQUIRE(std::get<0>(value).n_cols == 8); BOOST_REQUIRE(std::get<1>(value).n_cols == 2); @@ -74,47 +48,4 @@ BOOST_AUTO_TEST_CASE(SplitDataSplitResultMat) compareData(input, std::get<1>(value), std::get<3>(value)); } -BOOST_AUTO_TEST_CASE(SplitDataSplitResultCube) -{ - arma::cube input(2,2,10); - input.randu(); - using Labels = arma::Row; - Labels const labels = - arma::linspace(0, input.n_slices-1, - input.n_slices); - - util::TrainTestSplit tts(0.2, 1); - auto const value = tts.Split(input, labels); - BOOST_REQUIRE(std::get<0>(value).n_slices == 8); - BOOST_REQUIRE(std::get<1>(value).n_slices == 2); - BOOST_REQUIRE(std::get<2>(value).n_cols == 8); - BOOST_REQUIRE(std::get<3>(value).n_cols == 2); - - compareData(input, std::get<0>(value), std::get<2>(value)); - compareData(input, std::get<1>(value), std::get<3>(value)); -} - -BOOST_AUTO_TEST_CASE(SplitDataSplitResultCubeMultiSlice) -{ - size_t const slice = 3; - arma::cube input(2,2,slice*2); - input.randu(); - using Labels = arma::Row; - Labels const labels = - arma::linspace(0, input.n_slices/slice-1, - input.n_slices/slice); - - util::TrainTestSplit tts(0.5, slice); - auto const value = tts.Split(input, labels); - BOOST_REQUIRE(std::get<0>(value).n_slices == 3); - BOOST_REQUIRE(std::get<1>(value).n_slices == 3); - BOOST_REQUIRE(std::get<2>(value).n_cols == 1); - BOOST_REQUIRE(std::get<3>(value).n_cols == 1); - - compareData(input, std::get<0>(value), - std::get<2>(value), slice); - compareData(input, std::get<1>(value), - std::get<3>(value), slice); -} - BOOST_AUTO_TEST_SUITE_END(); From 0fe417273cbd509ffc0a2ddded78b880bfd130bb Mon Sep 17 00:00:00 2001 From: stereomatchingkiss Date: Wed, 13 Apr 2016 05:22:40 +0800 Subject: [PATCH 14/21] remove useless data and function --- src/mlpack/core/util/split_data.hpp | 17 ++--------------- 1 file changed, 2 insertions(+), 15 deletions(-) diff --git a/src/mlpack/core/util/split_data.hpp b/src/mlpack/core/util/split_data.hpp index 0b1f6699573..dd4b0df5db0 100644 --- a/src/mlpack/core/util/split_data.hpp +++ b/src/mlpack/core/util/split_data.hpp @@ -120,19 +120,7 @@ class TrainTestSplit arma::arma_rng::seed_type Seed() const { return seed; - } - - size_t Slice() const - { - return slice; - } - void Slice(size_t value) - { - if(value < 1){ - throw std::out_of_range("The range of slice should not less than 1"); - } - slice = value; - } + } void TestRatio(double value) { @@ -144,8 +132,7 @@ class TrainTestSplit } private: - arma::arma_rng::seed_type seed; - size_t slice; + arma::arma_rng::seed_type seed; double testRatio; }; From 595ff7061fc3efc327553dbe18ae603d5e93f326 Mon Sep 17 00:00:00 2001 From: stereomatchingkiss Date: Wed, 13 Apr 2016 05:32:10 +0800 Subject: [PATCH 15/21] 1 : remove seed variable 2 : do not store testRatio, pass in by function --- src/mlpack/core/util/split_data.hpp | 51 +++++------------------------ 1 file changed, 8 insertions(+), 43 deletions(-) diff --git a/src/mlpack/core/util/split_data.hpp b/src/mlpack/core/util/split_data.hpp index dd4b0df5db0..45386640b69 100644 --- a/src/mlpack/core/util/split_data.hpp +++ b/src/mlpack/core/util/split_data.hpp @@ -18,27 +18,13 @@ namespace util { */ class TrainTestSplit { -public: - /** - * @brief TrainTestSplit - * @param testRatio the ratio of test data - * @param seed seed of the random device - * @warning slice should not less than 1 - */ - TrainTestSplit(double testRatio, - arma::arma_rng::seed_type seed = 0) : - seed(seed), - testRatio(testRatio) - { - } - +public: /** *Split training data and test data, please define *ARMA_USE_CXX11 to enable move of c++11 *@param input input data want to split *@param label input label want to split - *@param testRatio the ratio of test data - *@param seed seed of the random device + *@param testRatio the ratio of test data *@code *arma::mat input = loadData(); *arma::Row label = loadLabel(); @@ -58,7 +44,8 @@ class TrainTestSplit arma::Mat &trainData, arma::Mat &testData, arma::Row &trainLabel, - arma::Row &testLabel) + arma::Row &testLabel, + double testRatio) { size_t const testSize = static_cast(input.n_cols * testRatio); @@ -68,8 +55,7 @@ class TrainTestSplit trainLabel.set_size(trainSize); testLabel.set_size(testSize); - using Col = arma::Col; - arma::arma_rng::set_seed(seed); + using Col = arma::Col; Col const sequence = arma::linspace(0, input.n_cols - 1, input.n_cols); arma::Col const order = arma::shuffle(sequence); @@ -99,7 +85,8 @@ class TrainTestSplit std::tuple, arma::Mat, arma::Row, arma::Row> Split(arma::Mat const &input, - arma::Row const &inputLabel) + arma::Row const &inputLabel, + double testRatio) { arma::Mat trainData; arma::Mat testData; @@ -107,33 +94,11 @@ class TrainTestSplit arma::Row testLabel; Split(input, inputLabel, trainData, testData, - trainLabel, testLabel); + trainLabel, testLabel, testRatio); return std::make_tuple(trainData, testData, trainLabel, testLabel); } - - void Seed(arma::arma_rng::seed_type value) - { - seed = value; - } - arma::arma_rng::seed_type Seed() const - { - return seed; - } - - void TestRatio(double value) - { - testRatio = value; - } - double TestRatio() const - { - return testRatio; - } - -private: - arma::arma_rng::seed_type seed; - double testRatio; }; } // namespace util From 75e946e5fa8ebea610e7f74f5d6a68a2b70d2a97 Mon Sep 17 00:00:00 2001 From: stereomatchingkiss Date: Wed, 13 Apr 2016 05:47:22 +0800 Subject: [PATCH 16/21] change class to function --- src/mlpack/core/util/split_data.hpp | 150 ++++++++++++++-------------- 1 file changed, 74 insertions(+), 76 deletions(-) diff --git a/src/mlpack/core/util/split_data.hpp b/src/mlpack/core/util/split_data.hpp index 45386640b69..32ccaa60762 100644 --- a/src/mlpack/core/util/split_data.hpp +++ b/src/mlpack/core/util/split_data.hpp @@ -15,91 +15,89 @@ namespace util { /** *Split training data and test data, please define *ARMA_USE_CXX11 to enable move of c++11 + *@param input input data want to split + *@param label input label want to split + *@param trainData training data split by input + *@param testData test data split by input + *@param trainLabel train label split by input + *@param testLabel test label split by input + *@param testRatio the ratio of test data + *@code + *arma::mat input = loadData(); + *arma::Row label = loadLabel(); + *arma::mat trainData; + *arma::mat testData; + *arma::Row trainLabel; + *arma::Row testLabel; + *std::random_device rd; + *TrainTestSplit tts(0.25); + *TrainTestSplit(input, label, trainData, + * testData, trainLabel, testLabel); + *@endcode */ -class TrainTestSplit +template +void TrainTestSplit(arma::Mat const &input, + arma::Row const &inputLabel, + arma::Mat &trainData, + arma::Mat &testData, + arma::Row &trainLabel, + arma::Row &testLabel, + double const testRatio) { -public: - /** - *Split training data and test data, please define - *ARMA_USE_CXX11 to enable move of c++11 - *@param input input data want to split - *@param label input label want to split - *@param testRatio the ratio of test data - *@code - *arma::mat input = loadData(); - *arma::Row label = loadLabel(); - *arma::mat trainData; - *arma::mat testData; - *arma::Row trainLabel; - *arma::Row testLabel; - *std::random_device rd; - *TrainTestSplit tts(0.25); - *tts.Split(input, label, trainData, testData, trainLabel, - * testLabel); - *@endcode - */ - template - void Split(arma::Mat const &input, - arma::Row const &inputLabel, - arma::Mat &trainData, - arma::Mat &testData, - arma::Row &trainLabel, - arma::Row &testLabel, - double testRatio) - { - size_t const testSize = - static_cast(input.n_cols * testRatio); - size_t const trainSize = input.n_cols - testSize; - trainData.set_size(input.n_rows, trainSize); - testData.set_size(input.n_rows, testSize); - trainLabel.set_size(trainSize); - testLabel.set_size(testSize); - - using Col = arma::Col; - Col const sequence = arma::linspace(0, input.n_cols - 1, - input.n_cols); - arma::Col const order = arma::shuffle(sequence); + size_t const testSize = + static_cast(input.n_cols * testRatio); + size_t const trainSize = input.n_cols - testSize; + trainData.set_size(input.n_rows, trainSize); + testData.set_size(input.n_rows, testSize); + trainLabel.set_size(trainSize); + testLabel.set_size(testSize); - for(size_t i = 0; i != trainSize; ++i) - { - trainData.col(i) = input.col(order[i]); - trainLabel(i) = inputLabel(order[i]); - } + using Col = arma::Col; + Col const sequence = arma::linspace(0, input.n_cols - 1, + input.n_cols); + arma::Col const order = arma::shuffle(sequence); - for(size_t i = 0; i != testSize; ++i) - { - testData.col(i) = input.col(order[i + trainSize]); - testLabel(i) = inputLabel(order[i + trainSize]); - } + for(size_t i = 0; i != trainSize; ++i) + { + trainData.col(i) = input.col(order[i]); + trainLabel(i) = inputLabel(order[i]); } - /** - *Overload of Split, if you do not like to pass in - *so many param, you could call this api instead - *@param input input data want to split - *@param label input label want to split - *@return They are trainData, testData, trainLabel and - *testLabel - */ - template - std::tuple, arma::Mat, - arma::Row, arma::Row> - Split(arma::Mat const &input, - arma::Row const &inputLabel, - double testRatio) + for(size_t i = 0; i != testSize; ++i) { - arma::Mat trainData; - arma::Mat testData; - arma::Row trainLabel; - arma::Row testLabel; + testData.col(i) = input.col(order[i + trainSize]); + testLabel(i) = inputLabel(order[i + trainSize]); + } +} - Split(input, inputLabel, trainData, testData, - trainLabel, testLabel, testRatio); +/** + *Overload of Split, if you do not like to pass in + *so many param, you could call this api instead + *@param input input data want to split + *@param label input label want to split + *@return They are trainData, testData, trainLabel and + *testLabel + */ +template +std::tuple, arma::Mat, +arma::Row, arma::Row> +TrainTestSplit(arma::Mat const &input, + arma::Row const &inputLabel, + double const testRatio) +{ + arma::Mat trainData; + arma::Mat testData; + arma::Row trainLabel; + arma::Row testLabel; - return std::make_tuple(trainData, testData, - trainLabel, testLabel); - } -}; + TrainTestSplit(input, inputLabel, + trainData, testData, + trainLabel, testLabel, + testRatio); + + return std::make_tuple(trainData, testData, + trainLabel, testLabel); +} } // namespace util } // namespace mlpack From 7723421c6415787fe60cd2ace74e6ff05039e2e3 Mon Sep 17 00:00:00 2001 From: stereomatchingkiss Date: Wed, 13 Apr 2016 06:16:41 +0800 Subject: [PATCH 17/21] refine test case --- src/mlpack/tests/split_data_test.cpp | 54 ++++++++++++++++------------ 1 file changed, 31 insertions(+), 23 deletions(-) diff --git a/src/mlpack/tests/split_data_test.cpp b/src/mlpack/tests/split_data_test.cpp index 5d9fcbc5779..8ec45911c2d 100644 --- a/src/mlpack/tests/split_data_test.cpp +++ b/src/mlpack/tests/split_data_test.cpp @@ -16,36 +16,44 @@ using namespace arma; BOOST_AUTO_TEST_SUITE(SplitDataTest); -void compareData(arma::mat const &inputData, arma::mat const &compareData, +/** + * compare the data after train test split + * @param inputData The original data set before split + * @param compareData The data want to compare with the inputData, + * it could be train data or test data + * @param inputLabel The label of the compareData + */ +void CompareData(arma::mat const &inputData, arma::mat const &compareData, arma::Row const &inputLabel) { - for(size_t i = 0; i != compareData.n_cols; ++i){ - arma::mat const &lhsCol = inputData.col(inputLabel(i)); - arma::mat const &rhsCol = compareData.col(i); - for(size_t j = 0; j != lhsCol.n_rows; ++j){ - BOOST_REQUIRE_CLOSE(lhsCol(j), rhsCol(j), 1e-5); - } + for(size_t i = 0; i != compareData.n_cols; ++i){ + arma::mat const &lhsCol = inputData.col(inputLabel(i)); + arma::mat const &rhsCol = compareData.col(i); + for(size_t j = 0; j != lhsCol.n_rows; ++j){ + BOOST_REQUIRE_CLOSE(lhsCol(j), rhsCol(j), 1e-5); } + } } BOOST_AUTO_TEST_CASE(SplitDataSplitResultMat) { - arma::mat input(2,10); - input.randu(); - using Labels = arma::Row; - Labels const labels = - arma::linspace(0, input.n_cols-1, - input.n_cols); - - util::TrainTestSplit tts(0.2); - auto const value = tts.Split(input, labels); - BOOST_REQUIRE(std::get<0>(value).n_cols == 8); - BOOST_REQUIRE(std::get<1>(value).n_cols == 2); - BOOST_REQUIRE(std::get<2>(value).n_cols == 8); - BOOST_REQUIRE(std::get<3>(value).n_cols == 2); - - compareData(input, std::get<0>(value), std::get<2>(value)); - compareData(input, std::get<1>(value), std::get<3>(value)); + arma::mat input(2,10); + input.randu(); + using Labels = arma::Row; + //set the labels range same as the col, so the CompareData + //can compare the data after TrainTestSplit are valid or not + Labels const labels = + arma::linspace(0, input.n_cols-1, + input.n_cols); + + auto const value = util::TrainTestSplit(input, labels, 0.2); + BOOST_REQUIRE(std::get<0>(value).n_cols == 8); + BOOST_REQUIRE(std::get<1>(value).n_cols == 2); + BOOST_REQUIRE(std::get<2>(value).n_cols == 8); + BOOST_REQUIRE(std::get<3>(value).n_cols == 2); + + CompareData(input, std::get<0>(value), std::get<2>(value)); + CompareData(input, std::get<1>(value), std::get<3>(value)); } BOOST_AUTO_TEST_SUITE_END(); From b5790a0030b99f4999264f10a48584b6670dafcf Mon Sep 17 00:00:00 2001 From: stereomatchingkiss Date: Wed, 13 Apr 2016 07:40:40 +0800 Subject: [PATCH 18/21] refine comments --- src/mlpack/core/util/split_data.hpp | 8 ++++++-- 1 file changed, 6 insertions(+), 2 deletions(-) diff --git a/src/mlpack/core/util/split_data.hpp b/src/mlpack/core/util/split_data.hpp index 32ccaa60762..fc695c7a089 100644 --- a/src/mlpack/core/util/split_data.hpp +++ b/src/mlpack/core/util/split_data.hpp @@ -29,8 +29,7 @@ namespace util { *arma::mat testData; *arma::Row trainLabel; *arma::Row testLabel; - *std::random_device rd; - *TrainTestSplit tts(0.25); + *arma::arma_rng::set_seed(100); //set the seed if you like *TrainTestSplit(input, label, trainData, * testData, trainLabel, testLabel); *@endcode @@ -77,6 +76,11 @@ void TrainTestSplit(arma::Mat const &input, *@param label input label want to split *@return They are trainData, testData, trainLabel and *testLabel + *@code + *arma::mat input = loadData(); + *arma::Row label = loadLabel(); + *auto splitResult = TrainTestSplit(input, label, 0.2); + *@endcode */ template std::tuple, arma::Mat, From 6f213a3d077ddde4281e072e2a975398c67a253f Mon Sep 17 00:00:00 2001 From: stereomatchingkiss Date: Sat, 16 Apr 2016 12:27:03 +0800 Subject: [PATCH 19/21] add ARMA_USE_CXX11 if the compiler is msvc --- src/mlpack/prereqs.hpp | 1 + 1 file changed, 1 insertion(+) diff --git a/src/mlpack/prereqs.hpp b/src/mlpack/prereqs.hpp index 5b93936e527..b81b8c4929b 100644 --- a/src/mlpack/prereqs.hpp +++ b/src/mlpack/prereqs.hpp @@ -77,6 +77,7 @@ // it's part of the C++11 standard. #ifdef _MSC_VER #pragma warning(disable : 4519) + #define ARMA_USE_CXX11 #endif #endif From c92dd2fc6561b1cf40efb744dcc9233bc3675c91 Mon Sep 17 00:00:00 2001 From: stereomatchingkiss Date: Sat, 16 Apr 2016 13:26:21 +0800 Subject: [PATCH 20/21] remove suggestions of add ARMA_USE_CXX11 --- src/mlpack/core/util/split_data.hpp | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/src/mlpack/core/util/split_data.hpp b/src/mlpack/core/util/split_data.hpp index fc695c7a089..04fc6f01afc 100644 --- a/src/mlpack/core/util/split_data.hpp +++ b/src/mlpack/core/util/split_data.hpp @@ -13,8 +13,7 @@ namespace mlpack { namespace util { /** - *Split training data and test data, please define - *ARMA_USE_CXX11 to enable move of c++11 + *Split training data and test data *@param input input data want to split *@param label input label want to split *@param trainData training data split by input From 54c7f22a6ddaf68cca491ad098cc27324740a8fa Mon Sep 17 00:00:00 2001 From: stereomatchingkiss Date: Wed, 20 Apr 2016 14:46:09 +0800 Subject: [PATCH 21/21] 1 : fix style 2 : remove useless header 3 : reduce temporary variable --- src/mlpack/core/util/split_data.hpp | 25 ++++++++++--------------- 1 file changed, 10 insertions(+), 15 deletions(-) diff --git a/src/mlpack/core/util/split_data.hpp b/src/mlpack/core/util/split_data.hpp index 04fc6f01afc..1ddd3a7b939 100644 --- a/src/mlpack/core/util/split_data.hpp +++ b/src/mlpack/core/util/split_data.hpp @@ -3,11 +3,7 @@ #include -#include -#include -#include #include -#include namespace mlpack { namespace util { @@ -34,26 +30,25 @@ namespace util { *@endcode */ template -void TrainTestSplit(arma::Mat const &input, - arma::Row const &inputLabel, +void TrainTestSplit(const arma::Mat &input, + const arma::Row &inputLabel, arma::Mat &trainData, arma::Mat &testData, arma::Row &trainLabel, arma::Row &testLabel, - double const testRatio) + const double testRatio) { size_t const testSize = static_cast(input.n_cols * testRatio); - size_t const trainSize = input.n_cols - testSize; + const size_t trainSize = input.n_cols - testSize; trainData.set_size(input.n_rows, trainSize); testData.set_size(input.n_rows, testSize); trainLabel.set_size(trainSize); testLabel.set_size(testSize); - using Col = arma::Col; - Col const sequence = arma::linspace(0, input.n_cols - 1, - input.n_cols); - arma::Col const order = arma::shuffle(sequence); + const arma::Col order = + arma::shuffle(arma::linspace>(0, input.n_cols - 1, + input.n_cols)); for(size_t i = 0; i != trainSize; ++i) { @@ -84,9 +79,9 @@ void TrainTestSplit(arma::Mat const &input, template std::tuple, arma::Mat, arma::Row, arma::Row> -TrainTestSplit(arma::Mat const &input, - arma::Row const &inputLabel, - double const testRatio) +TrainTestSplit(const arma::Mat &input, + const arma::Row &inputLabel, + const double testRatio) { arma::Mat trainData; arma::Mat testData;