Skip to content

Commit

Permalink
Merge pull request #289 from nccgroup/refactoring/async-support
Browse files Browse the repository at this point in the history
Parallelize async loops and factorize async code
  • Loading branch information
misg committed Mar 29, 2019
2 parents c012f86 + bb6a3a5 commit 7694eae
Show file tree
Hide file tree
Showing 15 changed files with 172 additions and 144 deletions.
21 changes: 15 additions & 6 deletions ScoutSuite/providers/aws/facade/elbv2.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,10 @@
import asyncio

from ScoutSuite.providers.aws.facade.utils import AWSFacadeUtils
from ScoutSuite.providers.aws.facade.basefacade import AWSBaseFacade
from ScoutSuite.providers.aws.utils import ec2_classic
from ScoutSuite.providers.utils import run_concurrently

from asyncio import Lock


class ELBv2Facade(AWSBaseFacade):
regional_load_balancers_cache_locks = {}
Expand All @@ -15,7 +15,7 @@ async def get_load_balancers(self, region: str, vpc: str):
return [load_balancer for load_balancer in self.load_balancers_cache[region] if load_balancer['VpcId'] == vpc]

async def cache_load_balancers(self, region):
async with self.regional_load_balancers_cache_locks.setdefault(region, Lock()):
async with self.regional_load_balancers_cache_locks.setdefault(region, asyncio.Lock()):
if region in self.load_balancers_cache:
return

Expand All @@ -27,11 +27,20 @@ async def cache_load_balancers(self, region):
load_balancer['VpcId'] =\
load_balancer['VpcId'] if 'VpcId' in load_balancer and load_balancer['VpcId'] else ec2_classic

async def get_load_balancer_attributes(self, region: str, load_balancer_arn: str):
if len(self.load_balancers_cache[region]) == 0:
return
tasks = {
asyncio.ensure_future(
self.get_and_set_load_balancer_attributes(region, load_balancer)
) for load_balancer in self.load_balancers_cache[region]
}
await asyncio.wait(tasks)

async def get_and_set_load_balancer_attributes(self, region: str, load_balancer: dict):
elbv2_client = AWSFacadeUtils.get_client('elbv2', self.session, region)
return await run_concurrently(
load_balancer['attributes'] = await run_concurrently(
lambda: elbv2_client.describe_load_balancer_attributes(
LoadBalancerArn=load_balancer_arn)['Attributes']
LoadBalancerArn=load_balancer['LoadBalancerArn'])['Attributes']
)

async def get_listeners(self, region: str, load_balancer_arn: str):
Expand Down
21 changes: 10 additions & 11 deletions ScoutSuite/providers/aws/resources/elbv2/load_balancers.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
from ScoutSuite.providers.aws.resources.resources import AWSCompositeResources
from ScoutSuite.providers.utils import get_non_provider_id

from .listeners import Listeners


Expand All @@ -10,15 +9,18 @@ class LoadBalancers(AWSCompositeResources):
]

async def fetch_all(self, **kwargs):
raw_loads_balancers = await self.facade.elbv2.get_load_balancers(self.scope['region'], self.scope['vpc'])
# TODO: parallelize the following loop which is async:
for raw_load_balancer in raw_loads_balancers:
id, load_balancer = await self._parse_load_balancer(raw_load_balancer)
raw_load_balancers = await self.facade.elbv2.get_load_balancers(self.scope['region'], self.scope['vpc'])
for raw_load_balancer in raw_load_balancers:
id, load_balancer = self._parse_load_balancer(raw_load_balancer)
self[id] = load_balancer
await self._fetch_children(
parent=load_balancer, scope={'region': self.scope['region'], 'load_balancer_arn': load_balancer['arn']})

async def _parse_load_balancer(self, load_balancer):
await self._fetch_children_of_all_resources(
resources=self,
scopes={load_balancer_id: {'region': self.scope['region'], 'load_balancer_arn': load_balancer['arn']}
for (load_balancer_id, load_balancer) in self.items()}
)

