Skip to content

Commit

Permalink
Fix add_host and add_group methods to re-initialize inventory (#384)
Browse files Browse the repository at this point in the history
* Fix add_host method and add test to test_inventory

* Fix add_group method and add tests

* Black reformatting

* Add _update_group_refs helper and tests
  • Loading branch information
brandomando authored and dbarrosop committed May 3, 2019
1 parent 8f9607c commit b03bf8d
Show file tree
Hide file tree
Showing 2 changed files with 32 additions and 11 deletions.
29 changes: 19 additions & 10 deletions nornir/core/inventory.py
Original file line number Diff line number Diff line change
Expand Up @@ -428,6 +428,17 @@ def filter(self, filter_obj=None, filter_func=None, *args, **kwargs):
def __len__(self):
return self.hosts.__len__()

def _update_group_refs(self, inventory_element: InventoryElement) -> None:
"""
Returns inventory_element with updated group references for the supplied
inventory element
"""
if hasattr(inventory_element, "groups"):
inventory_element.groups.refs = [
self.groups[p] for p in inventory_element.groups
]
return inventory_element

def children_of_group(self, group: Union[str, Group]) -> Set[Host]:
"""
Returns set of hosts that belongs to a group including those that belong
Expand All @@ -443,22 +454,20 @@ def add_host(self, name: str, **kwargs) -> None:
"""
Add a host to the inventory after initialization
"""
host = {
name: deserializer.inventory.InventoryElement.deserialize_host(
name=name, defaults=self.defaults, **kwargs
)
}
host_element = deserializer.inventory.InventoryElement.deserialize_host(
name=name, defaults=self.defaults, **kwargs
)
host = {name: self._update_group_refs(host_element)}
self.hosts.update(host)

def add_group(self, name: str, **kwargs) -> None:
"""
Add a group to the inventory after initialization
"""
group = {
name: deserializer.inventory.InventoryElement.deserialize_group(
name=name, defaults=self.defaults, **kwargs
)
}
group_element = deserializer.inventory.InventoryElement.deserialize_group(
name=name, defaults=self.defaults, **kwargs
)
group = {name: self._update_group_refs(group_element)}
self.groups.update(group)

def get_inventory_dict(self) -> Dict:
Expand Down
14 changes: 13 additions & 1 deletion tests/core/test_inventory.py
Original file line number Diff line number Diff line change
Expand Up @@ -255,14 +255,19 @@ def test_add_host(self):
connection_options=h3_connection_options,
)
assert "h3" in inv.hosts
assert "g1" in inv.hosts["h3"].groups
assert "g1" in [i.name for i in inv.hosts["h3"].groups.refs]
assert "test_var" in inv.hosts["h3"].defaults.data.keys()
assert inv.hosts["h3"].defaults.data.get("test_var") == "test_value"
assert inv.hosts["h3"].platform == "TestPlatform"
assert (
inv.hosts["h3"].connection_options["netmiko"].extras["device_type"]
== "cisco_ios"
)
with pytest.raises(KeyError):
inv.add_host(name="h4", groups=["not_defined"])
# Test with one good and one undefined group
with pytest.raises(KeyError):
inv.add_host(name="h5", groups=["g1", "not_defined"])

def test_add_group(self):
connection_options = {"username": "test_user", "password": "test_pass"}
Expand All @@ -279,6 +284,7 @@ def test_add_group(self):
inv.add_group(
name="g3", username="test_user", connection_options=g3_connection_options
)
assert "g1" in [i.name for i in inv.groups["g2"].groups.refs]
assert "g3" in inv.groups
assert (
inv.groups["g3"].defaults.connection_options.get("username") == "test_user"
Expand All @@ -292,6 +298,12 @@ def test_add_group(self):
inv.groups["g3"].connection_options["netmiko"].extras["device_type"]
== "cisco_ios"
)
# Test with one undefined parent group
with pytest.raises(KeyError):
inv.add_group(name="g4", groups=["undefined"])
# Test with one defined and one undefined parent group
with pytest.raises(KeyError):
inv.add_group(name="g4", groups=["g1", "undefined"])

def test_get_inventory_dict(self):
inv = deserializer.Inventory.deserialize(**inv_dict)
Expand Down

0 comments on commit b03bf8d

Please sign in to comment.