Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Enable mypy for nornir.core.task #418

Merged
merged 1 commit into from
Aug 8, 2019
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
2 changes: 1 addition & 1 deletion nornir/core/state.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ class GlobalState(object):

__slots__ = "dry_run", "failed_hosts"

def __init__(self, dry_run: bool = None, failed_hosts: Set[str] = None) -> None:
def __init__(self, dry_run: bool = False, failed_hosts: Set[str] = None) -> None:
self.dry_run = dry_run
self.failed_hosts = failed_hosts or set()

Expand Down
60 changes: 36 additions & 24 deletions nornir/core/task.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,13 @@
import logging
import traceback
from typing import Any, Optional, TYPE_CHECKING
from typing import Any, Callable, Dict, List, Optional, TYPE_CHECKING, Union

from nornir.core.exceptions import NornirExecutionError
from nornir.core.exceptions import NornirSubTaskError

if TYPE_CHECKING:
from nornir.core.inventory import Host
from nornir.core import Nornir


logger = logging.getLogger(__name__)
Expand Down Expand Up @@ -36,17 +37,23 @@ class Task(object):
severity_level (logging.LEVEL): Severity level associated to the task
"""

def __init__(self, task, name=None, severity_level=logging.INFO, **kwargs):
def __init__(
self,
task: Callable[..., Any],
name: str = None,
severity_level: int = logging.INFO,
**kwargs: str
):
self.name = name or task.__name__
self.task = task
self.params = kwargs
self.results = MultiResult(self.name)
self.severity_level = severity_level

def __repr__(self):
def __repr__(self) -> str:
return self.name

def start(self, host, nornir):
def start(self, host: "Host", nornir: "Nornir") -> "MultiResult":
"""
Run the task for the given host.

Expand Down Expand Up @@ -94,7 +101,7 @@ def start(self, host, nornir):
self.results.insert(0, r)
return self.results

def run(self, task, **kwargs):
def run(self, task: Callable[..., Any], **kwargs: Any) -> "MultiResult":
"""
This is a utility method to call a task from within a task. For instance:

Expand All @@ -115,8 +122,8 @@ def grouped_tasks(task):

if "severity_level" not in kwargs:
kwargs["severity_level"] = self.severity_level
task = Task(task, **kwargs)
r = task.start(self.host, self.nornir)
run_task = Task(task, **kwargs)
r = run_task.start(self.host, self.nornir)
self.results.append(r[0] if len(r) == 1 else r)

if r.failed:
Expand All @@ -125,11 +132,16 @@ def grouped_tasks(task):

return r

def is_dry_run(self, override: bool = None) -> bool:
def is_dry_run(self, override: Optional[bool] = None) -> bool:
"""
Returns whether current task is a dry_run or not.
"""
# if override is not None:
# return override

# return self.nornir.data.dry_run
return override if override is not None else self.nornir.data.dry_run
# return cast(bool, override if override is not None else self.nornir.data.dry_run)


class Result(object):
Expand Down Expand Up @@ -157,7 +169,7 @@ class Result(object):

def __init__(
self,
host: "Host",
host: Union["Host", None],
result: Any = None,
changed: bool = False,
diff: str = "",
Expand All @@ -181,43 +193,43 @@ def __init__(
for k, v in kwargs.items():
setattr(self, k, v)

def __repr__(self):
def __repr__(self) -> str:
return '{}: "{}"'.format(self.__class__.__name__, self.name)

def __str__(self):
def __str__(self) -> str:
if self.exception:
return str(self.exception)

else:
return str(self.result)


class AggregatedResult(dict):
class AggregatedResult(Dict[str, Any]):
"""
It basically is a dict-like object that aggregates the results for all devices.
You can access each individual result by doing ``my_aggr_result["hostname_of_device"]``.
"""

def __init__(self, name, **kwargs):
def __init__(self, name: str, **kwargs: str):
self.name = name
super().__init__(**kwargs)

def __repr__(self):
def __repr__(self) -> str:
return "{} ({}): {}".format(
self.__class__.__name__, self.name, super().__repr__()
)

@property
def failed(self):
def failed(self) -> bool:
"""If ``True`` at least a host failed."""
return any([h.failed for h in self.values()])

@property
def failed_hosts(self):
def failed_hosts(self) -> Dict[str, "MultiResult"]:
"""Hosts that failed during the execution of the task."""
return {h: r for h, r in self.items() if r.failed}

def raise_on_error(self):
def raise_on_error(self) -> None:
"""
Raises:
:obj:`nornir.core.exceptions.NornirExecutionError`: When at least a task failed
Expand All @@ -226,32 +238,32 @@ def raise_on_error(self):
raise NornirExecutionError(self)


class MultiResult(list):
class MultiResult(List[Any]):
"""
It is basically is a list-like object that gives you access to the results of all subtasks for
a particular device/task.
"""

def __init__(self, name):
def __init__(self, name: str):
self.name = name

def __getattr__(self, name):
def __getattr__(self, name: str) -> Any:
return getattr(self[0], name)

def __repr__(self):
def __repr__(self) -> str:
return "{}: {}".format(self.__class__.__name__, super().__repr__())

@property
def failed(self):
def failed(self) -> bool:
"""If ``True`` at least a task failed."""
return any([h.failed for h in self])

@property
def changed(self):
def changed(self) -> bool:
"""If ``True`` at least a task changed the system."""
return any([h.changed for h in self])

def raise_on_error(self):
def raise_on_error(self) -> None:
"""
Raises:
:obj:`nornir.core.exceptions.NornirExecutionError`: When at least a task failed
Expand Down
3 changes: 0 additions & 3 deletions setup.cfg
Original file line number Diff line number Diff line change
Expand Up @@ -55,8 +55,5 @@ ignore_errors = True
[mypy-nornir.core.inventory]
ignore_errors = True

[mypy-nornir.core.task]
ignore_errors = True

[mypy-tests.*]
ignore_errors = True