Skip to content

Commit

Permalink
Add ability to query for subnets/supernets by cidr
Browse files Browse the repository at this point in the history
  • Loading branch information
gmjosack committed Jan 11, 2015
1 parent 40c6f2a commit 8901ced
Show file tree
Hide file tree
Showing 2 changed files with 39 additions and 1 deletion.
25 changes: 24 additions & 1 deletion nsot/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -284,6 +284,7 @@ def after_create(self, user_id):
)

def networks(self, include_networks=True, include_ips=False, root=False,
subnets_of=None, supernets_of=None,
attribute_name=None, attribute_value=None):
""" Helper method for grabbing Networks.
Expand All @@ -292,6 +293,8 @@ def networks(self, include_networks=True, include_ips=False, root=False,
address networks
include_ips: Whether the response should include ip addresses
root: Only return networks at the root.
subnets_of: Only return subnets of the given CIDR
supernets_of: Only return supernets of the given CIDR
attribute_name: Filter to networks that contain this attribute name
attribute_value: Filter to networks that contain this attribute value
"""
Expand All @@ -302,8 +305,10 @@ def networks(self, include_networks=True, include_ips=False, root=False,
if attribute_value is not None and attribute_name is None:
raise ValueError("attribute_value requires attribute_name to be set.")

query = self.session.query(Network)
if all([subnets_of, supernets_of]):
raise ValueError("subnets_of and supernets_of are mutually exclusive.")

query = self.session.query(Network)

if attribute_name is not None:
query = query.outerjoin(Network.attr_idx).filter(
Expand All @@ -323,6 +328,24 @@ def networks(self, include_networks=True, include_ips=False, root=False,
if root:
query = query.filter(Network.parent_id == None)

if subnets_of is not None:
subnets_of = ipaddress.ip_network(unicode(subnets_of))
query = query.filter(
Network.ip_version == subnets_of.version,
Network.prefix_length > subnets_of.prefixlen,
Network.network_address >= subnets_of.network_address.packed,
Network.broadcast_address <= subnets_of.broadcast_address.packed
)

if supernets_of is not None:
supernets_of = ipaddress.ip_network(unicode(supernets_of))
query = query.filter(
Network.ip_version == supernets_of.version,
Network.prefix_length < supernets_of.prefixlen,
Network.network_address <= supernets_of.network_address.packed,
Network.broadcast_address >= supernets_of.broadcast_address.packed
)

return query


Expand Down
15 changes: 15 additions & 0 deletions tests/model_tests/test_networks.py
Original file line number Diff line number Diff line change
Expand Up @@ -103,6 +103,21 @@ def test_retrieve_networks(session, admin, site):
include_networks=False, include_ips=True
)) == sorted([ip])

assert sorted(site.networks(
subnets_of="10.0.0.0/10"
)) == sorted([net_24])

assert sorted(site.networks(
subnets_of="10.0.0.0/10", include_ips=True
)) == sorted([net_24, ip])

assert sorted(site.networks(
supernets_of="10.0.0.0/10"
)) == sorted([net_8])

with pytest.raises(ValueError):
site.networks(subnets_of="10.0.0.0/10", supernets_of="10.0.0.0/10")

with pytest.raises(ValueError):
assert site.networks(attribute_value="foo")

Expand Down

0 comments on commit 8901ced

Please sign in to comment.