Skip to content

Commit

Permalink
Proper labeling of folds for an augmented dataset
Browse files Browse the repository at this point in the history
  • Loading branch information
kboone committed May 18, 2019
1 parent 849a2e2 commit 0a62644
Showing 1 changed file with 50 additions and 26 deletions.
76 changes: 50 additions & 26 deletions avocado/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -183,42 +183,66 @@ def from_objects(cls, name, objects, **kwargs):

return dataset

def label_folds(self):
def label_folds(self, num_folds=None, random_state=1):
"""Separate the dataset into groups for k-folding
This is only applicable to training datasets that have assigned
categories.
The number of folds is set by the `num_folds` settings parameter.
If the dataset is an augmented dataset, we ensure that the
augmentations of the same object stay in the same fold.
Parameters
----------
num_folds : int (optional)
The number of folds to use. Default: settings['num_folds']
random_state : int (optional)
The random number initializer to use for splitting the folds.
Returns
-------
fold_indices : `pandas.Series`
A pandas Series where each element is an integer representing the
assigned fold for each object.
"""
print("TODO: KEEP AUGMENTS IN SAME FOLD!")
if num_folds is None:
num_folds = settings['num_folds']

if 'category' not in self.metadata:
logger.warn("Dataset %s does not have labeled categories! Can't "
"separate into folds." % self.name)
return

num_folds = settings['num_folds']

if 'fold' in self.metadata:
# Warn if the fold count doesn't match.
data_num_folds = np.max(self.metadata['fold']) + 1
if data_num_folds != num_folds:
logger.warn("Using %d preset folds in dataset instead of "
"%d requested." % (data_num_folds, num_folds))
return

# Label folds
categories = self.metadata['category']
raise AvocadoException(
"Dataset %s does not have labeled categories! Can't separate "
"into folds." % self.name
)

if 'reference_object_id' in self.metadata:
# We are operating on an augmented dataset. Use original objects to
# determine the folds.
is_augmented = True
reference_mask = self.metadata['reference_object_id'].isna()
reference_metadata = self.metadata[reference_mask]
else:
is_augmented = False
reference_metadata = self.metadata

reference_classes = reference_metadata['category']
folds = StratifiedKFold(n_splits=num_folds, shuffle=True,
random_state=1)
kfold_indices = -1 * np.ones(len(categories), dtype=int)
for idx, (fold_train, fold_val) in \
enumerate(folds.split(categories, categories)):
kfold_indices[fold_val] = idx
self.metadata['fold'] = kfold_indices
random_state=random_state)
fold_map = {}
for fold_number, (fold_train, fold_val) in \
enumerate(folds.split(reference_classes, reference_classes)):
for object_id in reference_metadata.index[fold_val]:
fold_map[object_id] = fold_number

if is_augmented:
fold_indices = self.metadata['reference_object_id'].map(fold_map)
fold_indices[reference_mask] = \
self.metadata.index.to_series().map(fold_map)
else:
fold_indices = self.metadata.index.to_series().map(fold_map)

fold_indices = fold_indices.astype(int)

return fold_indices

def get_object(self, index=None, category=None, object_id=None):
"""Parse keywords to pull a specific object out of the dataset
Expand Down

0 comments on commit 0a62644

Please sign in to comment.