Skip to content
This repository has been archived by the owner on Jun 26, 2021. It is now read-only.

Commit

Permalink
Merge pull request #187 from delira-dev/load_sample_docs
Browse files Browse the repository at this point in the history
Fix docstring of LoadSampleLabel
  • Loading branch information
justusschock committed Aug 30, 2019
2 parents a97890d + d18a684 commit eddfc96
Show file tree
Hide file tree
Showing 2 changed files with 25 additions and 14 deletions.
33 changes: 22 additions & 11 deletions delira/data_loading/load_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -152,7 +152,8 @@ class LoadSample:
def __init__(self,
sample_ext: dict,
sample_fn: collections.abc.Callable,
dtype=None, normalize=(), norm_fn=norm_range('-1,1'),
dtype: dict = None, normalize: tuple = (),
norm_fn=norm_range('-1,1'),
**kwargs):
"""
Expand All @@ -162,16 +163,16 @@ def __init__(self,
Defines the data _sample_ext. The dict key defines the position of
the sample inside the returned data dict, while the list defines
the the files which should be loaded inside the data dict.
sample_fn : callable
sample_fn : function
function to load a single sample
dtype : dict
defines the data type which should be used for the respective key
normalize : iterable of hashable
list of hashable which should be normalized. Can contain
entire keys of extension (normalizes each element individually)
or provide the file name which should be normalized
norm_fn : callable
callable to normalize input. Default: normalize range to [-1, 1]
norm_fn : function
function to normalize input. Default: normalize range to [-1, 1]
kwargs :
variable number of keyword arguments passed to load function
Expand All @@ -198,7 +199,7 @@ def __init__(self,
self._norm_fn = norm_fn
self._kwargs = kwargs

def __call__(self, path):
def __call__(self, path) -> dict:
"""
Load sample from multiple files
Expand Down Expand Up @@ -239,8 +240,10 @@ class LoadSampleLabel(LoadSample):
def __init__(self,
sample_ext: dict,
sample_fn: collections.abc.Callable,
label_ext: collections.abc.Iterable,
label_ext: str,
label_fn: collections.abc.Callable,
dtype: dict = None, normalize: tuple = (),
norm_fn=norm_range('-1,1'),
sample_kwargs=None, **kwargs):
"""
Load sample and label from folder
Expand All @@ -252,15 +255,21 @@ def __init__(self,
the sample inside the returned data dict, while the list defines
the the files which should be loaded inside the data dict.
Passed to LoadSample.
sample_fn : callable
sample_fn : function
function to load a single sample
Passed to LoadSample.
label_ext : str
extension for label
label_fn: function
functions which returns the label inside a dict
args :
variable number of positional arguments passed to LoadSample
dtype : dict
defines the data type which should be used for the respective key
normalize : iterable of hashable
list of hashable which should be normalized. Can contain
entire keys of extension (normalizes each element individually)
or provide the file name which should be normalized
norm_fn : function
function to normalize input. Default: normalize range to [-1, 1]
sample_kwargs :
additional keyword arguments passed to LoadSample
kwargs :
Expand All @@ -273,12 +282,14 @@ def __init__(self,
if sample_kwargs is None:
sample_kwargs = {}

super().__init__(sample_ext, sample_fn, **sample_kwargs)
super().__init__(sample_ext=sample_ext, sample_fn=sample_fn,
dtype=dtype, normalize=normalize, norm_fn=norm_fn,
**sample_kwargs)
self._label_ext = label_ext
self._label_fn = label_fn
self._label_kwargs = kwargs

def __call__(self, path):
def __call__(self, path) -> dict:
"""
Loads a sample and a label
Expand Down
6 changes: 3 additions & 3 deletions tests/data_loading/test_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -186,7 +186,7 @@ def load_dummy_data(path):
'data2': ['data', 'data', 'data']},
load_dummy_data,
dtype={'seg': 'uint8'},
normalize=['data2'])
normalize=('data2',))
sample = sample_fn('load')
assert not np.isclose(np.mean(sample['data']), 0)
assert not np.isclose(np.mean(sample['seg']), 0)
Expand All @@ -197,7 +197,7 @@ def load_dummy_data(path):
# check different normalization function
sample_fn = LoadSample({'data': ['data', 'data', 'data']},
load_dummy_data,
normalize=['data'],
normalize=('data',),
norm_fn=norm_zero_mean_unit_std)
sample = sample_fn('load')
assert np.isclose(np.mean(sample['data']), 0)
Expand All @@ -208,7 +208,7 @@ def load_dummy_data(path):
{'data': ['data', 'data', 'data'], 'seg': ['data'],
'data2': ['data', 'data', 'data']}, load_dummy_data,
'label', load_dummy_label,
sample_kwargs={'dtype': {'seg': 'uint8'}, 'normalize': ['data2']})
dtype={'seg': 'uint8'}, normalize=('data2',))
sample = sample_fn('load')
assert not np.isclose(np.mean(sample['data']), 0)
assert not np.isclose(np.mean(sample['seg']), 0)
Expand Down

0 comments on commit eddfc96

Please sign in to comment.