Skip to content

Commit

Permalink
Merge pull request #201 from mottodora/stratified-splitter
Browse files Browse the repository at this point in the history
Stratified Splitter
  • Loading branch information
corochann committed Jul 3, 2018
2 parents 63ab672 + 5944803 commit 33b6f16
Show file tree
Hide file tree
Showing 4 changed files with 597 additions and 2 deletions.
2 changes: 2 additions & 0 deletions chainer_chemistry/dataset/splitters/__init__.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
from chainer_chemistry.dataset.splitters import base_splitter # NOQA
from chainer_chemistry.dataset.splitters import random_splitter # NOQA
from chainer_chemistry.dataset.splitters import stratified_splitter # NOQA

from chainer_chemistry.dataset.splitters.base_splitter import BaseSplitter # NOQA
from chainer_chemistry.dataset.splitters.random_splitter import RandomSplitter # NOQA
from chainer_chemistry.dataset.splitters.stratified_splitter import StratifiedSplitter # NOQA
214 changes: 214 additions & 0 deletions chainer_chemistry/dataset/splitters/stratified_splitter.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,214 @@
import numpy
import pandas

from chainer_chemistry.dataset.splitters.base_splitter import BaseSplitter
from chainer_chemistry.datasets.numpy_tuple_dataset import NumpyTupleDataset


# Refer to scikit-learn
# https://git.io/fPMmB
def _approximate_mode(class_counts, n_draws):
n_class = len(class_counts)
continuous = class_counts * n_draws / class_counts.sum()
floored = numpy.floor(continuous)
assert n_draws // n_class == floored.sum() // n_class
n_remainder = int(n_draws - floored.sum())
remainder = continuous - floored
inds = numpy.argsort(remainder)[::-1]
inds = inds[:n_remainder]
floored[inds] += 1
assert n_draws == floored.sum()
return floored.astype(numpy.int)


class StratifiedSplitter(BaseSplitter):
"""Class for doing stratified data splits."""

def _split(self, dataset, frac_train=0.8, frac_valid=0.1, frac_test=0.1,
labels=None, **kwargs):
numpy.testing.assert_almost_equal(frac_train + frac_valid + frac_test,
1.)

seed = kwargs.get('seed', None)
label_axis = kwargs.get('label_axis', -1)
task_index = kwargs.get('task_index', 0)
n_bin = kwargs.get('n_bin', 10)
task_type = kwargs.get('task_type', 'auto')
if task_type not in ['classification', 'regression', 'auto']:
raise ValueError("{} is invalid. Please use 'classification',"
"'regression' or 'auto'".format(task_type))

rng = numpy.random.RandomState(seed)

if isinstance(labels, list):
labels = numpy.array(labels)
elif labels is None:
if not isinstance(dataset, NumpyTupleDataset):
raise ValueError("Please assign label dataset.")
labels = dataset.features[:, label_axis]

if labels.ndim == 1:
labels = labels
else:
labels = labels[:, task_index]

if task_type == 'auto':
if labels.dtype.kind == 'i':
task_type = 'classification'
elif labels.dtype.kind == 'f':
task_type = 'regression'
else:
raise ValueError

if task_type == 'classification':
classes, labels = numpy.unique(labels, return_inverse=True)
elif task_type == 'regression':
classes = numpy.arange(n_bin)
labels = pandas.qcut(labels, n_bin, labels=False)
else:
raise ValueError

n_classes = classes.shape[0]
n_total_valid = int(numpy.floor(frac_valid * len(dataset)))
n_total_test = int(numpy.floor(frac_test * len(dataset)))

class_counts = numpy.bincount(labels)
class_indices = numpy.split(numpy.argsort(labels,
kind='mergesort'),
numpy.cumsum(class_counts)[:-1])

# n_total_train is the remainder: n - n_total_valid - n_total_test
n_valid_samples = _approximate_mode(class_counts, n_total_valid)
class_counts = class_counts - n_valid_samples
n_test_samples = _approximate_mode(class_counts, n_total_test)

train_index = []
valid_index = []
test_index = []

for i in range(n_classes):
n_valid = n_valid_samples[i]
n_test = n_test_samples[i]

perm = rng.permutation(len(class_indices[i]))
class_perm_index = class_indices[i][perm]

class_valid_index = class_perm_index[:n_valid]
class_test_index = class_perm_index[n_valid:n_valid+n_test]
class_train_index = class_perm_index[n_valid+n_test:]

