Skip to content

Commit

Permalink
ControlFlowCallback ddp support (#1341)
Browse files Browse the repository at this point in the history
* fix

* codestyle

* fix
  • Loading branch information
Scitator committed Oct 30, 2021
1 parent 77dca7f commit b6a4a7f
Showing 1 changed file with 103 additions and 133 deletions.
236 changes: 103 additions & 133 deletions catalyst/callbacks/control_flow.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,139 +10,108 @@
from catalyst.core.runner import IRunner


def _filter_fn_from_epochs(
epochs: Union[int, float, Sequence[int]], reverse_condition: bool
) -> FILTER_FN:
"""Build ``filter_fn`` from epochs for ``ControlFlowCallback``
Args:
epochs: epochs description
reverse_condition: indicator to use reversed
condition in filter function
Raises:
ValueError: if passed object with unexpected type
Returns:
filter function which accepts 3 arguments - stage (str),
epoch (int), loader (str) and return ``True`` if
need to disable callback
"""
if isinstance(epochs, (int, float)):
epochs = int(epochs)
if reverse_condition:
filter_fn = lambda stage, epoch, loader: epoch % epochs != 0
else:
filter_fn = lambda stage, epoch, loader: epoch % epochs == 0
elif isinstance(epochs, (list, tuple)):
epochs = sorted(set(epochs))
if reverse_condition:
filter_fn = lambda stage, epoch, loader: epoch not in epochs
else:
filter_fn = lambda stage, epoch, loader: epoch in epochs
else:
raise ValueError("'epochs' should be int/float/Sequence[int]! " f"(got {type(epochs)})")
return filter_fn


def _filter_fn_from_loaders(loaders: LOADERS, reverse_condition: bool) -> FILTER_FN:
"""Build ``filter_fn`` from loaders for ``ControlFlowCallback``.
Args:
loaders (str/Sequence[str]/Mapping[str, int/Sequence[str]]):
loaders description
reverse_condition: indicator to use reversed
condition in filter function
Raises:
ValueError: if can't build filter_fn from mappings
ValueError: if passed object with unexpected type
Returns:
filter function which accepts 3 arguments - stage (str),
epoch (int), loader (str) and return ``True`` if
need to disable callback
"""
if isinstance(loaders, str):
loaders = [loaders]

# sequence of loaders
if isinstance(loaders, (list, tuple)):
loaders = sorted(set(loaders)) # ignore duplicates
if reverse_condition:
filter_fn = lambda stage, epoch, loader: loader not in loaders
else:
filter_fn = lambda stage, epoch, loader: loader in loaders
# loader: ignore epoch or epochs
elif isinstance(loaders, (dict, OrderedDict)):
ignore_list = {}
for loader, epochs in loaders.items():
if isinstance(epochs, (int, float)):
ignore_list[loader] = [int(epochs)]
else:
try:
ignore_list[loader] = []
for num in sorted(set(epochs)):
to_add = int(num)
ignore_list[loader].append(to_add)
except (ValueError, TypeError):
raise ValueError(
"'ignore_list' should be a dict where "
"keys is a int/float/List[int]/Tuple[int]!"
)
if reverse_condition:
filter_fn = lambda stage, epoch, loader: epoch not in (
ignore_list.get(loader) or {} # {loader: [epoch]}.get(loader)
class _EpochFilterFn:
def __init__(self, epochs: Union[int, float, Sequence[int]], reverse_condition: bool):
if not isinstance(epochs, (int, float, list, tuple)):
raise ValueError(
"'epochs' should be int/float/Sequence[int]! " f"(got {type(epochs)})"
)
else:
filter_fn = lambda stage, epoch, loader: epoch in (ignore_list.get(loader) or {})
else:
raise ValueError(
"'loaders' type should be one of - str, "
"Sequence[str], Mapping[str, int] or "
"Mapping[str, Sequence[int]]! "
f"(got {type(loaders)})"
)
return filter_fn

self.epochs = epochs
self.reverse_condition = reverse_condition

# extra conditions precomputing
if isinstance(self.epochs, (int, float)):
self.epochs = int(self.epochs)
elif isinstance(self.epochs, (list, tuple)):
self.epochs = sorted(set(self.epochs))

def __call__(self, stage, epoch, loader):
if isinstance(self.epochs, (int, float)):
if self.reverse_condition:
return epoch % self.epochs != 0
else:
return epoch % self.epochs == 0
elif isinstance(self.epochs, (list, tuple)):
if self.reverse_condition:
return epoch not in self.epochs
else:
return epoch in self.epochs

def _filter_fn_from_arg(filter_fn: Union[str, FILTER_FN]) -> FILTER_FN:
"""Check if filter function from argumets
can be used with ``ControlFlowCallback``.

Args:
filter_fn (str or Callable): filter function to check
Raises:
ValueError: if ``filter_fn`` is a string and can not be
interpreted as python code then an error will be raised
ValueError: if passed not callable object then will be
raised an error
ValueError: will be raised error if filter function do not
have three arguments
Returns:
filter function which accepts 3 arguments - stage (str),
epoch (int), loader (str) and return ``True`` if
need to disable callback
"""
if isinstance(filter_fn, str):
# lambda function from string
try:
filter_fn = eval(filter_fn)
except (ValueError, SyntaxError):
class _LoaderFilterFn:
def __init__(self, loaders: LOADERS, reverse_condition: bool):
if isinstance(loaders, str):
loaders = [loaders]
if not isinstance(loaders, (list, tuple, dict, OrderedDict)):
raise ValueError(
"'filter_fn' should be a valid "
"python lambda function with "
"three arguments - 'stage', 'epoch' and 'loader'!"
"'loaders' type should be one of - str, "
"Sequence[str], Mapping[str, int] or "
"Mapping[str, Sequence[int]]! "
f"(got {type(loaders)})"
)
if not callable(filter_fn):
raise ValueError("'filter_fn' should be a callable!")
if filter_fn.__code__.co_argcount != 3:
raise ValueError(
"Filter function should have three arguments - " "'stage', 'epoch' and 'loader'!"
)
return filter_fn
self.loaders = loaders
self.reverse_condition = reverse_condition

# extra conditions precomputing
if isinstance(self.loaders, (list, tuple)):
self.loaders = sorted(set(self.loaders)) # ignore duplicates
elif isinstance(self.loaders, (dict, OrderedDict)):
ignore_list = {}
for loader, epochs in self.loaders.items():
if isinstance(epochs, (int, float)):
ignore_list[loader] = [int(epochs)]
else:
try:
ignore_list[loader] = []
for num in sorted(set(epochs)):
to_add = int(num)
ignore_list[loader].append(to_add)
except (ValueError, TypeError):
raise ValueError(
"'ignore_list' should be a dict where "
"keys is a int/float/List[int]/Tuple[int]!"
)
self._ignore_list = ignore_list

def __call__(self, stage, epoch, loader):
# sequence of loaders
if isinstance(self.loaders, (list, tuple)):
if self.reverse_condition:
return loader not in self.loaders
else:
return loader in self.loaders
# loader: ignore epoch or epochs
elif isinstance(self.loaders, (dict, OrderedDict)):
if self.reverse_condition:
return epoch not in (
self._ignore_list.get(loader) or {} # {loader: [epoch]}.get(loader)
)
else:
return epoch in (self._ignore_list.get(loader) or {})


class _ArgsFilterFn:
def __init__(self, filter_fn: Union[str, FILTER_FN]):
if isinstance(filter_fn, str):
# lambda function from string
try:
filter_fn = eval(filter_fn)
except (ValueError, SyntaxError):
raise ValueError(
"'filter_fn' should be a valid "
"python lambda function with "
"three arguments - 'stage', 'epoch' and 'loader'!"
)
if not callable(filter_fn):
raise ValueError("'filter_fn' should be a callable!")
if filter_fn.__code__.co_argcount != 3:
raise ValueError(
"Filter function should have three arguments - " "'stage', 'epoch' and 'loader'!"
)
self.filter_fn = filter_fn

def __call__(self, stage, epoch, loader):
return self.filter_fn(stage, epoch, loader)


class ControlFlowCallback(CallbackWrapper):
Expand Down Expand Up @@ -352,16 +321,17 @@ def __init__(
# loader parameters
self.filter_fn = None

# due to ddp-setup, we have to wrap everything with classes
if epochs is not None:
self.filter_fn = _filter_fn_from_epochs(epochs, False)
self.filter_fn = _EpochFilterFn(epochs, False)
elif ignore_epochs is not None:
self.filter_fn = _filter_fn_from_epochs(ignore_epochs, True)
self.filter_fn = _EpochFilterFn(ignore_epochs, True)
elif loaders is not None:
self.filter_fn = _filter_fn_from_loaders(loaders, False)
self.filter_fn = _LoaderFilterFn(loaders, False)
elif ignore_loaders is not None:
self.filter_fn = _filter_fn_from_loaders(ignore_loaders, True)
self.filter_fn = _LoaderFilterFn(ignore_loaders, True)
elif filter_fn is not None:
self.filter_fn = _filter_fn_from_arg(filter_fn)
self.filter_fn = _ArgsFilterFn(filter_fn)

def on_loader_start(self, runner: "IRunner") -> None:
"""
Expand Down

0 comments on commit b6a4a7f

Please sign in to comment.