def _parse_load_balancer(self, load_balancer):
load_balancer['arn'] = load_balancer.pop('LoadBalancerArn')
load_balancer['name'] = load_balancer.pop('LoadBalancerName')
load_balancer['security_groups'] = []
Expand All @@ -28,7 +30,4 @@ async def _parse_load_balancer(self, load_balancer):
load_balancer['security_groups'].append({'GroupId': sg})
load_balancer.pop('SecurityGroups')

load_balancer['attributes'] =\
await self.facade.elbv2.get_load_balancer_attributes(self.scope['region'], load_balancer['arn'])

return get_non_provider_id(load_balancer['name']), load_balancer
Original file line number Diff line number Diff line change
Expand Up @@ -11,15 +11,17 @@ class ClusterParameterGroups(AWSCompositeResources):

async def fetch_all(self, **kwargs):
raw_parameter_groups = await self.facade.redshift.get_cluster_parameter_groups(self.scope['region'])
# TODO: parallelize this async loop:
for raw_parameter_group in raw_parameter_groups:
id, parameter_group = self._parse_parameter_group(raw_parameter_group)
await self._fetch_children(
parent=parameter_group,
scope={'region': self.scope['region'], 'parameter_group_name': parameter_group['name']}
)
self[id] = parameter_group

await self._fetch_children_of_all_resources(
resources=self,
scopes={parameter_group_id: {'region': self.scope['region'],
'parameter_group_name': parameter_group['name']}
for (parameter_group_id, parameter_group) in self.items()}
)

def _parse_parameter_group(self, raw_parameter_group):
name = raw_parameter_group.pop('ParameterGroupName')
id = get_non_provider_id(name)
Expand Down
16 changes: 4 additions & 12 deletions ScoutSuite/providers/aws/resources/regions.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,18 +21,10 @@ async def fetch_all(self, credentials, regions=None, partition_name='aws'):
'name': region
}

# TODO: make a refactoring of the following:
if len(self['regions']) == 0:
return
tasks = {
asyncio.ensure_future(
self._fetch_children(
self['regions'][region],
{'region': region, 'owner_id': account_id}
)
) for region in self['regions']
}
await asyncio.wait(tasks)
await self._fetch_children_of_all_resources(
resources=self['regions'],
scopes={region: {'region': region, 'owner_id': account_id} for region in self['regions']}
)

self._set_counts()

Expand Down
62 changes: 41 additions & 21 deletions ScoutSuite/providers/aws/resources/resources.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,9 +12,9 @@ class AWSResources(Resources, metaclass=abc.ABCMeta):

def __init__(self, facade, scope: dict):
"""
:param scope: The scope holds the scope in which the resource is located. This usually means \
at least a region, but can also contain a VPC id, an owner id, etc. It should be \
used when fetching the data through the facade.
:param scope: The scope holds the scope in which the resource is located. This usually means at least a region,
but can also contain a VPC id, an owner id, etc. It should be used when fetching the data through
the facade.
"""

self.scope = scope
Expand All @@ -23,29 +23,49 @@ def __init__(self, facade, scope: dict):

class AWSCompositeResources(AWSResources, CompositeResources, metaclass=abc.ABCMeta):

"""This class represents a collection of AWSResources. Classes extending AWSCompositeResources should \
define a "_children" attribute which consists of a list of tuples describing the children. The tuples \
are expected to respect the following format: (<child_class>, <child_name>). The child_name is used by \
indicates the name under which the child will be stored in the parent object.
"""This class represents a collection of AWSResources. Classes extending AWSCompositeResources should define a
"_children" attribute which consists of a list of tuples describing the children. The tuples are expected to
respect the following format: (<child_class>, <child_name>). The child_name is used by indicates the name under
which the child will be stored in the parent object.
"""

async def _fetch_children(self, parent: object, scope: dict):
"""This method calls fetch_all on each child defined in "_children" and stores the fetched resources \
in the parent under the key associated with the child. It also creates a "<child_name>_count" entry \
for each child.

:param parent: The object in which the children should be stored
:param scope: The scope passed to the children constructors
async def _fetch_children_of_all_resources(self, resources: dict, scopes: dict):
""" This method iterates through a collection of resources and fetches all children of each resource, in a
concurrent way.
:param resources: list of (composite) resources
:param scopes: dict that maps resource parent keys to scopes (dict) that should be used to retrieve children
of each resource.
"""
if len(resources) == 0:
return

