Skip to content

Commit

Permalink
Turn k_fold function into KFold class
Browse files Browse the repository at this point in the history
  • Loading branch information
cbarrick committed Jun 3, 2018
1 parent f98a293 commit 14191f4
Show file tree
Hide file tree
Showing 3 changed files with 36 additions and 33 deletions.
2 changes: 1 addition & 1 deletion toys/model_selection/__init__.py
@@ -1,4 +1,4 @@
from .grid_search import combinations
from .grid_search import GridSearchCV

from .cross_val import k_fold
from .cross_val import KFold
59 changes: 32 additions & 27 deletions toys/model_selection/cross_val.py
Expand Up @@ -4,36 +4,41 @@
from toys.typing import CrossValSplitter, Dataset


def k_fold(k=3, shuffle=True):
'''Returns a splitter function for k-fold cross validation.
K-folding partitions a dataset into k subsets of roughly equal size.
If ``shuffle`` is true, the elements of each partition are chosen at
random. Otherwise each partition is a continuous subset of the dataset.
Arguments:
k (int):
The number of folds. Must be at least 2.
shuffle (bool):
Whether to shuffle the indices before splitting.
Returns:
cv (CrossValSplitter):
A function which takes a dataset and returns an iterator over pairs
of lists of indices, ``(train, test)``, where ``train`` indexes the
training instances of the fold and ``test`` indexes the testing
instances.
class KFold(CrossValSplitter):
'''A splitter for simple k-fold cross validation.
K-folding partitions a dataset into k subsets of roughly equal size. A
"fold" is a pair of datasets, ``(train, test)``, where ``test`` is one of
the partitions and ``train`` is the concatenation of the remaining
partitions.
Instances of this class are functions which apply k-folding to datasets.
They return an iterator over all folds of the datasets.
'''
assert 1 < k, 'The number of folds must be at least 2.'

def cv(dataset):
def __init__(self, k=3, shuffle=True):
'''Initialize a KFold.
If ``shuffle`` is true, the elements of each partition are chosen at
random. Otherwise each partition is a continuous subset of the dataset.
Arguments:
k (int):
The number of folds. Must be at least 2.
shuffle (bool):
Whether to shuffle the indices before splitting.
'''
if k < 2:
raise ValueError('The number of folds must be at least 2.')

self.k = k
self.shuffle = shuffle

def __call__(self, dataset):
indices = np.arange(len(dataset))
if shuffle: np.random.shuffle(indices)
splits = np.array_split(indices, k)
if self.shuffle: np.random.shuffle(indices)
splits = np.array_split(indices, self.k)
for test in splits:
train = [s for s in splits if s is not test]
train = np.concatenate(train)
yield train, test

return cv
yield toys.subset(dataset, train), toys.subset(dataset, test)
8 changes: 3 additions & 5 deletions toys/model_selection/grid_search.py
Expand Up @@ -12,7 +12,7 @@
from toys.parsers import parse_metric
from toys.typing import CrossValSplitter, Estimator, Metric, Model, ParamGrid

from .cross_val import k_fold
from .cross_val import KFold


logger = getLogger(__name__)
Expand Down Expand Up @@ -109,14 +109,12 @@ def fit(self, *datasets, estimator=None, param_grid=None, cv=3, metric='f_score'
logger.warn('multiprocessing is not yet supported')

if not callable(cv):
cv = k_fold(cv)
cv = KFold(cv)

def jobs():
for train, test in cv(dataset):
for params in combinations(param_grid):
train_set = toys.subset(dataset, train)
test_set = toys.subset(dataset, test)
yield params, train_set, test_set
yield params, train, test

def run(job):
(params, train_set, test_set) = job
Expand Down

0 comments on commit 14191f4

Please sign in to comment.