-
Notifications
You must be signed in to change notification settings - Fork 129
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Merge pull request #201 from mottodora/stratified-splitter
Stratified Splitter
- Loading branch information
Showing
4 changed files
with
597 additions
and
2 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,5 +1,7 @@ | ||
from chainer_chemistry.dataset.splitters import base_splitter # NOQA | ||
from chainer_chemistry.dataset.splitters import random_splitter # NOQA | ||
from chainer_chemistry.dataset.splitters import stratified_splitter # NOQA | ||
|
||
from chainer_chemistry.dataset.splitters.base_splitter import BaseSplitter # NOQA | ||
from chainer_chemistry.dataset.splitters.random_splitter import RandomSplitter # NOQA | ||
from chainer_chemistry.dataset.splitters.stratified_splitter import StratifiedSplitter # NOQA |
214 changes: 214 additions & 0 deletions
214
chainer_chemistry/dataset/splitters/stratified_splitter.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,214 @@ | ||
import numpy | ||
import pandas | ||
|
||
from chainer_chemistry.dataset.splitters.base_splitter import BaseSplitter | ||
from chainer_chemistry.datasets.numpy_tuple_dataset import NumpyTupleDataset | ||
|
||
|
||
# Refer to scikit-learn | ||
# https://git.io/fPMmB | ||
def _approximate_mode(class_counts, n_draws): | ||
n_class = len(class_counts) | ||
continuous = class_counts * n_draws / class_counts.sum() | ||
floored = numpy.floor(continuous) | ||
assert n_draws // n_class == floored.sum() // n_class | ||
n_remainder = int(n_draws - floored.sum()) | ||
remainder = continuous - floored | ||
inds = numpy.argsort(remainder)[::-1] | ||
inds = inds[:n_remainder] | ||
floored[inds] += 1 | ||
assert n_draws == floored.sum() | ||
return floored.astype(numpy.int) | ||
|
||
|
||
class StratifiedSplitter(BaseSplitter): | ||
"""Class for doing stratified data splits.""" | ||
|
||
def _split(self, dataset, frac_train=0.8, frac_valid=0.1, frac_test=0.1, | ||
labels=None, **kwargs): | ||
numpy.testing.assert_almost_equal(frac_train + frac_valid + frac_test, | ||
1.) | ||
|
||
seed = kwargs.get('seed', None) | ||
label_axis = kwargs.get('label_axis', -1) | ||
task_index = kwargs.get('task_index', 0) | ||
n_bin = kwargs.get('n_bin', 10) | ||
task_type = kwargs.get('task_type', 'auto') | ||
if task_type not in ['classification', 'regression', 'auto']: | ||
raise ValueError("{} is invalid. Please use 'classification'," | ||
"'regression' or 'auto'".format(task_type)) | ||
|
||
rng = numpy.random.RandomState(seed) | ||
|
||
if isinstance(labels, list): | ||
labels = numpy.array(labels) | ||
elif labels is None: | ||
if not isinstance(dataset, NumpyTupleDataset): | ||
raise ValueError("Please assign label dataset.") | ||
labels = dataset.features[:, label_axis] | ||
|
||
if labels.ndim == 1: | ||
labels = labels | ||
else: | ||
labels = labels[:, task_index] | ||
|
||
if task_type == 'auto': | ||
if labels.dtype.kind == 'i': | ||
task_type = 'classification' | ||
elif labels.dtype.kind == 'f': | ||
task_type = 'regression' | ||
else: | ||
raise ValueError | ||
|
||
if task_type == 'classification': | ||
classes, labels = numpy.unique(labels, return_inverse=True) | ||
elif task_type == 'regression': | ||
classes = numpy.arange(n_bin) | ||
labels = pandas.qcut(labels, n_bin, labels=False) | ||
else: | ||
raise ValueError | ||
|
||
n_classes = classes.shape[0] | ||
n_total_valid = int(numpy.floor(frac_valid * len(dataset))) | ||
n_total_test = int(numpy.floor(frac_test * len(dataset))) | ||
|
||
class_counts = numpy.bincount(labels) | ||
class_indices = numpy.split(numpy.argsort(labels, | ||
kind='mergesort'), | ||
numpy.cumsum(class_counts)[:-1]) | ||
|
||
# n_total_train is the remainder: n - n_total_valid - n_total_test | ||
n_valid_samples = _approximate_mode(class_counts, n_total_valid) | ||
class_counts = class_counts - n_valid_samples | ||
n_test_samples = _approximate_mode(class_counts, n_total_test) | ||
|
||
train_index = [] | ||
valid_index = [] | ||
test_index = [] | ||
|
||
for i in range(n_classes): | ||
n_valid = n_valid_samples[i] | ||
n_test = n_test_samples[i] | ||
|
||
perm = rng.permutation(len(class_indices[i])) | ||
class_perm_index = class_indices[i][perm] | ||
|
||
class_valid_index = class_perm_index[:n_valid] | ||
class_test_index = class_perm_index[n_valid:n_valid+n_test] | ||
class_train_index = class_perm_index[n_valid+n_test:] | ||
|
||
train_index.extend(class_train_index) | ||
valid_index.extend(class_valid_index) | ||
test_index.extend(class_test_index) | ||
|
||
assert n_total_valid == len(valid_index) | ||
assert n_total_test == len(test_index) | ||
|
||
return numpy.array(train_index), numpy.array(valid_index),\ | ||
numpy.array(test_index), | ||
|
||
def train_valid_test_split(self, dataset, labels=None, label_axis=-1, | ||
task_index=0, frac_train=0.8, frac_valid=0.1, | ||
frac_test=0.1, converter=None, | ||
return_index=True, seed=None, task_type='auto', | ||
n_bin=10, **kwargs): | ||
"""Generate indices by stratified splittting dataset into train, valid | ||
and test set. | ||
Args: | ||
dataset(NumpyTupleDataset, numpy.ndarray): | ||
Dataset. | ||
labels(numpy.ndarray): | ||
Target label. If `None`, this function assumes that dataset is | ||
an instance of `NumpyTupleDataset`. | ||
labels_axis(int): | ||
Dataset feature axis in NumpyTupleDataset. | ||
task_index(int): | ||
Target task index in dataset for stratification. | ||
seed (int): | ||
Random seed. | ||
frac_train(float): | ||
Fraction of dataset put into training data. | ||
frac_valid(float): | ||
Fraction of dataset put into validation data. | ||
return_index(bool): | ||
If `True`, this function returns only indexes. If `False`, this | ||
function returns splitted dataset. | ||
Returns: | ||
SplittedDataset(tuple): | ||
splitted dataset or indexes | ||
.. admonition:: Example | ||
>>> from chainer_chemistry.datasets import NumpyTupleDataset | ||
>>> from chainer_chemistry.dataset.splitters \ | ||
>>> import StratifiedSplitter | ||
>>> a = numpy.random.random((10, 10)) | ||
>>> b = numpy.random.random((10, 8)) | ||
>>> c = numpy.random.random((10, 1)) | ||
>>> d = NumpyTupleDataset(a, b, c) | ||
>>> splitter = StratifiedSplitter() | ||
>>> train, valid, test = | ||
splitter.train_valid_test_split(dataset, return_index=False) | ||
>>> print(len(train), len(valid)) | ||
8, 1, 1 | ||
""" | ||
|
||
return super(StratifiedSplitter, self)\ | ||
.train_valid_test_split(dataset, frac_train, frac_valid, frac_test, | ||
converter, return_index, seed=seed, | ||
label_axis=label_axis, task_type=task_type, | ||
task_index=task_index, n_bin=n_bin, | ||
labels=labels, **kwargs) | ||
|
||
def train_valid_split(self, dataset, labels=None, label_axis=-1, | ||
task_index=0, frac_train=0.9, frac_valid=0.1, | ||
converter=None, return_index=True, seed=None, | ||
task_type='auto', n_bin=10, **kwargs): | ||
"""Generate indices by stratified splittting dataset into train and | ||
valid set. | ||
Args: | ||
dataset(NumpyTupleDataset, numpy.ndarray): | ||
Dataset. | ||
labels(numpy.ndarray): | ||
Target label. If `None`, this function assumes that dataset is | ||
an instance of `NumpyTupleDataset`. | ||
labels_axis(int): | ||
Dataset feature axis in NumpyTupleDataset. | ||
task_index(int): | ||
Target task index in dataset for stratification. | ||
seed (int): | ||
Random seed. | ||
frac_train(float): | ||
Fraction of dataset put into training data. | ||
frac_valid(float): | ||
Fraction of dataset put into validation data. | ||
return_index(bool): | ||
If `True`, this function returns only indexes. If `False`, this | ||
function returns splitted dataset. | ||
Returns: | ||
SplittedDataset(tuple): | ||
splitted dataset or indexes | ||
.. admonition:: Example | ||
>>> from chainer_chemistry.datasets import NumpyTupleDataset | ||
>>> from chainer_chemistry.dataset.splitters \ | ||
>>> import StratifiedSplitter | ||
>>> a = numpy.random.random((10, 10)) | ||
>>> b = numpy.random.random((10, 8)) | ||
>>> c = numpy.random.random((10, 1)) | ||
>>> d = NumpyTupleDataset(a, b, c) | ||
>>> splitter = StratifiedSplitter() | ||
>>> train, valid = | ||
splitter.train_valid_split(dataset, return_index=False) | ||
>>> print(len(train), len(valid)) | ||
9, 1 | ||
""" | ||
|
||
return super(StratifiedSplitter, self)\ | ||
.train_valid_split(dataset, frac_train, frac_valid, converter, | ||
return_index, seed=seed, label_axis=label_axis, | ||
task_type=task_type, task_index=task_index, | ||
n_bin=n_bin, labels=labels, **kwargs) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.