tasks = {
asyncio.ensure_future(
self._fetch_children(resource_parent=resource_parent, scope=scopes[resource_parent_key])
) for (resource_parent_key, resource_parent) in resources.items()
}
await asyncio.wait(tasks)

async def _fetch_children(self, resource_parent: object, scope: dict):
"""This method fetches all children of a given resource (the so called 'resource_parent') by calling fetch_all
method on each child defined in '_children' and then stores the fetched resources in `resource_parent` under
the key associated with the child. It also creates a "<child_name>_count" entry for each child.
:param resource_parent: The resource in which the children will be stored.
:param scope: The scope passed to the children constructors.
"""

children = [(child_class(self.facade, scope), child_name) for (child_class, child_name) in self._children]
# fetch all children concurrently:
await asyncio.wait({asyncio.ensure_future(child.fetch_all()) for (child, _) in children})
# update parent content:
# Fetch all children concurrently:
await asyncio.wait(
{asyncio.ensure_future(child.fetch_all()) for (child, _) in children}
)
# Update parent content:
for child, child_name in children:
if parent.get(child_name) is None:
parent[child_name] = {}
if resource_parent.get(child_name) is None:
resource_parent[child_name] = {}

parent[child_name].update(child)
parent[child_name + '_count'] = len(child)
resource_parent[child_name].update(child)
resource_parent[child_name + '_count'] = len(child)

11 changes: 6 additions & 5 deletions ScoutSuite/providers/aws/resources/ses/identities.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,15 +11,16 @@ class Identities(AWSCompositeResources):

async def fetch_all(self, **kwargs):
raw_identities = await self.facade.ses.get_identities(self.scope['region'])
# TODO: parallelize the following async loop:
for raw_identity in raw_identities:
id, identity = self._parse_identity(raw_identity)
await self._fetch_children(
parent=identity,
scope={'region': self.scope['region'], 'identity_name': identity['name']}
)
self[id] = identity

await self._fetch_children_of_all_resources(
resources=self,
scopes={identity_id: {'region': self.scope['region'], 'identity_name': identity['name']}
for (identity_id, identity) in self.items()}
)

def _parse_identity(self, raw_identity):
identity_name, dkim_attributes = raw_identity
identity = {}
Expand Down
17 changes: 10 additions & 7 deletions ScoutSuite/providers/aws/resources/sns/topics.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,17 +12,20 @@ class Topics(AWSCompositeResources):

async def fetch_all(self, **kwargs):
raw_topics = await self.facade.sns.get_topics(self.scope['region'])
# TODO: parallelize this async loop:
for raw_topic in raw_topics:
topic_name, topic = self._parse_topic(raw_topic)
await self._fetch_children(
parent=topic,
scope={'region': self.scope['region'], 'topic_name': topic_name}
)
# Fix subscriptions count:
topic['subscriptions_count'] = topic['subscriptions'].pop('subscriptions_count')
self[topic_name] = topic

await self._fetch_children_of_all_resources(
resources=self,
scopes={topic_id: {'region': self.scope['region'], 'topic_name': topic['name']}
for (topic_id, topic) in self.items()}
)

# Fix subscriptions count:
for topic in self.values():
topic['subscriptions_count'] = topic['subscriptions'].pop('subscriptions_count')

def _parse_topic(self, raw_topic):
raw_topic['arn'] = raw_topic.pop('TopicArn')
raw_topic['name'] = raw_topic['arn'].split(':')[-1]
Expand Down
27 changes: 10 additions & 17 deletions ScoutSuite/providers/aws/resources/vpcs.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,3 @@
import asyncio

from ScoutSuite.providers.aws.resources.resources import AWSCompositeResources


Expand All @@ -14,21 +12,16 @@ def __init__(self, facade, scope: dict, add_ec2_classic=False):
self.add_ec2_classic = add_ec2_classic

async def fetch_all(self, **kwargs):
vpcs = await self.facade.ec2.get_vpcs(self.scope['region'])
for vpc in vpcs:
name, resource = self._parse_vpc(vpc)
self[name] = resource

