-
Notifications
You must be signed in to change notification settings - Fork 129
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Stratified Splitter #201
Stratified Splitter #201
Changes from 11 commits
11243eb
6cc5891
ab9aaa9
6328bbe
4318564
1d6b05d
8f03961
c7998c4
ee840b0
a6e79e5
f3739b9
dba27a5
b06c333
5944803
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,5 +1,7 @@ | ||
from chainer_chemistry.dataset.splitters import base_splitter # NOQA | ||
from chainer_chemistry.dataset.splitters import random_splitter # NOQA | ||
from chainer_chemistry.dataset.splitters import stratified_splitter # NOQA | ||
|
||
from chainer_chemistry.dataset.splitters.base_splitter import BaseSplitter # NOQA | ||
from chainer_chemistry.dataset.splitters.random_splitter import RandomSplitter # NOQA | ||
from chainer_chemistry.dataset.splitters.stratified_splitter import StratifiedSplitter # NOQA |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,212 @@ | ||
import numpy | ||
import pandas | ||
|
||
from chainer_chemistry.dataset.splitters.base_splitter import BaseSplitter | ||
from chainer_chemistry.datasets.numpy_tuple_dataset import NumpyTupleDataset | ||
|
||
|
||
def _approximate_mode(class_counts, n_draws): | ||
n_class = len(class_counts) | ||
continuous = class_counts * n_draws / class_counts.sum() | ||
floored = numpy.floor(continuous) | ||
assert n_draws // n_class == floored.sum() // n_class | ||
n_remainder = int(n_draws - floored.sum()) | ||
remainder = continuous - floored | ||
inds = numpy.argsort(remainder)[::-1] | ||
inds = inds[:n_remainder] | ||
floored[inds] += 1 | ||
assert n_draws == floored.sum() | ||
return floored.astype(numpy.int) | ||
|
||
|
||
class StratifiedSplitter(BaseSplitter): | ||
"""Class for doing stratified data splits.""" | ||
|
||
def _split(self, dataset, frac_train=0.8, frac_valid=0.1, frac_test=0.1, | ||
labels=None, **kwargs): | ||
numpy.testing.assert_almost_equal(frac_train + frac_valid + frac_test, | ||
1.) | ||
|
||
seed = kwargs.get('seed', None) | ||
label_axis = kwargs.get('label_axis', -1) | ||
task_index = kwargs.get('task_index', 0) | ||
n_bin = kwargs.get('n_bin', 10) | ||
task_type = kwargs.get('task_type', 'infer') | ||
if task_type not in ['classification', 'regression', 'infer']: | ||
raise ValueError("{} is invalid. Please use 'classification'," | ||
"'regression' or 'infer'".format(task_type)) | ||
|
||
rng = numpy.random.RandomState(seed) | ||
|
||
if labels is None: | ||
if not isinstance(dataset, NumpyTupleDataset): | ||
raise ValueError("Please assign label dataset.") | ||
labels = dataset.features[:, label_axis] | ||
|
||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. if isinstance(labels, list): |
||
if len(labels.shape) == 1: | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. labels.ndim |
||
labels = labels | ||
else: | ||
labels = labels[:, task_index] | ||
|
||
if task_type == 'infer': | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. auto |
||
if labels.dtype.kind == 'i': | ||
task_type = 'classification' | ||
elif labels.dtype.kind == 'f': | ||
task_type = 'regression' | ||
else: | ||
raise ValueError | ||
|
||
if task_type == 'classification': | ||
classes, labels = numpy.unique(labels, return_inverse=True) | ||
elif task_type == 'regression': | ||
classes = numpy.arange(n_bin) | ||
labels = pandas.qcut(labels, n_bin, labels=False) | ||
else: | ||
raise ValueError | ||
|
||
n_classes = classes.shape[0] | ||
n_total_valid = int(numpy.floor(frac_valid * len(dataset))) | ||
n_total_test = int(numpy.floor(frac_test * len(dataset))) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. please put comment There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I wrote comments in other place. |
||
|
||
class_counts = numpy.bincount(labels) | ||
class_indices = numpy.split(numpy.argsort(labels, | ||
kind='mergesort'), | ||
numpy.cumsum(class_counts)[:-1]) | ||
|
||
n_valid_samples = _approximate_mode(class_counts, n_total_valid) | ||
class_counts = class_counts - n_valid_samples | ||
n_test_samples = _approximate_mode(class_counts, n_total_test) | ||
|
||
train_index = [] | ||
valid_index = [] | ||
test_index = [] | ||
|
||
for i in range(n_classes): | ||
n_valid = n_valid_samples[i] | ||
n_test = n_test_samples[i] | ||
|
||
perm = rng.permutation(len(class_indices[i])) | ||
class_perm_index = class_indices[i][perm] | ||
|
||
class_valid_index = class_perm_index[:n_valid] | ||
class_test_index = class_perm_index[n_valid:n_valid+n_test] | ||
class_train_index = class_perm_index[n_valid+n_test:] | ||
|
||
train_index.extend(class_train_index) | ||
valid_index.extend(class_valid_index) | ||
test_index.extend(class_test_index) | ||
|
||
assert n_total_valid == len(valid_index) | ||
assert n_total_test == len(test_index) | ||
|
||
return rng.permutation(train_index),\ | ||
rng.permutation(valid_index),\ | ||
rng.permutation(test_index), | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I think it is ok to just return array. |
||
|
||
def train_valid_test_split(self, dataset, labels=None, label_axis=-1, | ||
task_index=0, frac_train=0.8, frac_valid=0.1, | ||
frac_test=0.1, converter=None, | ||
return_index=True, seed=None, task_type='infer', | ||
n_bin=10, **kwargs): | ||
"""Generate indices by stratified splittting dataset into train, valid | ||
and test set. | ||
|
||
Args: | ||
dataset(NumpyTupleDataset, numpy.ndarray): | ||
Dataset. | ||
labels(numpy.ndarray): | ||
Target label. If `None`, this function assumes that dataset is | ||
an instance of `NumpyTupleDataset`. | ||
labels_axis(int): | ||
Dataset feature axis in NumpyTupleDataset. | ||
task_index(int): | ||
Target task index in dataset for stratification. | ||
seed (int): | ||
Random seed. | ||
frac_train(float): | ||
Fraction of dataset put into training data. | ||
frac_valid(float): | ||
Fraction of dataset put into validation data. | ||
return_index(bool): | ||
If `True`, this function returns only indexes. If `False`, this | ||
function returns splitted dataset. | ||
|
||
Returns: | ||
SplittedDataset(tuple): | ||
splitted dataset or indexes | ||
|
||
.. admonition:: Example | ||
>>> from chainer_chemistry.datasets import NumpyTupleDataset | ||
>>> from chainer_chemistry.dataset.splitters \ | ||
>>> import StratifiedSplitter | ||
>>> a = numpy.random.random((10, 10)) | ||
>>> b = numpy.random.random((10, 8)) | ||
>>> c = numpy.random.random((10, 1)) | ||
>>> d = NumpyTupleDataset(a, b, c) | ||
>>> splitter = StratifiedSplitter() | ||
>>> splitter.train_valid_test_split() | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. remove this line |
||
>>> train, valid, test = | ||
splitter.train_valid_test_split(dataset, return_index=False) | ||
>>> print(len(train), len(valid)) | ||
8, 1, 1 | ||
""" | ||
|
||
return super(StratifiedSplitter, self)\ | ||
.train_valid_test_split(dataset, frac_train, frac_valid, frac_test, | ||
converter, return_index, seed=seed, | ||
label_axis=label_axis, task_type=task_type, | ||
task_index=task_index, n_bin=n_bin, | ||
labels=labels, **kwargs) | ||
|
||
def train_valid_split(self, dataset, labels=None, label_axis=-1, | ||
task_index=0, frac_train=0.9, frac_valid=0.1, | ||
converter=None, return_index=True, seed=None, | ||
task_type='infer', n_bin=10, **kwargs): | ||
"""Generate indices by stratified splittting dataset into train and | ||
valid set. | ||
|
||
Args: | ||
dataset(NumpyTupleDataset, numpy.ndarray): | ||
Dataset. | ||
labels(numpy.ndarray): | ||
Target label. If `None`, this function assumes that dataset is | ||
an instance of `NumpyTupleDataset`. | ||
labels_axis(int): | ||
Dataset feature axis in NumpyTupleDataset. | ||
task_index(int): | ||
Target task index in dataset for stratification. | ||
seed (int): | ||
Random seed. | ||
frac_train(float): | ||
Fraction of dataset put into training data. | ||
frac_valid(float): | ||
Fraction of dataset put into validation data. | ||
return_index(bool): | ||
If `True`, this function returns only indexes. If `False`, this | ||
function returns splitted dataset. | ||
|
||
Returns: | ||
SplittedDataset(tuple): | ||
splitted dataset or indexes | ||
|
||
.. admonition:: Example | ||
>>> from chainer_chemistry.datasets import NumpyTupleDataset | ||
>>> from chainer_chemistry.dataset.splitters \ | ||
>>> import StratifiedSplitter | ||
>>> a = numpy.random.random((10, 10)) | ||
>>> b = numpy.random.random((10, 8)) | ||
>>> c = numpy.random.random((10, 1)) | ||
>>> d = NumpyTupleDataset(a, b, c) | ||
>>> splitter = StratifiedSplitter() | ||
>>> splitter.train_valid_split() | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. remove this line |
||
>>> train, valid = | ||
splitter.train_valid_split(dataset, return_index=False) | ||
>>> print(len(train), len(valid)) | ||
9, 1 | ||
""" | ||
|
||
return super(StratifiedSplitter, self)\ | ||
.train_valid_split(dataset, frac_train, frac_valid, converter, | ||
return_index, seed=seed, label_axis=label_axis, | ||
task_type=task_type, task_index=task_index, | ||
n_bin=n_bin, labels=labels, **kwargs) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
comment: url you referred.