Skip to content

Commit

Permalink
Generate firewall rules with common properties
Browse files Browse the repository at this point in the history
Network security group (NSG) records from Azure and firewall records
from GCP are in very different formats. In both types of records, each
record may contain multiple firewall rules. This change ensures that
both AzureCloud and GCPCloud plugin generates a new record for each
firewall rule found in NSG/firewall record.

Further, this change ensures that the `com` bucket, i.e.,
`record['com']` is populated with firewall rule properties that we care
about in common notation, where `record` denotes a firewall rule record.
  • Loading branch information
susam committed Apr 3, 2019
1 parent 6969fe7 commit 515e320
Show file tree
Hide file tree
Showing 5 changed files with 1,103 additions and 50 deletions.
193 changes: 189 additions & 4 deletions cloudmarker/clouds/azurecloud.py
Original file line number Diff line number Diff line change
Expand Up @@ -127,16 +127,17 @@ def _get_doc(iterator, azure_record_type, subscription_id):
# Dictionary to map Azure record types to common record types.
record_type_map = {
'virtual_machine': 'compute',
'nsg': 'firewall_rule',
}

try:
for i, v in enumerate(iterator):
raw_doc = v.as_dict()
doc = {
'raw': v.as_dict(),
'raw': raw_doc,
'ext': {
'cloud_type': 'azure',
'record_type': azure_record_type,
'subscription_id': subscription_id
'subscription_id': subscription_id,
},
'com': {
'cloud_type': 'azure',
Expand All @@ -146,10 +147,194 @@ def _get_doc(iterator, azure_record_type, subscription_id):

_log.info('Found %s #%d; subscription_id: %s; name: %s',
azure_record_type, i, subscription_id,
doc['raw']['name'])
doc['raw'].get('name'))

yield doc

# For every security rule found in an NSG, generate a
# separate security rule (firewall rule) record to maintain
# parity with separate records for separate firewall rules
# in GCP.
if azure_record_type == 'nsg':
yield from _get_normalized_firewall_rules(raw_doc,
subscription_id)

except CloudError as e:
_log.error('Failed to fetch details for %s; subscription_id: %s; '
'error: %s: %s',
azure_record_type, subscription_id, type(e).__name__, e)


def _get_normalized_firewall_rules(nsg_doc, subscription_id):
"""Split a network security group (NSG) into multiple firewall rules.
An Azure NSG record contains a top-level key named
``security_rules`` whose value is a list of security rules.
In order to make it easier to write event plugins to detect security
issues in an NSG, we generate a new firewall rule record for each
security rule found in the NSG.
Arguments:
nsg_doc (dict): Raw NSG record as retrieved from Azure.
subscription_id (str): Subscription ID (for logging purpose only).
Yields:
dict: A normalized firewall rule record with ``com`` bucket
populated with firewall rule properties in common notation.
"""
security_rules = nsg_doc.get('security_rules')
nsg_name = nsg_doc.get('name')

if security_rules is None:
_log.warning('Found NSG without security_rules; name: %s', nsg_name)
return

for i, security_rule in enumerate(security_rules):
record = {
'raw': security_rule,
'ext': {
'record_type': 'security_rule',
'subscription_id': subscription_id,
'nsg_id': nsg_doc.get('id'),
'security_rule_id': security_rule.get('id'),
},
'com': {
'cloud_type': 'azure',
'record_type': 'firewall_rule',
'reference': security_rule.get('id'),

'enabled':
_get_normalized_firewall_state(security_rule),

'direction':
_get_normalized_firewall_direction(security_rule),

'access':
_get_normalized_firewall_access(security_rule),

'source_addresses':
_get_normalized_firewall_source_addresses(security_rule),

'protocol':
_get_normalized_firewall_protocol(security_rule),

'destination_ports':
_get_normalized_firewall_destination_ports(security_rule),
}
}
_log.info('Found document security_rule #%d; '
'subscription_id: %s; name: %s',
i, subscription_id, security_rule.get('name'))
yield record


def _get_normalized_firewall_state(security_rule):
rule_name = security_rule.get('name')
state = security_rule.get('provisioning_state')

if state is None:
_log.warning('Found security rule without provisioning_state; '
'name: %s', rule_name)
return None

return state.lower() == 'succeeded'


def _get_normalized_firewall_direction(security_rule):
rule_name = security_rule.get('name')
direction = security_rule.get('direction')

if direction is None:
_log.warning('Found security rule without direction; name: %s',
rule_name)
return None

direction = direction.lower()

if direction == 'inbound':
return 'in'

if direction == 'outbound':
return 'out'

_log.warning('Found unknown direction in security rule; '
'direction: %s; name: %s', direction, rule_name)
return direction


def _get_normalized_firewall_access(security_rule):
rule_name = security_rule.get('name')
access = security_rule.get('access')

if access is None:
_log.warning('Found security rule without access; name: %s',
rule_name)
return None

access = access.lower()

if access in ('allow', 'deny'):
return access

_log.warning('Found unknown access in security rule; '
'access: %s; name: %s', access, rule_name)
return access


def _get_normalized_firewall_source_addresses(security_rule):
all_prefixes = []

prefix = security_rule.get('source_address_prefix')
if prefix is not None:
all_prefixes.append(prefix)

prefixes = security_rule.get('source_address_prefixes')
if prefixes is not None:
all_prefixes.extend(prefixes)

source_addresses = []
for prefix in all_prefixes:
if prefix in ('*', 'Internet'):
source_addresses.append('0.0.0.0/0')
else:
source_addresses.append(prefix)

return source_addresses


def _get_normalized_firewall_protocol(security_rule):
rule_name = security_rule.get('name')
protocol = security_rule.get('protocol')

if protocol is None:
_log.warning('Found security rule without protocol; name: %s',
rule_name)
return None

protocol = protocol.lower()
if protocol == '*':
return 'all'
return protocol


def _get_normalized_firewall_destination_ports(security_rule):
all_ports = []

port = security_rule.get('destination_port_range')
if port is not None:
all_ports.append(port)

ports = security_rule.get('destination_port_ranges')
if ports is not None:
all_ports.extend(ports)

destination_ports = []
for port in all_ports:
if port == '*':
destination_ports.append('0-65535')
else:
destination_ports.append(port)

return destination_ports

0 comments on commit 515e320

Please sign in to comment.