Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Stratified Splitter #201

Merged
merged 14 commits into from
Jul 3, 2018
2 changes: 2 additions & 0 deletions chainer_chemistry/dataset/splitters/__init__.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
from chainer_chemistry.dataset.splitters import base_splitter # NOQA
from chainer_chemistry.dataset.splitters import random_splitter # NOQA
from chainer_chemistry.dataset.splitters import stratified_splitter # NOQA

from chainer_chemistry.dataset.splitters.base_splitter import BaseSplitter # NOQA
from chainer_chemistry.dataset.splitters.random_splitter import RandomSplitter # NOQA
from chainer_chemistry.dataset.splitters.stratified_splitter import StratifiedSplitter # NOQA
212 changes: 212 additions & 0 deletions chainer_chemistry/dataset/splitters/stratified_splitter.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,212 @@
import numpy
import pandas

from chainer_chemistry.dataset.splitters.base_splitter import BaseSplitter
from chainer_chemistry.datasets.numpy_tuple_dataset import NumpyTupleDataset


def _approximate_mode(class_counts, n_draws):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

comment: url you referred.

n_class = len(class_counts)
continuous = class_counts * n_draws / class_counts.sum()
floored = numpy.floor(continuous)
assert n_draws // n_class == floored.sum() // n_class
n_remainder = int(n_draws - floored.sum())
remainder = continuous - floored
inds = numpy.argsort(remainder)[::-1]
inds = inds[:n_remainder]
floored[inds] += 1
assert n_draws == floored.sum()
return floored.astype(numpy.int)


class StratifiedSplitter(BaseSplitter):
"""Class for doing stratified data splits."""

def _split(self, dataset, frac_train=0.8, frac_valid=0.1, frac_test=0.1,
labels=None, **kwargs):
numpy.testing.assert_almost_equal(frac_train + frac_valid + frac_test,
1.)

seed = kwargs.get('seed', None)
label_axis = kwargs.get('label_axis', -1)
task_index = kwargs.get('task_index', 0)
n_bin = kwargs.get('n_bin', 10)
task_type = kwargs.get('task_type', 'infer')
if task_type not in ['classification', 'regression', 'infer']:
raise ValueError("{} is invalid. Please use 'classification',"
"'regression' or 'infer'".format(task_type))

rng = numpy.random.RandomState(seed)

if labels is None:
if not isinstance(dataset, NumpyTupleDataset):
raise ValueError("Please assign label dataset.")
labels = dataset.features[:, label_axis]

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

if isinstance(labels, list):
labels = numpy.array(labels)

if len(labels.shape) == 1:
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

labels.ndim

labels = labels
else:
labels = labels[:, task_index]

if task_type == 'infer':
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

auto

if labels.dtype.kind == 'i':
task_type = 'classification'
elif labels.dtype.kind == 'f':
task_type = 'regression'
else:
raise ValueError

if task_type == 'classification':
classes, labels = numpy.unique(labels, return_inverse=True)
elif task_type == 'regression':
classes = numpy.arange(n_bin)
labels = pandas.qcut(labels, n_bin, labels=False)
else:
raise ValueError

n_classes = classes.shape[0]
n_total_valid = int(numpy.floor(frac_valid * len(dataset)))
n_total_test = int(numpy.floor(frac_test * len(dataset)))
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

please put comment
# n_total_train is the remainder: n - n_total_valid - n_total_test

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I wrote comments in other place.


class_counts = numpy.bincount(labels)
class_indices = numpy.split(numpy.argsort(labels,
kind='mergesort'),
numpy.cumsum(class_counts)[:-1])

n_valid_samples = _approximate_mode(class_counts, n_total_valid)
class_counts = class_counts - n_valid_samples
n_test_samples = _approximate_mode(class_counts, n_total_test)

train_index = []
valid_index = []
test_index = []

for i in range(n_classes):
n_valid = n_valid_samples[i]
n_test = n_test_samples[i]

perm = rng.permutation(len(class_indices[i]))
class_perm_index = class_indices[i][perm]

class_valid_index = class_perm_index[:n_valid]
class_test_index = class_perm_index[n_valid:n_valid+n_test]
class_train_index = class_perm_index[n_valid+n_test:]

train_index.extend(class_train_index)
valid_index.extend(class_valid_index)
test_index.extend(class_test_index)

assert n_total_valid == len(valid_index)
assert n_total_test == len(test_index)

