Skip to content
This repository has been archived by the owner on Jun 26, 2021. It is now read-only.

Commit

Permalink
Merge pull request #134 from justusschock/data_loading_fix
Browse files Browse the repository at this point in the history
Data loading fix
  • Loading branch information
justusschock committed Jun 12, 2019
2 parents 74e0d6d + dd72548 commit 9d25a68
Show file tree
Hide file tree
Showing 8 changed files with 421 additions and 261 deletions.
34 changes: 15 additions & 19 deletions delira/data_loading/data_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,12 @@ def __init__(self, data_loader: BaseDataLoader, transforms,
# only an int is gien as seed -> replicate it for each process
if isinstance(seeds, int):
seeds = [seeds] * n_process_augmentation

# avoid same seeds for all processes
if any([seeds[0] == _seed for _seed in seeds[1:]]):
for idx in range(len(seeds)):
seeds[idx] = seeds[idx] + idx

augmenter = MultiThreadedAugmenter(
data_loader, transforms,
num_processes=n_process_augmentation,
Expand Down Expand Up @@ -200,7 +206,7 @@ def num_batches(self):
number of batches
"""
if isinstance(self._augmenter, MultiThreadedAugmenter):
return self._augmenter.generator.num_batches
return self._augmenter.generator.num_batches # * self.num_processes

return self._augmenter.data_loader.num_batches

Expand Down Expand Up @@ -693,8 +699,7 @@ def n_samples(self):
@property
def n_batches(self):
"""
Returns Number of Batches based on batchsize,
number of samples and number of processes
Returns Number of Batches based on batchsize and number of samples
Returns
-------
Expand All @@ -708,20 +713,11 @@ def n_batches(self):
"""
assert self.n_samples > 0
if self.n_process_augmentation == 1:
n_batches = int(np.floor(self.n_samples / self.batch_size))
elif self.n_process_augmentation > 1:
if (self.n_samples / self.batch_size) < \
self.n_process_augmentation:
self.n_process_augmentation = 1
logger.warning(
'Too few samples for n_process_augmentation={}. '
'Forcing n_process_augmentation={} '
'instead'.format(
self.n_process_augmentation, 1))
n_batches = int(np.floor(
self.n_samples / self.batch_size / self.n_process_augmentation)
)
else:
raise ValueError('Invalid value for n_process_augmentation')

n_batches = self.n_samples // self.batch_size

truncated_batch = self.n_samples % self.batch_size

n_batches += int(bool(truncated_batch))

return n_batches
37 changes: 36 additions & 1 deletion delira/data_loading/sampler/abstract_sampler.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,8 @@ class AbstractSampler(object):
"""

def __init__(self, indices=None):
pass
self._num_samples = len(indices)
self._global_index = 0

@classmethod
def from_dataset(cls, dataset: AbstractDataset, **kwargs):
Expand All @@ -31,6 +32,40 @@ def from_dataset(cls, dataset: AbstractDataset, **kwargs):
indices = list(range(len(dataset)))
return cls(indices, **kwargs)

def _check_batchsize(self, n_indices):
"""
Checks if the batchsize is valid (and truncates batches if necessary).
Will also raise StopIteration if enough batches sampled
Parameters
----------
n_indices : int
number of indices to sample
Returns
-------
int
number of indices to sample (truncated if necessary)
Raises
------
StopIteration
if enough batches sampled
"""

if self._global_index >= self._num_samples:
self._global_index = 0
raise StopIteration

else:
# truncate batch if necessary
if n_indices + self._global_index > self._num_samples:
n_indices = self._num_samples - self._global_index

self._global_index += n_indices
return n_indices

@abstractmethod
def _get_indices(self, n_indices):
"""
Expand Down
23 changes: 3 additions & 20 deletions delira/data_loading/sampler/lambda_sampler.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,11 +22,10 @@ def __init__(self, indices, sampling_fn):
and the number of indices to return
"""
super().__init__()
super().__init__(indices)
self._indices = list(range(len(indices)))

self._sampling_fn = sampling_fn
self._global_index = 0

def _get_indices(self, n_indices):
"""
Expand All @@ -42,27 +41,11 @@ def _get_indices(self, n_indices):
list
list of sampled indices
Raises
------
StopIteration
Maximum number of indices sampled
"""

if self._global_index >= len(self._indices):
self._global_index = 0
raise StopIteration

new_global_idx = self._global_index + n_indices

# If we reach end, make batch smaller
if new_global_idx >= len(self._indices):
new_global_idx = len(self._indices)

samples = self._sampling_fn(self._indices,
new_global_idx - self._global_index)
n_indices = self._check_batchsize(n_indices)

self._global_index = new_global_idx
samples = self._sampling_fn(self._indices, n_indices)
return samples

def __len__(self):
Expand Down
125 changes: 71 additions & 54 deletions delira/data_loading/sampler/random_sampler.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,9 +23,8 @@ def __init__(self, indices):
corresponding class
"""
super().__init__()
super().__init__(indices)
self._indices = list(range(len(indices)))
self._global_index = 0

def _get_indices(self, n_indices):
"""
Expand All @@ -47,21 +46,11 @@ def _get_indices(self, n_indices):
If maximal number of samples is reached
"""
if self._global_index >= len(self._indices):
self._global_index = 0
raise StopIteration

new_global_idx = self._global_index + n_indices

# If we reach end, make batch smaller
if new_global_idx >= len(self._indices):
new_global_idx = len(self._indices)
n_indices = self._check_batchsize(n_indices)

indices = choice(self._indices,
size=new_global_idx - self._global_index)
# indices = choices(
# self._indices, k=new_global_idx - self._global_index)
self._global_index = new_global_idx
size=n_indices)

return indices

def __len__(self):
Expand Down Expand Up @@ -93,12 +82,12 @@ def __init__(self, indices, shuffle_batch=True):
sampled indices will be shuffled
"""
super().__init__()
super().__init__(indices)

self._num_indices = 0
self._num_samples = 0
_indices = {}
for idx, class_idx in enumerate(indices):
self._num_indices += 1
self._num_samples += 1
class_idx = int(class_idx)
if class_idx in _indices.keys():
_indices[class_idx].append(idx)
Expand Down Expand Up @@ -158,9 +147,7 @@ def _get_indices(self, n_indices):
If maximal number of samples is reached
"""
if self._global_index >= self._num_indices:
self._global_index = 0
raise StopIteration
n_indices = self._check_batchsize(n_indices)

samples_per_class = int(n_indices / self._n_classes)

Expand All @@ -178,14 +165,13 @@ def _get_indices(self, n_indices):
_samples.append(choice(idx_list, size=1))

_samples = concatenate(_samples)
self._global_index += n_indices
if self._shuffle:
shuffle(_samples)

return _samples

def __len__(self):
return self._num_indices
return self._num_samples


class StoppingPrevalenceRandomSampler(AbstractSampler):
Expand All @@ -212,7 +198,7 @@ def __init__(self, indices, shuffle_batch=True):
if True: indices will be sampled in a sequential way per class and
sampled indices will be shuffled
"""
super().__init__()
super().__init__(indices)

_indices = {}
_global_idxs = {}
Expand All @@ -233,7 +219,7 @@ def __init__(self, indices, shuffle_batch=True):
ordered_dict[k] = _indices[k]
length = min(length, len(_indices[k]))

self._length = length
self._num_samples = length

self._indices = ordered_dict
self._n_classes = len(_indices.keys())
Expand Down Expand Up @@ -261,6 +247,58 @@ def from_dataset(cls, dataset: AbstractDataset, **kwargs):
labels = [dataset[idx]['label'] for idx in indices]
return cls(labels, **kwargs)

def _check_batchsize(self, n_indices):
"""
Checks if batchsize is valid for all classes
Parameters
----------
n_indices : int
the number of samples to return
Returns
-------
dict
number of samples per class to return
"""
n_indices = super()._check_batchsize(n_indices)

samples_per_class = n_indices // self._n_classes
remaining = n_indices % self._n_classes

samples = {}

try:

# sample same number of sample for each class
for key, idx_list in self._indices.items():
if self._global_idxs[key] >= len(idx_list):
raise StopIteration

# truncate if necessary
samples[key] = min(
samples_per_class,
len(self._indices[key]) - self._global_idxs[key])

self._global_idxs[key] += samples[key]

# fill up starting with largest class
while remaining:
for key, idx_list in self._indices.items():
samples[key] += 1
remaining -= 1

except StopIteration as e:
# set all global indices to 0
for key in self._global_idxs.keys():
self._global_idxs[key] = 0

raise e

finally:
return samples

def _get_indices(self, n_indices):
"""
Actual Sampling
Expand All @@ -278,40 +316,19 @@ def _get_indices(self, n_indices):
-------
list: list of sampled indices
"""
samples_per_class = int(n_indices / self._n_classes)
_samples = []

for key, idx_list in self._indices.items():
if self._global_idxs[key] >= len(idx_list):
self._global_idxs[key] = 0
raise StopIteration

new_global_idx = self._global_idxs[key] + samples_per_class

if new_global_idx >= len(idx_list):
new_global_idx = len(idx_list)

_samples.append(choice(idx_list, size=samples_per_class))

self._global_idxs[key] = new_global_idx
n_indices = self._check_batchsize(n_indices)

for key, idx_list in self._indices.items():
if len(_samples) >= n_indices:
break
samples = []

if self._global_idxs[key] >= len(idx_list):
self._global_idxs[key] = 0
for key, _n_indices in n_indices.items():
samples.append(choice(self._indices[key], size=_n_indices))

new_global_idx = self._global_idxs[key] + 1
samples = concatenate(samples)

_samples.append(choice(idx_list, size=1))
self._global_idxs[key] = new_global_idx

_samples = concatenate(_samples)
if self._shuffle:
shuffle(_samples)
shuffle(samples)

return _samples
return samples

def __len__(self):
return self._length
return self._num_samples

0 comments on commit 9d25a68

Please sign in to comment.