Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Adding an option to get stratified Train-Test splits #2662

Closed
Abilityguy opened this issue Oct 6, 2020 · 8 comments
Closed

Adding an option to get stratified Train-Test splits #2662

Abilityguy opened this issue Oct 6, 2020 · 8 comments

Comments

@Abilityguy
Copy link
Contributor

Abilityguy commented Oct 6, 2020

What is the desired addition or change?

Addition of an option to get stratified train-test splits in mlpack. While going through the docs, I noticed we don't have such an option yet. I brought this up in the IRC a few weeks ago and decided to take this up and add an issue on the same.

What is the motivation for this feature?

In stratified train test splits, the proportion of labels seen in the dataset is maintained across the train set and test set. In cases of datasets with imbalanced classes, it would be desirable reflect this class imbalance in the train and test sets.

If applicable, describe how this feature would be implemented.

I referred to the implementation of split in mlpack/src/mlpack/core/data/split_data.hpp and wrote a sample template. Let me know what you guys think about this.

template<typename T, typename U>
std::tuple<arma::Mat<T>, arma::Mat<T>, arma::Row<U>, arma::Row<U>>
StratifiedSplit(const arma::Mat<T>& input,
                const arma::Row<U>& inputLabel,
                const double testRatio)
 {
   arma::Row<U> uniqueLabel = arma::unique(inputLabel); \\get unique labels in the dataset

   for(auto l : uniqueLabel) \\ iterate through the unique labels
   {
     arma::uvec uniqueIndexes = arma::find(inputLabel == l); \\ get the indexes of label 'l' only

     const size_t testSize = static_cast<size_t>(uniqueIndexes.n_rows*testRatio); 
     const size_t trainSize = uniqueIndexes.n_rows - testSize;

     \\ Dividing the data of label 'l' among the train and test strata using the indexes

     arma::Mat<T> trainStrata = input.cols(uniqueIndexes.subvec(0, trainSize - 1));
     arma::Row<U> trainLabelStrata = inputLabel.cols(uniqueIndexes.subvec(0, trainSize - 1));

     arma::Mat<T> testStrata = input.cols(uniqueIndexes.subvec(trainSize, uniqueIndexes.n_rows - 1));
     arma::Row<U> testLabelStrata = inputLabel.cols(uniqueIndexes.subvec(trainSize, uniqueIndexes.n_rows - 1));

    \\ concatenating train and test strata to the train and test data. 
    \\ This way we iteratively build the train and test set through concatenation

     trainData = join_rows(trainData,trainStrata);
     trainLabel = join_rows(trainLabel, trainLabelStrata);

     testData = join_rows(testData, testStrata);
     testLabel = join_rows(testLabel, testLabelStrata);
   }
   return std::make_tuple(std::move(trainData),
                          std::move(testData),
                          std::move(trainLabel),
                          std::move(testLabel));
 }

Additional information?

I integrated this with mlpack_process_split and ran tests on a few datasets.

Dataset 1: covertype dataset (https://www.mlpack.org/datasets/covertype-small.data.csv.gz)

Dataset size - 54 x 100000
Test ratio - 0.3

Label-wise splits:
Dataset -   36307: 49000: 6133: 481: 1663: 2990: 3426
Train Set - 25415: 34300: 4294: 337: 1165: 2093: 2399
Test Set -  10892: 14700: 1839: 144:  498:  897: 1027

[INFO ] Program timers:
[INFO ]   loading_data: 0.458566s
[INFO ]   saving_data: 2.796008s
[INFO ]   total_time: 0.641220s

Dataset 2: MNIST train dataset from Kaggle (https://www.kaggle.com/c/digit-recognizer/data)

Dataset size - 784 x 42000
Test ratio - 0.2

Label-wise splits:
Dataset -   4132: 4684: 4177: 4351: 4072: 3795: 4137: 4401: 4063: 4188
Train Set - 3306: 3748: 3342: 3481: 3258: 3036: 3310: 3521: 3251: 3351
Test Set -   826:  936:  835:  870:  814:  759:  827:  880:  812:  837

[INFO ] Program timers:
[INFO ]   loading_data: 2.912084s
[INFO ]   saving_data: 19.503108s
[INFO ]   total_time: 3.998145s

I was able to plot images from the train set so it seems the data is being split correctly.

Should I go ahead and make a PR on this?
Let me know if you guys have any suggestions or changes on this.

@Abilityguy Abilityguy changed the title Adding an option to get stratified train test splits Adding an option to get stratified Train-Test splits Oct 6, 2020
@zoq
Copy link
Member

zoq commented Oct 7, 2020

Thanks for opening the detailed issue, feel free to open a PR on this on, my only concern is that it might be more efficient to build an index vector first and then to use cols at the end instead of using join_rows which could be expensive.

@Abilityguy
Copy link
Contributor Author

The index vector idea makes sense. I will change the implementation to reflect that. Thanks @zoq.

@zoq
Copy link
Member

zoq commented Oct 8, 2020

Sounds good, thanks.

@rcurtin
Copy link
Member

rcurtin commented Oct 10, 2020

@Abilityguy looks like you did most of the work already! 😄 As a fun challenge, I think that this can be done most efficiently with only a single pass over the label set and a single pass over the dataset. 👍

@Abilityguy
Copy link
Contributor Author

@rcurtin challenge accepted! 😂
Not completely sure about the idea right now but I think it's possible. Will make an attempt 👍!

@mlpack-bot
Copy link

mlpack-bot bot commented Nov 9, 2020

This issue has been automatically marked as stale because it has not had any recent activity. It will be closed in 7 days if no further activity occurs. Thank you for your contributions! 👍

@mlpack-bot mlpack-bot bot added the s: stale label Nov 9, 2020
@Abilityguy
Copy link
Contributor Author

Keep open.

@mlpack-bot mlpack-bot bot removed the s: stale label Nov 9, 2020
@rcurtin
Copy link
Member

rcurtin commented Nov 15, 2020

Thabks @Abilityguy! I forgot that this issue was open and needed to be closed after the PR was merged. :)

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

No branches or pull requests

3 participants