return rng.permutation(train_index),\
rng.permutation(valid_index),\
rng.permutation(test_index),
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think it is ok to just return array.


def train_valid_test_split(self, dataset, labels=None, label_axis=-1,
task_index=0, frac_train=0.8, frac_valid=0.1,
frac_test=0.1, converter=None,
return_index=True, seed=None, task_type='infer',
n_bin=10, **kwargs):
"""Generate indices by stratified splittting dataset into train, valid
and test set.

Args:
dataset(NumpyTupleDataset, numpy.ndarray):
Dataset.
labels(numpy.ndarray):
Target label. If `None`, this function assumes that dataset is
an instance of `NumpyTupleDataset`.
labels_axis(int):
Dataset feature axis in NumpyTupleDataset.
task_index(int):
Target task index in dataset for stratification.
seed (int):
Random seed.
frac_train(float):
Fraction of dataset put into training data.
frac_valid(float):
Fraction of dataset put into validation data.
return_index(bool):
If `True`, this function returns only indexes. If `False`, this
function returns splitted dataset.

Returns:
SplittedDataset(tuple):
splitted dataset or indexes

.. admonition:: Example
>>> from chainer_chemistry.datasets import NumpyTupleDataset
>>> from chainer_chemistry.dataset.splitters \
>>> import StratifiedSplitter
>>> a = numpy.random.random((10, 10))
>>> b = numpy.random.random((10, 8))
>>> c = numpy.random.random((10, 1))
>>> d = NumpyTupleDataset(a, b, c)
>>> splitter = StratifiedSplitter()
>>> splitter.train_valid_test_split()
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

remove this line

>>> train, valid, test =
splitter.train_valid_test_split(dataset, return_index=False)
>>> print(len(train), len(valid))
8, 1, 1
"""

return super(StratifiedSplitter, self)\
.train_valid_test_split(dataset, frac_train, frac_valid, frac_test,
converter, return_index, seed=seed,
label_axis=label_axis, task_type=task_type,
task_index=task_index, n_bin=n_bin,
labels=labels, **kwargs)

def train_valid_split(self, dataset, labels=None, label_axis=-1,
task_index=0, frac_train=0.9, frac_valid=0.1,
converter=None, return_index=True, seed=None,
task_type='infer', n_bin=10, **kwargs):
"""Generate indices by stratified splittting dataset into train and
valid set.

Args:
dataset(NumpyTupleDataset, numpy.ndarray):
Dataset.
labels(numpy.ndarray):
Target label. If `None`, this function assumes that dataset is
an instance of `NumpyTupleDataset`.
labels_axis(int):
Dataset feature axis in NumpyTupleDataset.
task_index(int):
Target task index in dataset for stratification.
seed (int):
Random seed.
frac_train(float):
Fraction of dataset put into training data.
frac_valid(float):
Fraction of dataset put into validation data.
return_index(bool):
If `True`, this function returns only indexes. If `False`, this
function returns splitted dataset.

Returns:
SplittedDataset(tuple):
splitted dataset or indexes

.. admonition:: Example
>>> from chainer_chemistry.datasets import NumpyTupleDataset
>>> from chainer_chemistry.dataset.splitters \
>>> import StratifiedSplitter
>>> a = numpy.random.random((10, 10))
>>> b = numpy.random.random((10, 8))
>>> c = numpy.random.random((10, 1))
>>> d = NumpyTupleDataset(a, b, c)
>>> splitter = StratifiedSplitter()
>>> splitter.train_valid_split()
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

remove this line

>>> train, valid =
splitter.train_valid_split(dataset, return_index=False)
>>> print(len(train), len(valid))
9, 1
"""

return super(StratifiedSplitter, self)\
.train_valid_split(dataset, frac_train, frac_valid, converter,
return_index, seed=seed, label_axis=label_axis,
task_type=task_type, task_index=task_index,
n_bin=n_bin, labels=labels, **kwargs)
3 changes: 1 addition & 2 deletions docs/source/dataset.rst
Original file line number Diff line number Diff line change
Expand Up @@ -88,5 +88,4 @@ Splitters
:nosignatures:

chainer_chemistry.dataset.splitters.RandomSplitter
.. autosummary:: chainer_chemistry.dataset.splitters.RandomSplitter
:methods:
chainer_chemistry.dataset.splitters.StratifiedSplitter
Loading