if len(self) == 0:
return

tasks = {
asyncio.ensure_future(
self._fetch_children(self[vpc], {'region': self.scope['region'], 'vpc': vpc})
) for vpc in self
}

await asyncio.wait(tasks)
raw_vpcs = await self.facade.ec2.get_vpcs(self.scope['region'])
for raw_vpc in raw_vpcs:
vpc_id, vpc = self._parse_vpc(raw_vpc)
self[vpc_id] = vpc

await self._fetch_children_of_all_resources(
resources=self,
scopes={vpc_id: {'region': self.scope['region'], 'vpc': vpc_id}
for vpc_id in self}
)

def _parse_vpc(self, vpc):
return vpc['VpcId'], {}
2 changes: 1 addition & 1 deletion ScoutSuite/providers/azure/resources/network/networks.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ class Networks(AzureCompositeResources):
]

async def fetch_all(self, credentials, **kwargs):
await self._fetch_children(parent=self, facade=self.facade)
await self._fetch_children(resource_parent=self, facade=self.facade)

self['network_security_groups_count'] = len(self['network_security_groups'])
self['network_watchers_count'] = len(self['network_watchers'])
49 changes: 42 additions & 7 deletions ScoutSuite/providers/azure/resources/resources.py
Original file line number Diff line number Diff line change
@@ -1,23 +1,58 @@
"""This module provides implementations for Resources and CompositeResources for Azure."""

import abc
import asyncio

from ScoutSuite.providers.base.configs.resources import CompositeResources
from ScoutSuite.providers.azure.facade.facade import AzureFacade


# TODO: add docstrings.
class AzureCompositeResources(CompositeResources, metaclass=abc.ABCMeta):

"""This class represents a collection of composite Resources (resources that include nested resources referred as
their children). Classes extending AzureCompositeResources have to define a '_children' attribute which consists of
a list of tuples describing the children. The tuples are expected to respect the following format:
(<child_class>, <child_name>). 'child_name' is used to indicate the name under which the child resources will be
stored in the parent object.
"""

def __init__(self, facade: AzureFacade):
self.facade = facade

async def _fetch_children_of_all_resources(self, resources: dict, kwargs: dict):
"""This method iterates through a collection of resources and fetches all children of each resource, in a
concurrent way.
:param resources: list of (composite) resources
:param kwargs: dict that maps resource parent keys to each kwargs (dict) used to retrieve child resources.
"""
if len(resources) == 0:
return

tasks = {
asyncio.ensure_future(
self._fetch_children(resource_parent=resource_parent, **kwargs[resource_parent_key])
) for (resource_parent_key, resource_parent) in resources.items()
}
await asyncio.wait(tasks)

async def _fetch_children(self, resource_parent, **kwargs):
"""This method fetches all children of a given resource (the so called 'resource_parent') by calling fetch_all
method on each child defined in '_children' and then stores the fetched resources in `resource_parent` under
the key associated with the child.
:param resource_parent: The resource in which the children will be stored.
:param kwargs: parameters that depend on the type of child resources, used to fetch them.
"""

async def _fetch_children(self, parent, **kwargs):
children = [(child_class(**kwargs), child_name) for (child_class, child_name) in self._children]
# fetch all children concurrently:
await asyncio.wait({asyncio.ensure_future(child.fetch_all()) for (child, _) in children})
# update parent content:
# Fetch all children concurrently:
await asyncio.wait(
{asyncio.ensure_future(child.fetch_all()) for (child, _) in children}
)
# Update parent content:
for child, name in children:
if name:
parent[name] = child
resource_parent[name] = child
else:
parent.update(child)
resource_parent.update(child)
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ class SecurityCenter(AzureCompositeResources):
]

async def fetch_all(self, credentials, **kwargs):
await self._fetch_children(parent=self, facade=self.facade)
await self._fetch_children(resource_parent=self, facade=self.facade)

self['auto_provisioning_settings_count'] = len(
self['auto_provisioning_settings'])
Expand Down
Loading

0 comments on commit 7694eae

Please sign in to comment.