Skip to content
This repository has been archived by the owner on Jun 26, 2021. It is now read-only.

Commit

Permalink
Merge pull request #115 from justusschock/debug_mode
Browse files Browse the repository at this point in the history
Introduce Debug mode
  • Loading branch information
justusschock committed Jun 7, 2019
2 parents 3b73d6d + 2e4daeb commit b5713dd
Show file tree
Hide file tree
Showing 554 changed files with 351 additions and 23 deletions.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
48 changes: 47 additions & 1 deletion delira/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,14 +6,16 @@
warnings.simplefilter('default', DeprecationWarning)
warnings.simplefilter('ignore', ImportWarning)

# to register new pssible backends, they have to be added to this list.
# to register new possible backends, they have to be added to this list.
# each backend should consist of a tuple of length 2 with the first entry
# being the package import name and the second being the backend abbreviation.
# E.g. TensorFlow's package is named 'tensorflow' but if the package is found,
# it will be considered as 'tf' later on
__POSSIBLE_BACKENDS = [("torch", "torch"), ("tensorflow", "tf")]
__BACKENDS = []

__DEBUG_MODE = False


def _determine_backends():

Expand Down Expand Up @@ -75,3 +77,47 @@ def get_backends():
if not __BACKENDS:
_determine_backends()
return __BACKENDS


# Functions to get and set the internal __DEBUG_MODE variable. This variable
# currently only defines whether to use multiprocessing or not. At the moment
# this is only used inside the BaseDataManager, which either returns a
# MultiThreadedAugmenter or a SingleThreadedAugmenter depending on the current
# debug mode.
# All other functions using multiprocessing should be aware of this and
# implement a functionality without multiprocessing
# (even if this slows down things a lot!).

def get_current_debug_mode():
"""
Getter function for the current debug mode
Returns
-------
bool
current debug mode
"""
return __DEBUG_MODE


def switch_debug_mode():
"""
Alternates the current debug mode
"""
set_debug_mode(not get_current_debug_mode())


def set_debug_mode(mode: bool):
"""
Sets a new debug mode
Parameters
----------
mode : bool
the new debug mode
"""
global __DEBUG_MODE
__DEBUG_MODE = mode
234 changes: 224 additions & 10 deletions delira/data_loading/data_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,19 +2,234 @@
import logging

import numpy as np
from batchgenerators.dataloading import SlimDataLoaderBase, \
MultiThreadedAugmenter
import typing
import inspect
from batchgenerators.dataloading import MultiThreadedAugmenter, \
SingleThreadedAugmenter, SlimDataLoaderBase
from batchgenerators.transforms import AbstractTransform

from .data_loader import BaseDataLoader
from .dataset import AbstractDataset, BaseCacheDataset, BaseLazyDataset
from .load_utils import default_load_fn_2d
from .sampler import SequentialSampler, AbstractSampler
from ..utils.decorators import make_deprecated
from delira import get_current_debug_mode

logger = logging.getLogger(__name__)


class Augmenter(object):
"""
Class wrapping ``MultiThreadedAugmentor`` and ``SingleThreadedAugmenter``
to provide a uniform API and to disable multiprocessing/multithreading
inside the dataloading pipeline
"""

def __init__(self, data_loader: BaseDataLoader, transforms,
n_process_augmentation=None, num_cached_per_queue=2,
seeds=None, **kwargs):
"""
Parameters
----------
data_loader : :class:`BaseDataLoader`
the dataloader providing the actual data
transforms : Callable or None
the transforms to use. Can be single callable or None
n_process_augmentation : int
the number of processes to use for augmentation (only necessary if
not in debug mode)
num_cached_per_queue : int
the number of samples to cache per queue (only necessary if not in
debug mode)
seeds : int or list
the seeds for each process (only necessary if not in debug mode)
**kwargs :
additional keyword arguments
"""
# don't use multiprocessing in debug mode
if get_current_debug_mode():
augmenter = SingleThreadedAugmenter(data_loader, transforms)

else:
assert isinstance(n_process_augmentation, int)
# no seeds are given -> use default seed of 1
if seeds is None:
seeds = 1

# only an int is gien as seed -> replicate it for each process
if isinstance(seeds, int):
seeds = [seeds] * n_process_augmentation
augmenter = MultiThreadedAugmenter(
data_loader, transforms,
num_processes=n_process_augmentation,
num_cached_per_queue=num_cached_per_queue,
seeds=seeds,
**kwargs)

self._augmenter = augmenter

@property
def __iter__(self):
"""
Property returning the augmenters ``__iter__``
Returns
-------
Callable
the augmenters ``__iter__``
"""
return self._augmenter.__iter__

@property
def __next__(self):
"""
Property returning the augmenters ``__next__``
Returns
-------
Callable
the augmenters ``__next__``
"""
return self._augmenter.__next__

@property
def next(self):
"""
Property returning the augmenters ``next``
Returns
-------
Callable
the augmenters ``next``
"""
return self._augmenter.next

@staticmethod
def __identity_fn(*args, **kwargs):
"""
Helper function accepting arbitrary args and kwargs and returning
without doing anything
Parameters
----------
*args
keyword arguments
**kwargs
positional arguments
"""
return

