diff --git a/suzieq/poller/controller/source/base_source.py b/suzieq/poller/controller/source/base_source.py index eff1a6108a..e4efbf5c53 100644 --- a/suzieq/poller/controller/source/base_source.py +++ b/suzieq/poller/controller/source/base_source.py @@ -14,6 +14,8 @@ from suzieq.poller.controller.utils.inventory_utils import read_inventory from suzieq.shared.exceptions import InventorySourceError +_DEFAULT_PORTS = {'http': 80, 'https': 443, 'ssh': 22} + class SourceModel(InventoryPluginModel): """Model for inventory source validation @@ -233,14 +235,17 @@ def set_device(self, inventory: Dict[str, Dict]): devtype = self._device.get('devtype') for node in inventory.values(): + transport_tmp = node.get('transport') or transport or 'ssh' + ignore_known_hosts_tmp = node.get('ignore_known_hosts') node.update({ 'jump_host': node.get('jump_host') or jump_host, 'jump_host_key_file': node.get('jump_host_key_file') or jump_host_key_file, - 'ignore_known_hosts': node.get('ignore_known_hosts') - or ignore_known_hosts, - 'transport': node.get('transport') or transport or 'ssh', - 'port': node.get('port') or port or 22, + 'ignore_known_hosts': ignore_known_hosts_tmp if + ignore_known_hosts_tmp is not None else ignore_known_hosts, + 'transport': transport_tmp, + 'port': node.get('port') or port or + _DEFAULT_PORTS.get(transport_tmp), 'devtype': node.get('devtype') or devtype, 'slow_host': node.get('slow_host', '') or slow_host, 'per_cmd_auth': ((node.get('per_cmd_auth', '') != '') diff --git a/tests/unit/poller/controller/sources/test_devices.py b/tests/unit/poller/controller/sources/test_devices.py index 9d763f5f3c..1dc1137ccc 100644 --- a/tests/unit/poller/controller/sources/test_devices.py +++ b/tests/unit/poller/controller/sources/test_devices.py @@ -9,14 +9,63 @@ from suzieq.shared.utils import PollerTransport -_INVENTORY = [{ +_INVENTORY = { 'native-ns.192.168.123.123.443': { 'address': '192.168.123.123', 'hostname': None, 'namespace': 'native-ns', 'port': 443, + 'transport': 'https' + }, + 'native-ns.192.168.123.164.443': + { + 'address': '192.168.123.164', + 'devtype': 'eos', + 'hostname': None, + 'namespace': 'native-ns', + 'port': 443, + 'ignore_known_hosts': False + }, + 'native-ns.192.168.123.111.443': + { + 'address': '192.168.123.111', + 'hostname': None, + 'namespace': 'native-ns', + 'transport': 'https' + }, + 'native-ns.192.168.123.110.22': + { + 'address': '192.168.123.110', + 'hostname': None, + 'namespace': 'native-ns', + 'transport': 'ssh' + }, + 'native-ns.192.168.123.143.443': + { + 'address': '192.168.123.143', + 'hostname': None, + 'namespace': 'native-ns', + 'transport': 'http', + 'port': 443 + }, + 'native-ns.192.168.123.171.22': + { + 'address': '192.168.123.171', + 'hostname': None, + 'namespace': 'native-ns', 'transport': 'http' + } +} + +_RESULT_INVENTORY = { + 'native-ns.192.168.123.123.443': + { + 'address': '192.168.123.123', + 'hostname': None, + 'namespace': 'native-ns', + 'port': 443, + 'transport': 'https' }, 'native-ns.192.168.123.164.443': { @@ -25,9 +74,42 @@ 'hostname': None, 'namespace': 'native-ns', 'port': 443, + 'transport': 'ssh', 'ignore_known_hosts': False + }, + 'native-ns.192.168.123.111.443': + { + 'address': '192.168.123.111', + 'hostname': None, + 'namespace': 'native-ns', + 'transport': 'https', + 'port': 443 + }, + 'native-ns.192.168.123.110.22': + { + 'address': '192.168.123.110', + 'hostname': None, + 'namespace': 'native-ns', + 'transport': 'ssh', + 'port': 22 + }, + 'native-ns.192.168.123.143.443': + { + 'address': '192.168.123.143', + 'hostname': None, + 'namespace': 'native-ns', + 'transport': 'http', + 'port': 443 + }, + 'native-ns.192.168.123.171.22': + { + 'address': '192.168.123.171', + 'hostname': None, + 'namespace': 'native-ns', + 'transport': 'http', + 'port': 80 } -}] +} def set_inventory_mock(self, inventory: Dict): @@ -51,8 +133,9 @@ def set_inventory_mock(self, inventory: Dict): @pytest.mark.poller_unit_tests @pytest.mark.controller_unit_tests @pytest.mark.asyncio -@pytest.mark.parametrize('inventory', _INVENTORY) -async def test_devices_set(inventory: Dict): +@pytest.mark.parametrize('inventory, result_inventory', + [(_INVENTORY, _RESULT_INVENTORY)]) +async def test_devices_set(inventory: Dict, result_inventory: Dict): """Test devices are correctly set Args: @@ -66,7 +149,9 @@ async def test_devices_set(inventory: Dict): 'jump-host-key-file': None, 'devtype': 'panos', 'transport': PollerTransport.ssh, - + 'slow_host': False, + 'per_cmd_auth': True, + 'retries-on-auth-fail': 0 } } @@ -80,14 +165,14 @@ async def test_devices_set(inventory: Dict): # emulate what the function Source.set_device should do exp_inv = {} - for key, node in inventory.items(): + for key, node in result_inventory.items(): exp_inv[key] = node.copy() for k, v in config['device'].items(): k = k.replace('-', '_') if k not in exp_inv[key]: exp_inv[key][k] = v - assert inv == exp_inv + assert inv == exp_inv, 'inventory do not match' @pytest.mark.controller_device