train_index.extend(class_train_index)
valid_index.extend(class_valid_index)
test_index.extend(class_test_index)

assert n_total_valid == len(valid_index)
assert n_total_test == len(test_index)

return numpy.array(train_index), numpy.array(valid_index),\
numpy.array(test_index),

def train_valid_test_split(self, dataset, labels=None, label_axis=-1,
task_index=0, frac_train=0.8, frac_valid=0.1,
frac_test=0.1, converter=None,
return_index=True, seed=None, task_type='auto',
n_bin=10, **kwargs):
"""Generate indices by stratified splittting dataset into train, valid
and test set.
Args:
dataset(NumpyTupleDataset, numpy.ndarray):
Dataset.
labels(numpy.ndarray):
Target label. If `None`, this function assumes that dataset is
an instance of `NumpyTupleDataset`.
labels_axis(int):
Dataset feature axis in NumpyTupleDataset.
task_index(int):
Target task index in dataset for stratification.
seed (int):
Random seed.
frac_train(float):
Fraction of dataset put into training data.
frac_valid(float):
Fraction of dataset put into validation data.
return_index(bool):
If `True`, this function returns only indexes. If `False`, this
function returns splitted dataset.
Returns:
SplittedDataset(tuple):
splitted dataset or indexes
.. admonition:: Example
>>> from chainer_chemistry.datasets import NumpyTupleDataset
>>> from chainer_chemistry.dataset.splitters \
>>> import StratifiedSplitter
>>> a = numpy.random.random((10, 10))
>>> b = numpy.random.random((10, 8))
>>> c = numpy.random.random((10, 1))
>>> d = NumpyTupleDataset(a, b, c)
>>> splitter = StratifiedSplitter()
>>> train, valid, test =
splitter.train_valid_test_split(dataset, return_index=False)
>>> print(len(train), len(valid))
8, 1, 1
"""

return super(StratifiedSplitter, self)\
.train_valid_test_split(dataset, frac_train, frac_valid, frac_test,
converter, return_index, seed=seed,
label_axis=label_axis, task_type=task_type,
task_index=task_index, n_bin=n_bin,
labels=labels, **kwargs)

def train_valid_split(self, dataset, labels=None, label_axis=-1,
task_index=0, frac_train=0.9, frac_valid=0.1,
converter=None, return_index=True, seed=None,
task_type='auto', n_bin=10, **kwargs):
"""Generate indices by stratified splittting dataset into train and
valid set.
Args:
dataset(NumpyTupleDataset, numpy.ndarray):
Dataset.
labels(numpy.ndarray):
Target label. If `None`, this function assumes that dataset is
an instance of `NumpyTupleDataset`.
labels_axis(int):
Dataset feature axis in NumpyTupleDataset.
task_index(int):
Target task index in dataset for stratification.
seed (int):
Random seed.
frac_train(float):
Fraction of dataset put into training data.
frac_valid(float):
Fraction of dataset put into validation data.
return_index(bool):
If `True`, this function returns only indexes. If `False`, this
function returns splitted dataset.
Returns:
SplittedDataset(tuple):
splitted dataset or indexes
.. admonition:: Example
>>> from chainer_chemistry.datasets import NumpyTupleDataset
>>> from chainer_chemistry.dataset.splitters \
>>> import StratifiedSplitter
>>> a = numpy.random.random((10, 10))
>>> b = numpy.random.random((10, 8))
>>> c = numpy.random.random((10, 1))
>>> d = NumpyTupleDataset(a, b, c)
>>> splitter = StratifiedSplitter()
>>> train, valid =
splitter.train_valid_split(dataset, return_index=False)
>>> print(len(train), len(valid))
9, 1
"""

return super(StratifiedSplitter, self)\
.train_valid_split(dataset, frac_train, frac_valid, converter,
return_index, seed=seed, label_axis=label_axis,
task_type=task_type, task_index=task_index,
n_bin=n_bin, labels=labels, **kwargs)
3 changes: 1 addition & 2 deletions docs/source/dataset.rst
Original file line number Diff line number Diff line change
Expand Up @@ -88,5 +88,4 @@ Splitters
:nosignatures:

chainer_chemistry.dataset.splitters.RandomSplitter
.. autosummary:: chainer_chemistry.dataset.splitters.RandomSplitter
:methods:
chainer_chemistry.dataset.splitters.StratifiedSplitter
Loading

0 comments on commit 33b6f16

Please sign in to comment.