Skip to content

Commit

Permalink
Do not store failed connections in host attributes (#497)
Browse files Browse the repository at this point in the history
* Tests for failed connections #350

* Do not store failed connections in host.connections, fix #350

* Replace sentinel object UNESTABLISHED_CONNECTION with None

* Use variable name conn_obj instead of connection not to confuse mypy
  • Loading branch information
dmfigol committed Mar 24, 2020
1 parent f7bc048 commit cbbcf83
Show file tree
Hide file tree
Showing 3 changed files with 75 additions and 37 deletions.
11 changes: 2 additions & 9 deletions nornir/core/connections.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
from abc import ABC, abstractmethod
from typing import Any, Dict, NoReturn, Optional, Type
from typing import Any, Dict, Optional, Type


from nornir.core.configuration import Config
Expand All @@ -23,7 +23,7 @@ class ConnectionPlugin(ABC):
__slots__ = ("connection", "state")

def __init__(self) -> None:
self.connection: Any = UnestablishedConnection()
self.connection: Any = None
self.state: Dict[str, Any] = {}

@abstractmethod
Expand All @@ -49,13 +49,6 @@ def close(self) -> None:
pass


class UnestablishedConnection(object):
def close(self) -> NoReturn:
raise ValueError("Connection not established")

disconnect = close


class Connections(Dict[str, ConnectionPlugin]):
available: Dict[str, Type[ConnectionPlugin]] = {}

Expand Down
63 changes: 35 additions & 28 deletions nornir/core/inventory.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,10 @@

from nornir.core import deserializer
from nornir.core.configuration import Config
from nornir.core.connections import ConnectionPlugin, Connections
from nornir.core.connections import (
ConnectionPlugin,
Connections,
)
from nornir.core.exceptions import ConnectionAlreadyOpen, ConnectionNotOpen


Expand Down Expand Up @@ -336,39 +339,43 @@ def open_connection(
Returns:
An already established connection
"""
if connection in self.connections:
raise ConnectionAlreadyOpen(connection)
conn_name = connection
existing_conn = self.connections.get(conn_name)
if existing_conn is not None:
raise ConnectionAlreadyOpen(conn_name)

self.connections[connection] = self.connections.get_plugin(connection)()
plugin = self.connections.get_plugin(conn_name)
conn_obj = plugin()
if default_to_host_attributes:
conn_params = self.get_connection_parameters(connection)
self.connections[connection].open(
hostname=hostname if hostname is not None else conn_params.hostname,
username=username if username is not None else conn_params.username,
password=password if password is not None else conn_params.password,
port=port if port is not None else conn_params.port,
platform=platform if platform is not None else conn_params.platform,
extras=extras if extras is not None else conn_params.extras,
configuration=configuration,
)
else:
self.connections[connection].open(
hostname=hostname,
username=username,
password=password,
port=port,
platform=platform,
extras=extras,
configuration=configuration,
)
return self.connections[connection]
conn_params = self.get_connection_parameters(conn_name)
hostname = hostname if hostname is not None else conn_params.hostname
username = username if username is not None else conn_params.username
password = password if password is not None else conn_params.password
port = port if port is not None else conn_params.port
platform = platform if platform is not None else conn_params.platform
extras = extras if extras is not None else conn_params.extras

conn_obj.open(
hostname=hostname,
username=username,
password=password,
port=port,
platform=platform,
extras=extras,
configuration=configuration,
)
self.connections[conn_name] = conn_obj
return connection

def close_connection(self, connection: str) -> None:
""" Close the connection"""
if connection not in self.connections:
raise ConnectionNotOpen(connection)
conn_name = connection
if conn_name not in self.connections:
raise ConnectionNotOpen(conn_name)

self.connections.pop(connection).close()
conn_obj = self.connections.pop(conn_name)
if conn_obj is not None:
conn_obj.close()

def close_connections(self) -> None:
# Decouple deleting dictionary elements from iterating over connections dict
Expand Down
38 changes: 38 additions & 0 deletions tests/core/test_connections.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,31 @@ class AnotherDummyConnectionPlugin(DummyConnectionPlugin):
pass


class FailedConnection(Exception):
pass


class FailedConnectionPlugin(ConnectionPlugin):
name = "fail"

def open(
self,
hostname: Optional[str],
username: Optional[str],
password: Optional[str],
port: Optional[int],
platform: Optional[str],
extras: Optional[Dict[str, Any]] = None,
configuration: Optional[Config] = None,
) -> None:
raise FailedConnection(
f"Failed to open connection to {self.hostname}:{self.port}"
)

def close(self) -> None:
pass


def open_and_close_connection(task):
task.host.open_connection("dummy", task.nornir.config)
assert "dummy" in task.host.connections
Expand Down Expand Up @@ -69,6 +94,10 @@ def close_not_opened_connection(task):
assert "dummy" not in task.host.connections


def failed_connection(task):
task.host.open_connection(FailedConnectionPlugin.name, task.nornir.config)


def a_task(task):
task.host.get_connection("dummy", task.nornir.config)

Expand All @@ -86,6 +115,7 @@ def setup_class(cls):
Connections.register("dummy", DummyConnectionPlugin)
Connections.register("dummy2", DummyConnectionPlugin)
Connections.register("dummy_no_overrides", DummyConnectionPlugin)
Connections.register(FailedConnectionPlugin.name, FailedConnectionPlugin)

def test_open_and_close_connection(self, nornir):
nr = nornir.filter(name="dev2.group_1")
Expand All @@ -105,6 +135,14 @@ def test_close_not_opened_connection(self, nornir):
assert len(r) == 1
assert not r.failed

def test_failed_connection(self, nornir):
nr = nornir.filter(name="dev2.group_1")
nr.run(task=failed_connection, num_workers=1)
assert (
FailedConnectionPlugin.name
not in nornir.inventory.hosts["dev2.group_1"].connections
)

def test_context_manager(self, nornir):
with nornir.filter(name="dev2.group_1") as nr:
nr.run(task=a_task)
Expand Down

0 comments on commit cbbcf83

Please sign in to comment.