-
Notifications
You must be signed in to change notification settings - Fork 129
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Stratified Splitter #201
Merged
Merged
Stratified Splitter #201
Changes from all commits
Commits
Show all changes
14 commits
Select commit
Hold shift + click to select a range
11243eb
add stratified splitter
mottodora 6cc5891
add unit tests about classification task
mottodora ab9aaa9
support regression dataset
mottodora 6328bbe
add unit tests for regression dataset
mottodora 4318564
[feature] seed fix
mottodora 1d6b05d
add documents
mottodora 8f03961
fix
mottodora c7998c4
fix bugs
mottodora ee840b0
apply comments
mottodora a6e79e5
support ndarray dataset
mottodora f3739b9
add documents
mottodora dba27a5
apply comments
mottodora b06c333
remove unnecessary files
mottodora 5944803
add some tests
mottodora File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,5 +1,7 @@ | ||
from chainer_chemistry.dataset.splitters import base_splitter # NOQA | ||
from chainer_chemistry.dataset.splitters import random_splitter # NOQA | ||
from chainer_chemistry.dataset.splitters import stratified_splitter # NOQA | ||
|
||
from chainer_chemistry.dataset.splitters.base_splitter import BaseSplitter # NOQA | ||
from chainer_chemistry.dataset.splitters.random_splitter import RandomSplitter # NOQA | ||
from chainer_chemistry.dataset.splitters.stratified_splitter import StratifiedSplitter # NOQA |
214 changes: 214 additions & 0 deletions
214
chainer_chemistry/dataset/splitters/stratified_splitter.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,214 @@ | ||
import numpy | ||
import pandas | ||
|
||
from chainer_chemistry.dataset.splitters.base_splitter import BaseSplitter | ||
from chainer_chemistry.datasets.numpy_tuple_dataset import NumpyTupleDataset | ||
|
||
|
||
# Refer to scikit-learn | ||
# https://git.io/fPMmB | ||
def _approximate_mode(class_counts, n_draws): | ||
n_class = len(class_counts) | ||
continuous = class_counts * n_draws / class_counts.sum() | ||
floored = numpy.floor(continuous) | ||
assert n_draws // n_class == floored.sum() // n_class | ||
n_remainder = int(n_draws - floored.sum()) | ||
remainder = continuous - floored | ||
inds = numpy.argsort(remainder)[::-1] | ||
inds = inds[:n_remainder] | ||
floored[inds] += 1 | ||
assert n_draws == floored.sum() | ||
return floored.astype(numpy.int) | ||
|
||
|
||
class StratifiedSplitter(BaseSplitter): | ||
"""Class for doing stratified data splits.""" | ||
|
||
def _split(self, dataset, frac_train=0.8, frac_valid=0.1, frac_test=0.1, | ||
labels=None, **kwargs): | ||
numpy.testing.assert_almost_equal(frac_train + frac_valid + frac_test, | ||
1.) | ||
|
||
seed = kwargs.get('seed', None) | ||
label_axis = kwargs.get('label_axis', -1) | ||
task_index = kwargs.get('task_index', 0) | ||
n_bin = kwargs.get('n_bin', 10) | ||
task_type = kwargs.get('task_type', 'auto') | ||
if task_type not in ['classification', 'regression', 'auto']: | ||
raise ValueError("{} is invalid. Please use 'classification'," | ||
"'regression' or 'auto'".format(task_type)) | ||
|
||
rng = numpy.random.RandomState(seed) | ||
|
||
if isinstance(labels, list): | ||
labels = numpy.array(labels) | ||
elif labels is None: | ||
if not isinstance(dataset, NumpyTupleDataset): | ||
raise ValueError("Please assign label dataset.") | ||
labels = dataset.features[:, label_axis] | ||
|
||
if labels.ndim == 1: | ||
labels = labels | ||
else: | ||
labels = labels[:, task_index] | ||
|
||
if task_type == 'auto': | ||
if labels.dtype.kind == 'i': | ||
task_type = 'classification' | ||
elif labels.dtype.kind == 'f': | ||
task_type = 'regression' | ||
else: | ||
raise ValueError | ||
|
||
if task_type == 'classification': | ||
classes, labels = numpy.unique(labels, return_inverse=True) | ||
elif task_type == 'regression': | ||
classes = numpy.arange(n_bin) | ||
labels = pandas.qcut(labels, n_bin, labels=False) | ||
else: | ||
raise ValueError | ||
|
||
n_classes = classes.shape[0] | ||
n_total_valid = int(numpy.floor(frac_valid * len(dataset))) | ||
n_total_test = int(numpy.floor(frac_test * len(dataset))) | ||
|
||
class_counts = numpy.bincount(labels) | ||
class_indices = numpy.split(numpy.argsort(labels, | ||
kind='mergesort'), | ||
numpy.cumsum(class_counts)[:-1]) | ||
|
||
# n_total_train is the remainder: n - n_total_valid - n_total_test | ||
n_valid_samples = _approximate_mode(class_counts, n_total_valid) | ||
class_counts = class_counts - n_valid_samples | ||
n_test_samples = _approximate_mode(class_counts, n_total_test) | ||
|
||
train_index = [] | ||
valid_index = [] | ||
test_index = [] | ||
|
||
for i in range(n_classes): | ||
n_valid = n_valid_samples[i] | ||
n_test = n_test_samples[i] | ||
|
||
perm = rng.permutation(len(class_indices[i])) | ||
class_perm_index = class_indices[i][perm] | ||
|
||
class_valid_index = class_perm_index[:n_valid] | ||
class_test_index = class_perm_index[n_valid:n_valid+n_test] | ||
class_train_index = class_perm_index[n_valid+n_test:] | ||
|
||
train_index.extend(class_train_index) | ||
valid_index.extend(class_valid_index) | ||
test_index.extend(class_test_index) | ||
|
||
assert n_total_valid == len(valid_index) | ||
assert n_total_test == len(test_index) | ||
|
||
return numpy.array(train_index), numpy.array(valid_index),\ | ||
numpy.array(test_index), | ||
|
||
def train_valid_test_split(self, dataset, labels=None, label_axis=-1, | ||
task_index=0, frac_train=0.8, frac_valid=0.1, | ||
frac_test=0.1, converter=None, | ||
return_index=True, seed=None, task_type='auto', | ||
n_bin=10, **kwargs): | ||
"""Generate indices by stratified splittting dataset into train, valid | ||
and test set. | ||
|
||
Args: | ||
dataset(NumpyTupleDataset, numpy.ndarray): | ||
Dataset. | ||
labels(numpy.ndarray): | ||
Target label. If `None`, this function assumes that dataset is | ||
an instance of `NumpyTupleDataset`. | ||
labels_axis(int): | ||
Dataset feature axis in NumpyTupleDataset. | ||
task_index(int): | ||
Target task index in dataset for stratification. | ||
seed (int): | ||
Random seed. | ||
frac_train(float): | ||
Fraction of dataset put into training data. | ||
frac_valid(float): | ||
Fraction of dataset put into validation data. | ||
return_index(bool): | ||
If `True`, this function returns only indexes. If `False`, this | ||
function returns splitted dataset. | ||
|
||
Returns: | ||
SplittedDataset(tuple): | ||
splitted dataset or indexes | ||
|
||
.. admonition:: Example | ||
>>> from chainer_chemistry.datasets import NumpyTupleDataset | ||
>>> from chainer_chemistry.dataset.splitters \ | ||
>>> import StratifiedSplitter | ||
>>> a = numpy.random.random((10, 10)) | ||
>>> b = numpy.random.random((10, 8)) | ||
>>> c = numpy.random.random((10, 1)) | ||
>>> d = NumpyTupleDataset(a, b, c) | ||
>>> splitter = StratifiedSplitter() | ||
>>> train, valid, test = | ||
splitter.train_valid_test_split(dataset, return_index=False) | ||
>>> print(len(train), len(valid)) | ||
8, 1, 1 | ||
""" | ||
|
||
return super(StratifiedSplitter, self)\ | ||
.train_valid_test_split(dataset, frac_train, frac_valid, frac_test, | ||
converter, return_index, seed=seed, | ||
label_axis=label_axis, task_type=task_type, | ||
task_index=task_index, n_bin=n_bin, | ||
labels=labels, **kwargs) | ||
|
||
def train_valid_split(self, dataset, labels=None, label_axis=-1, | ||
task_index=0, frac_train=0.9, frac_valid=0.1, | ||
converter=None, return_index=True, seed=None, | ||
task_type='auto', n_bin=10, **kwargs): | ||
"""Generate indices by stratified splittting dataset into train and | ||
valid set. | ||
|
||
Args: | ||
dataset(NumpyTupleDataset, numpy.ndarray): | ||
Dataset. | ||
labels(numpy.ndarray): | ||
Target label. If `None`, this function assumes that dataset is | ||
an instance of `NumpyTupleDataset`. | ||
labels_axis(int): | ||
Dataset feature axis in NumpyTupleDataset. | ||
task_index(int): | ||
Target task index in dataset for stratification. | ||
seed (int): | ||
Random seed. | ||
frac_train(float): | ||
Fraction of dataset put into training data. | ||
frac_valid(float): | ||
Fraction of dataset put into validation data. | ||
return_index(bool): | ||
If `True`, this function returns only indexes. If `False`, this | ||
function returns splitted dataset. | ||
|
||
Returns: | ||
SplittedDataset(tuple): | ||
splitted dataset or indexes | ||
|
||
.. admonition:: Example | ||
>>> from chainer_chemistry.datasets import NumpyTupleDataset | ||
>>> from chainer_chemistry.dataset.splitters \ | ||
>>> import StratifiedSplitter | ||
>>> a = numpy.random.random((10, 10)) | ||
>>> b = numpy.random.random((10, 8)) | ||
>>> c = numpy.random.random((10, 1)) | ||
>>> d = NumpyTupleDataset(a, b, c) | ||
>>> splitter = StratifiedSplitter() | ||
>>> train, valid = | ||
splitter.train_valid_split(dataset, return_index=False) | ||
>>> print(len(train), len(valid)) | ||
9, 1 | ||
""" | ||
|
||
return super(StratifiedSplitter, self)\ | ||
.train_valid_split(dataset, frac_train, frac_valid, converter, | ||
return_index, seed=seed, label_axis=label_axis, | ||
task_type=task_type, task_index=task_index, | ||
n_bin=n_bin, labels=labels, **kwargs) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
please put comment
# n_total_train is the remainder: n - n_total_valid - n_total_test
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I wrote comments in other place.