def _fn_checker(self, function_name):
"""
Checks if the internal augmenter has a given attribute and returns it.
Otherwise it returns ``__identity_fn``
Parameters
----------
function_name : str
the function name to check for
Returns
-------
Callable
either the function corresponding to the given function name or
``__identity_fn``
"""
# same as:
# if hasattr(self._augmenter, function_name):
# return getattr(self._augmenter, functionname)
# else:
# return self.__identity_fn
# but one less getattr call, because hasattr also calls getattr and
# handles AttributeError
try:
return getattr(self._augmenter, function_name)
except AttributeError:
return self.__identity_fn

@property
def _start(self):
"""
Property to provide uniform API of ``_start``
Returns
-------
Callable
either the augmenter's ``_start`` method (if available) or
``__identity_fn`` (if not available)
"""
return self._fn_checker("_start")

def restart(self):
"""
Property to provide uniform API of ``restart``
Returns
-------
Callable
either the augmenter's ``restart`` method (if available) or
``__identity_fn`` (if not available)
"""
return self._fn_checker("restart")

@property
def _finish(self):
"""
Property to provide uniform API of ``_finish``
Returns
-------
Callable
either the augmenter's ``_finish`` method (if available) or
``__identity_fn`` (if not available)
"""
return self._fn_checker("_finish")

@property
def num_batches(self):
"""
Property returning the number of batches
Returns
-------
int
number of batches
"""
if isinstance(self._augmenter, MultiThreadedAugmenter):
return self._augmenter.generator.num_batches

return self._augmenter.data_loader.num_batches

@property
def num_processes(self):
"""
Property returning the number of processes to use for loading and
augmentation
Returns
-------
int
number of processes to use for loading and
augmentation
"""
if isinstance(self._augmenter, MultiThreadedAugmenter):
return self._augmenter.num_processes

return 1

def __del__(self):
"""
Function defining what to do, if object should be deleted
"""
del self._augmenter


class BaseDataManager(object):
"""
Class to Handle Data
Expand Down Expand Up @@ -129,7 +344,7 @@ def get_batchgen(self, seed=1):
Returns
-------
MultiThreadedAugmenter
Augmenter
Batchgenerator
Raises
Expand All @@ -147,13 +362,10 @@ def get_batchgen(self, seed=1):
sampler=self.sampler
)

return MultiThreadedAugmenter(
data_loader,
self.transforms,
self.n_process_augmentation,
num_cached_per_queue=2,
seeds=self.n_process_augmentation * [seed]
)
return Augmenter(data_loader, self.transforms,
self.n_process_augmentation,
num_cached_per_queue=2,
seeds=self.n_process_augmentation * [seed])

def get_subset(self, indices):
"""
Expand Down Expand Up @@ -319,6 +531,7 @@ def n_process_augmentation(self, new_process_number):
Setter for number of augmentation processes, casts to int before
setting the attribute
Parameters
----------
new_process_number : int, Any
Expand Down Expand Up @@ -389,6 +602,7 @@ def data_loader_cls(self, new_loader_cls):

assert inspect.isclass(new_loader_cls) and issubclass(
new_loader_cls, SlimDataLoaderBase)

self._data_loader_cls = new_loader_cls

@property
Expand Down
8 changes: 4 additions & 4 deletions delira/training/base_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
import pickle
import typing

from batchgenerators.dataloading import MultiThreadedAugmenter
from ..data_loading.data_manager import Augmenter

from .predictor import Predictor
from .callbacks import AbstractCallback
Expand Down Expand Up @@ -280,14 +280,14 @@ def _at_epoch_end(self, metrics_val, val_score_key, epoch, is_best,
self.save_state(os.path.join(self.save_path,
"checkpoint_best"))

def _train_single_epoch(self, batchgen: MultiThreadedAugmenter, epoch,
def _train_single_epoch(self, batchgen: Augmenter, epoch,
verbose=False):
"""
Trains the network a single epoch
Parameters
----------
batchgen : MultiThreadedAugmenter
batchgen : :class:`Augmenter`
Generator yielding the training batches
epoch : int
current epoch
Expand All @@ -296,7 +296,7 @@ def _train_single_epoch(self, batchgen: MultiThreadedAugmenter, epoch,

metrics, losses = [], []

n_batches = batchgen.generator.num_batches * batchgen.num_processes
n_batches = batchgen.num_batches * batchgen.num_processes
if verbose:
iterable = tqdm(
enumerate(batchgen),
Expand Down
2 changes: 1 addition & 1 deletion delira/training/predictor.py
Original file line number Diff line number Diff line change
Expand Up @@ -196,7 +196,7 @@ def predict_data_mgr(self, datamgr, batchsize=None, metrics={},
if not lazy_gen:
predictions_all, metric_vals = [], {k: [] for k in metrics.keys()}

n_batches = batchgen.generator.num_batches * batchgen.num_processes
n_batches = batchgen.num_batches * batchgen.num_processes

if verbose:
iterable = tqdm(enumerate(batchgen), unit=' sample',
Expand Down

0 comments on commit b5713dd

Please sign in to comment.