diff --git a/Scout.py b/Scout.py index 83cb235dd..1f1800adf 100755 --- a/Scout.py +++ b/Scout.py @@ -4,10 +4,14 @@ import sys import asyncio +from concurrent.futures import ThreadPoolExecutor + from ScoutSuite.__main__ import main if __name__ == "__main__": loop = asyncio.get_event_loop() + # TODO: make max_workers parameterizable (through the thread_config cli argument) + loop.set_default_executor(ThreadPoolExecutor(max_workers=10)) loop.run_until_complete(main()) loop.close() sys.exit() diff --git a/ScoutSuite/providers/aws/facade/awslambda.py b/ScoutSuite/providers/aws/facade/awslambda.py index 04c3649d3..ed5d51262 100644 --- a/ScoutSuite/providers/aws/facade/awslambda.py +++ b/ScoutSuite/providers/aws/facade/awslambda.py @@ -2,5 +2,5 @@ class LambdaFacade: - def get_functions(self, region): - return AWSFacadeUtils.get_all_pages('lambda', region, 'list_functions', 'Functions') + async def get_functions(self, region): + return await AWSFacadeUtils.get_all_pages('lambda', region, 'list_functions', 'Functions') diff --git a/ScoutSuite/providers/aws/facade/cloudtrail.py b/ScoutSuite/providers/aws/facade/cloudtrail.py index 8836e7f63..58bafc2c3 100644 --- a/ScoutSuite/providers/aws/facade/cloudtrail.py +++ b/ScoutSuite/providers/aws/facade/cloudtrail.py @@ -1,13 +1,20 @@ from ScoutSuite.providers.aws.facade.utils import AWSFacadeUtils +from ScoutSuite.providers.utils import run_concurrently class CloudTrailFacade: - def get_trails(self, region): + async def get_trails(self, region): client = AWSFacadeUtils.get_client('cloudtrail', region) - trails = client.describe_trails()['trailList'] + trails = await run_concurrently( + lambda: client.describe_trails()['trailList'] + ) for trail in trails: - trail.update(client.get_trail_status(Name=trail['TrailARN'])) - trail['EventSelectors'] = client.get_event_selectors(TrailName=trail['TrailARN'])['EventSelectors'] + trail.update(await run_concurrently( + lambda: client.get_trail_status(Name=trail['TrailARN']) + )) + trail['EventSelectors'] = await run_concurrently( + lambda: client.get_event_selectors(TrailName=trail['TrailARN'])['EventSelectors'] + ) - return trails \ No newline at end of file + return trails diff --git a/ScoutSuite/providers/aws/facade/ec2.py b/ScoutSuite/providers/aws/facade/ec2.py index a884a1a70..e9939f601 100644 --- a/ScoutSuite/providers/aws/facade/ec2.py +++ b/ScoutSuite/providers/aws/facade/ec2.py @@ -1,23 +1,25 @@ import boto3 import base64 -import itertools from ScoutSuite.providers.aws.facade.utils import AWSFacadeUtils +from ScoutSuite.providers.utils import run_concurrently class EC2Facade: - def get_instance_user_data(self, region: str, instance_id: str): + async def get_instance_user_data(self, region: str, instance_id: str): ec2_client = AWSFacadeUtils.get_client('ec2', region) - user_data_response = ec2_client.describe_instance_attribute(Attribute='userData', InstanceId=instance_id) + user_data_response = await run_concurrently( + lambda: ec2_client.describe_instance_attribute(Attribute='userData', InstanceId=instance_id)) if 'Value' not in user_data_response['UserData'].keys(): return None return base64.b64decode(user_data_response['UserData']['Value']).decode('utf-8') - def get_instances(self, region, vpc): + async def get_instances(self, region, vpc): filters = [{'Name': 'vpc-id', 'Values': [vpc]}] - reservations = AWSFacadeUtils.get_all_pages('ec2', region, 'describe_instances', 'Reservations', Filters=filters) + reservations =\ + await AWSFacadeUtils.get_all_pages('ec2', region, 'describe_instances', 'Reservations', Filters=filters) instances = [] for reservation in reservations: @@ -27,36 +29,39 @@ def get_instances(self, region, vpc): return instances - def get_security_groups(self, region, vpc): + async def get_security_groups(self, region, vpc): filters = [{'Name': 'vpc-id', 'Values': [vpc]}] - return AWSFacadeUtils.get_all_pages('ec2', region, 'describe_security_groups', 'SecurityGroups', Filters=filters) + return await AWSFacadeUtils.get_all_pages( + 'ec2', region, 'describe_security_groups', 'SecurityGroups', Filters=filters) - def get_vpcs(self, region): - ec2_client = boto3.client('ec2', region_name=region) - return ec2_client.describe_vpcs()['Vpcs'] + async def get_vpcs(self, region): + ec2_client = await run_concurrently(lambda: boto3.client('ec2', region_name=region)) + return await run_concurrently(lambda: ec2_client.describe_vpcs()['Vpcs']) - def get_images(self, region, owner_id): + async def get_images(self, region, owner_id): filters = [{'Name': 'owner-id', 'Values': [owner_id]}] - response = AWSFacadeUtils.get_client('ec2', region) \ - .describe_images(Filters=filters) + client = AWSFacadeUtils.get_client('ec2', region) + response = await run_concurrently(lambda: client.describe_images(Filters=filters)) return response['Images'] - def get_network_interfaces(self, region, vpc): + async def get_network_interfaces(self, region, vpc): filters = [{'Name': 'vpc-id', 'Values': [vpc]}] - return AWSFacadeUtils.get_all_pages('ec2', region, 'describe_network_interfaces', 'NetworkInterfaces', Filters=filters) + return await AWSFacadeUtils.get_all_pages( + 'ec2', region, 'describe_network_interfaces', 'NetworkInterfaces', Filters=filters) - def get_volumes(self, region): - return AWSFacadeUtils.get_all_pages('ec2', region, 'describe_volumes', 'Volumes') + async def get_volumes(self, region): + return await AWSFacadeUtils.get_all_pages('ec2', region, 'describe_volumes', 'Volumes') - def get_snapshots(self, region, owner_id): + async def get_snapshots(self, region, owner_id): filters = [{'Name': 'owner-id', 'Values': [owner_id]}] - snapshots = AWSFacadeUtils.get_all_pages('ec2', region, 'describe_snapshots', 'Snapshots', Filters=filters) + snapshots = await AWSFacadeUtils.get_all_pages( + 'ec2', region, 'describe_snapshots', 'Snapshots', Filters=filters) ec2_client = AWSFacadeUtils.get_client('ec2', region) for snapshot in snapshots: - snapshot['CreateVolumePermissions'] = ec2_client.describe_snapshot_attribute( + snapshot['CreateVolumePermissions'] = await run_concurrently(lambda: ec2_client.describe_snapshot_attribute( Attribute='createVolumePermission', - SnapshotId=snapshot['SnapshotId'])['CreateVolumePermissions'] + SnapshotId=snapshot['SnapshotId'])['CreateVolumePermissions']) return snapshots diff --git a/ScoutSuite/providers/aws/facade/facade.py b/ScoutSuite/providers/aws/facade/facade.py index b4404d29b..0daf11683 100644 --- a/ScoutSuite/providers/aws/facade/facade.py +++ b/ScoutSuite/providers/aws/facade/facade.py @@ -4,24 +4,25 @@ from ScoutSuite.providers.aws.facade.cloudtrail import CloudTrailFacade from ScoutSuite.providers.aws.facade.ec2 import EC2Facade from ScoutSuite.providers.aws.facade.efs import EFSFacade +from ScoutSuite.providers.utils import run_concurrently class AWSFacade(object): + def __init__(self): self.ec2 = EC2Facade() self.awslambda = LambdaFacade() self.cloudtrail = CloudTrailFacade() self.efs = EFSFacade() - async def build_region_list(self, service: str, chosen_regions=None, partition_name='aws'): service = 'ec2containerservice' if service == 'ecs' else service - available_services = Session().get_available_services() + available_services = await run_concurrently(lambda: Session().get_available_services()) if service not in available_services: raise Exception('Service ' + service + ' is not available.') - regions = Session().get_available_regions(service, partition_name) + regions = await run_concurrently(lambda: Session().get_available_regions(service, partition_name)) if chosen_regions: return list((Counter(regions) & Counter(chosen_regions)).elements()) diff --git a/ScoutSuite/providers/aws/facade/utils.py b/ScoutSuite/providers/aws/facade/utils.py index 85ffb2bc0..6034494cb 100644 --- a/ScoutSuite/providers/aws/facade/utils.py +++ b/ScoutSuite/providers/aws/facade/utils.py @@ -1,26 +1,32 @@ -from typing import Callable import boto3 +from ScoutSuite.providers.utils import run_concurrently + + # TODO: Add docstrings class AWSFacadeUtils: _clients = {} - + @staticmethod - def get_all_pages(service: str, region: str, paginator_name: str, response_key: str, **paginator_args): - pages = AWSFacadeUtils.get_client(service, region) \ - .get_paginator(paginator_name) \ - .paginate(**paginator_args) - - return AWSFacadeUtils._get_from_all_pages(pages, response_key) + async def get_all_pages(service: str, region: str, paginator_name: str, response_key: str, **paginator_args): + client = AWSFacadeUtils.get_client(service, region) + # Building a paginator doesn't require any API call so no need to do it concurrently: + paginator = client.get_paginator(paginator_name).paginate(**paginator_args) + + # Getting all pages from a paginator requires API calls so we need to do it concurrently: + return await run_concurrently(lambda: AWSFacadeUtils._get_all_pages_from_paginator(paginator, response_key)) @staticmethod - def _get_from_all_pages(pages: [], key:str): + def _get_all_pages_from_paginator(paginator, key): resources = [] - for page in pages: + # There's an API call hidden behind each iteration: + for page in paginator: resources.extend(page[key]) return resources @staticmethod def get_client(service: str, region: str): - return AWSFacadeUtils._clients.setdefault((service, region), boto3.client(service, region_name=region)) + # TODO: investigate the use of a mutex to avoid useless creation of a same type of client among threads: + client = boto3.client(service, region_name=region) + return AWSFacadeUtils._clients.setdefault((service, region), client) diff --git a/ScoutSuite/providers/aws/resources/awslambda/service.py b/ScoutSuite/providers/aws/resources/awslambda/service.py index 7ec5fbb35..f50c4cbc1 100644 --- a/ScoutSuite/providers/aws/resources/awslambda/service.py +++ b/ScoutSuite/providers/aws/resources/awslambda/service.py @@ -1,18 +1,17 @@ from ScoutSuite.providers.aws.resources.regions import Regions from ScoutSuite.providers.aws.resources.resources import AWSResources -from ScoutSuite.providers.aws.facade.facade import AWSFacade class RegionalLambdas(AWSResources): async def fetch_all(self, **kwargs): - raw_functions = self.facade.awslambda.get_functions(self.scope['region']) + raw_functions = await self.facade.awslambda.get_functions(self.scope['region']) for raw_function in raw_functions: name, resource = self._parse_function(raw_function) self[name] = resource def _parse_function(self, raw_function): raw_function['name'] = raw_function.pop('FunctionName') - return (raw_function['name'], raw_function) + return raw_function['name'], raw_function class Lambdas(Regions): diff --git a/ScoutSuite/providers/aws/resources/cloudtrail/service.py b/ScoutSuite/providers/aws/resources/cloudtrail/service.py index 206ebd5cb..a2a962c12 100644 --- a/ScoutSuite/providers/aws/resources/cloudtrail/service.py +++ b/ScoutSuite/providers/aws/resources/cloudtrail/service.py @@ -1,13 +1,13 @@ from ScoutSuite.providers.aws.resources.regions import Regions from ScoutSuite.providers.aws.resources.resources import AWSResources -from ScoutSuite.providers.aws.facade.facade import AWSFacade from ScoutSuite.providers.utils import get_non_provider_id import time + class Trails(AWSResources): async def fetch_all(self, **kwargs): - raw_trails = self.facade.cloudtrail.get_trails(self.scope['region']) + raw_trails = await self.facade.cloudtrail.get_trails(self.scope['region']) for raw_trail in raw_trails: name, resource = self._parse_trail(raw_trail) self[name] = resource diff --git a/ScoutSuite/providers/aws/resources/ec2/ami.py b/ScoutSuite/providers/aws/resources/ec2/ami.py index 7ba4c7f6a..f5312b8ca 100644 --- a/ScoutSuite/providers/aws/resources/ec2/ami.py +++ b/ScoutSuite/providers/aws/resources/ec2/ami.py @@ -1,10 +1,9 @@ from ScoutSuite.providers.aws.resources.resources import AWSResources -from ScoutSuite.providers.aws.facade.facade import AWSFacade class AmazonMachineImages(AWSResources): async def fetch_all(self, **kwargs): - raw_images = self.facade.ec2.get_images(self.scope['region'], self.scope['owner_id']) + raw_images = await self.facade.ec2.get_images(self.scope['region'], self.scope['owner_id']) for raw_image in raw_images: name, resource = self._parse_image(raw_image) self[name] = resource diff --git a/ScoutSuite/providers/aws/resources/ec2/instances.py b/ScoutSuite/providers/aws/resources/ec2/instances.py index c5b9bf709..f173cd63f 100644 --- a/ScoutSuite/providers/aws/resources/ec2/instances.py +++ b/ScoutSuite/providers/aws/resources/ec2/instances.py @@ -1,23 +1,22 @@ from ScoutSuite.providers.aws.resources.resources import AWSResources -from ScoutSuite.providers.aws.facade.facade import AWSFacade from ScoutSuite.providers.aws.aws import get_name -from ScoutSuite.providers.aws.utils import ec2_classic, get_keys +from ScoutSuite.providers.aws.utils import get_keys class EC2Instances(AWSResources): async def fetch_all(self, **kwargs): - raw_instances = self.facade.ec2.get_instances(self.scope['region'], self.scope['vpc']) + raw_instances = await self.facade.ec2.get_instances(self.scope['region'], self.scope['vpc']) for raw_instance in raw_instances: - name, resource = self._parse_instance(raw_instance) + name, resource = await self._parse_instance(raw_instance) self[name] = resource - def _parse_instance(self, raw_instance): + async def _parse_instance(self, raw_instance): instance = {} id = raw_instance['InstanceId'] instance['id'] = id instance['reservation_id'] = raw_instance['ReservationId'] instance['monitoring_enabled'] = raw_instance['Monitoring']['State'] == 'enabled' - instance['user_data'] = self.facade.ec2.get_instance_user_data(self.scope['region'], id) + instance['user_data'] = await self.facade.ec2.get_instance_user_data(self.scope['region'], id) get_name(raw_instance, instance, 'InstanceId') get_keys(raw_instance, instance, ['KeyName', 'LaunchTime', 'InstanceType', 'State', 'IamInstanceProfile', 'SubnetId']) diff --git a/ScoutSuite/providers/aws/resources/ec2/networkinterfaces.py b/ScoutSuite/providers/aws/resources/ec2/networkinterfaces.py index 63408ddab..ce5bc3a03 100644 --- a/ScoutSuite/providers/aws/resources/ec2/networkinterfaces.py +++ b/ScoutSuite/providers/aws/resources/ec2/networkinterfaces.py @@ -1,14 +1,13 @@ from ScoutSuite.providers.aws.resources.resources import AWSResources -from ScoutSuite.providers.aws.facade.facade import AWSFacade class NetworkInterfaces(AWSResources): async def fetch_all(self, **kwargs): - raw_security_groups = self.facade.ec2.get_network_interfaces(self.scope['region'], self.scope['vpc']) + raw_security_groups = await self.facade.ec2.get_network_interfaces(self.scope['region'], self.scope['vpc']) for raw_security_groups in raw_security_groups: name, resource = self._parse_network_interface(raw_security_groups) self[name] = resource - def _parse_network_interface(self, raw_network_interace): - raw_network_interace['name'] = raw_network_interace['NetworkInterfaceId'] - return raw_network_interace['NetworkInterfaceId'], raw_network_interace + def _parse_network_interface(self, raw_network_interface): + raw_network_interface['name'] = raw_network_interface['NetworkInterfaceId'] + return raw_network_interface['NetworkInterfaceId'], raw_network_interface diff --git a/ScoutSuite/providers/aws/resources/ec2/securitygroups.py b/ScoutSuite/providers/aws/resources/ec2/securitygroups.py index c59b87129..d5009f407 100644 --- a/ScoutSuite/providers/aws/resources/ec2/securitygroups.py +++ b/ScoutSuite/providers/aws/resources/ec2/securitygroups.py @@ -1,7 +1,4 @@ from ScoutSuite.providers.aws.resources.resources import AWSResources -from ScoutSuite.providers.aws.facade.facade import AWSFacade -from ScoutSuite.providers.aws.aws import get_name -from ScoutSuite.providers.aws.utils import ec2_classic, get_keys from ScoutSuite.utils import manage_dictionary from ScoutSuite.core.fs import load_data @@ -10,7 +7,7 @@ class SecurityGroups(AWSResources): icmp_message_types_dict = load_data('icmp_message_types.json', 'icmp_message_types') async def fetch_all(self, **kwargs): - raw_security_groups = self.facade.ec2.get_security_groups(self.scope['region'], self.scope['vpc']) + raw_security_groups = await self.facade.ec2.get_security_groups(self.scope['region'], self.scope['vpc']) for raw_security_groups in raw_security_groups: name, resource = self._parse_security_group(raw_security_groups) self[name] = resource diff --git a/ScoutSuite/providers/aws/resources/ec2/snapshots.py b/ScoutSuite/providers/aws/resources/ec2/snapshots.py index 1606bb8b9..9c36ca5ce 100644 --- a/ScoutSuite/providers/aws/resources/ec2/snapshots.py +++ b/ScoutSuite/providers/aws/resources/ec2/snapshots.py @@ -1,18 +1,14 @@ from ScoutSuite.providers.aws.resources.resources import AWSResources -from ScoutSuite.providers.aws.facade.facade import AWSFacade from ScoutSuite.providers.aws.aws import get_name class Snapshots(AWSResources): async def fetch_all(self, **kwargs): - raw_snapshots = self.facade.ec2.get_snapshots(self.scope['region'], self.scope['owner_id']) + raw_snapshots = await self.facade.ec2.get_snapshots(self.scope['region'], self.scope['owner_id']) for raw_snapshot in raw_snapshots: name, resource = self._parse_snapshot(raw_snapshot) self[name] = resource - async def get_resources_from_api(self): - return self.facade.ec2.get_snapshots(self.scope['region'], self.scope['owner_id']) - def _parse_snapshot(self, raw_snapshot): raw_snapshot['id'] = raw_snapshot.pop('SnapshotId') raw_snapshot['name'] = get_name(raw_snapshot, raw_snapshot, 'id') diff --git a/ScoutSuite/providers/aws/resources/ec2/volumes.py b/ScoutSuite/providers/aws/resources/ec2/volumes.py index 16970fcf5..c36367b8a 100644 --- a/ScoutSuite/providers/aws/resources/ec2/volumes.py +++ b/ScoutSuite/providers/aws/resources/ec2/volumes.py @@ -1,16 +1,15 @@ from ScoutSuite.providers.aws.resources.resources import AWSResources -from ScoutSuite.providers.aws.facade.facade import AWSFacade from ScoutSuite.providers.aws.aws import get_name class Volumes(AWSResources): async def fetch_all(self, **kwargs): - raw_volumes = self.facade.ec2.get_volumes(self.scope['region']) + raw_volumes = await self.facade.ec2.get_volumes(self.scope['region']) for raw_volume in raw_volumes: - name, resource = self._parse_volumes(raw_volume) + name, resource = self._parse_volume(raw_volume) self[name] = resource - def _parse_volumes(self, raw_volume): + def _parse_volume(self, raw_volume): raw_volume['id'] = raw_volume.pop('VolumeId') raw_volume['name'] = get_name(raw_volume, raw_volume, 'id') return raw_volume['id'], raw_volume diff --git a/ScoutSuite/providers/aws/resources/ec2/vpcs.py b/ScoutSuite/providers/aws/resources/ec2/vpcs.py index 7b5c6ceeb..b8b520059 100644 --- a/ScoutSuite/providers/aws/resources/ec2/vpcs.py +++ b/ScoutSuite/providers/aws/resources/ec2/vpcs.py @@ -1,5 +1,6 @@ +import asyncio + from ScoutSuite.providers.aws.resources.resources import AWSCompositeResources -from ScoutSuite.providers.aws.facade.facade import AWSFacade from ScoutSuite.providers.aws.resources.ec2.instances import EC2Instances from ScoutSuite.providers.aws.resources.ec2.securitygroups import SecurityGroups from ScoutSuite.providers.aws.resources.ec2.networkinterfaces import NetworkInterfaces @@ -13,14 +14,23 @@ class Vpcs(AWSCompositeResources): ] async def fetch_all(self, **kwargs): - vpcs = self.facade.ec2.get_vpcs(self.scope['region']) + vpcs = await self.facade.ec2.get_vpcs(self.scope['region']) for vpc in vpcs: name, resource = self._parse_vpc(vpc) self[name] = resource - for vpc in self: - scope = {'region': self.scope['region'], 'vpc': vpc} - await self._fetch_children(self[vpc], scope=scope) + # TODO: make a refactoring of the following: + if len(self) == 0: + return + tasks = { + asyncio.ensure_future( + self._fetch_children( + self[vpc], + {'region': self.scope['region'], 'vpc': vpc} + ) + ) for vpc in self + } + await asyncio.wait(tasks) def _parse_vpc(self, vpc): return vpc['VpcId'], {} diff --git a/ScoutSuite/providers/aws/resources/regions.py b/ScoutSuite/providers/aws/resources/regions.py index 1f0b3a7c0..faf59865f 100644 --- a/ScoutSuite/providers/aws/resources/regions.py +++ b/ScoutSuite/providers/aws/resources/regions.py @@ -1,7 +1,10 @@ +import abc +import asyncio + from ScoutSuite.providers.aws.aws import get_aws_account_id from ScoutSuite.providers.aws.resources.resources import AWSCompositeResources from ScoutSuite.providers.aws.facade.facade import AWSFacade -import abc + class Regions(AWSCompositeResources, metaclass=abc.ABCMeta): def __init__(self, service): @@ -10,8 +13,8 @@ def __init__(self, service): self.facade = AWSFacade() async def fetch_all(self, credentials, regions=None, partition_name='aws'): - self['regions'] = {} + account_id = get_aws_account_id(credentials) for region in await self.facade.build_region_list(self.service, regions, partition_name): self['regions'][region] = { 'id': region, @@ -19,7 +22,18 @@ async def fetch_all(self, credentials, regions=None, partition_name='aws'): 'name': region } - await self._fetch_children(self['regions'][region], {'region': region, 'owner_id': get_aws_account_id(credentials)}) + # TODO: make a refactoring of the following: + if len(self['regions']) == 0: + return + tasks = { + asyncio.ensure_future( + self._fetch_children( + self['regions'][region], + {'region': region, 'owner_id': account_id} + ) + ) for region in self['regions'] + } + await asyncio.wait(tasks) self._set_counts() diff --git a/ScoutSuite/providers/aws/resources/resources.py b/ScoutSuite/providers/aws/resources/resources.py index 1741885ff..95361a738 100644 --- a/ScoutSuite/providers/aws/resources/resources.py +++ b/ScoutSuite/providers/aws/resources/resources.py @@ -1,9 +1,10 @@ - """This module provides implementations for Resources and CompositeResources for AWS.""" +import abc +import asyncio + from ScoutSuite.providers.base.configs.resources import Resources, CompositeResources from ScoutSuite.providers.aws.facade.facade import AWSFacade -import abc class AWSResources(Resources, metaclass=abc.ABCMeta): @@ -38,12 +39,14 @@ async def _fetch_children(self, parent: object, scope: dict): :param scope: The scope passed to the children constructors """ - for child_class, child_name in self._children: - child = child_class(scope) - await child.fetch_all() - + children = [(child_class(scope), child_name) for (child_class, child_name) in self._children] + # fetch all children concurrently: + await asyncio.wait({asyncio.ensure_future(child.fetch_all()) for (child, _) in children}) + # update parent content: + for child, child_name in children: if parent.get(child_name) is None: parent[child_name] = {} parent[child_name].update(child) parent[child_name + '_count'] = len(child) + diff --git a/ScoutSuite/providers/azure/facade/sqldatabase.py b/ScoutSuite/providers/azure/facade/sqldatabase.py new file mode 100644 index 000000000..937585108 --- /dev/null +++ b/ScoutSuite/providers/azure/facade/sqldatabase.py @@ -0,0 +1,54 @@ +from azure.mgmt.sql import SqlManagementClient +from ScoutSuite.providers.utils import run_concurrently + + +class SQLDatabaseFacade: + def __init__(self, credentials, subscription_id): + self._client = SqlManagementClient(credentials, subscription_id) + + async def get_database_blob_auditing_policies(self, resource_group_name, server_name, database_name): + return await run_concurrently( + lambda: self._client.database_blob_auditing_policies.get( + resource_group_name, server_name, database_name) + ) + + async def get_database_threat_detection_policies(self, resource_group_name, server_name, database_name): + return await run_concurrently( + lambda: self._client.database_threat_detection_policies.get( + resource_group_name, server_name, database_name) + ) + + async def get_databases(self, resource_group_name, server_name): + return await run_concurrently( + lambda: self._client.databases.list_by_server(resource_group_name, server_name) + ) + + async def get_database_replication_links(self, resource_group_name, server_name, database_name): + return await run_concurrently( + lambda: list(self._client.replication_links.list_by_database( + resource_group_name, server_name, database_name)) + ) + + async def get_server_azure_ad_administrators(self, resource_group_name, server_name): + return await run_concurrently( + lambda: self._client.server_azure_ad_administrators.get(resource_group_name, server_name) + ) + + async def get_server_blob_auditing_policies(self, resource_group_name, server_name): + return await run_concurrently( + lambda: self._client.server_blob_auditing_policies.get(resource_group_name, server_name) + ) + + async def get_server_security_alert_policies(self, resource_group_name, server_name): + return await run_concurrently( + lambda: self._client.server_security_alert_policies.get(resource_group_name, server_name) + ) + + async def get_servers(self): + return await run_concurrently(self._client.servers.list) + + async def get_database_transparent_data_encryptions(self, resource_group_name, server_name, database_name): + return await run_concurrently( + lambda: self._client.transparent_data_encryptions.get( + resource_group_name, server_name, database_name) + ) diff --git a/ScoutSuite/providers/azure/resources/resources.py b/ScoutSuite/providers/azure/resources/resources.py index 6ea29637c..5523bd4e7 100644 --- a/ScoutSuite/providers/azure/resources/resources.py +++ b/ScoutSuite/providers/azure/resources/resources.py @@ -1,4 +1,5 @@ import abc +import asyncio from ScoutSuite.providers.base.configs.resources import CompositeResources @@ -7,9 +8,11 @@ class AzureCompositeResources(CompositeResources, metaclass=abc.ABCMeta): async def _fetch_children(self, parent, **kwargs): - for child_class, name in self._children: - child = child_class(**kwargs) - await child.fetch_all() + children = [(child_class(**kwargs), child_name) for (child_class, child_name) in self._children] + # fetch all children concurrently: + await asyncio.wait({asyncio.ensure_future(child.fetch_all()) for (child, _) in children}) + # update parent content: + for child, name in children: if name: parent[name] = child else: diff --git a/ScoutSuite/providers/azure/resources/sqldatabase/database_blob_auditing_policies.py b/ScoutSuite/providers/azure/resources/sqldatabase/database_blob_auditing_policies.py index dd28d89f4..d90f75e0f 100644 --- a/ScoutSuite/providers/azure/resources/sqldatabase/database_blob_auditing_policies.py +++ b/ScoutSuite/providers/azure/resources/sqldatabase/database_blob_auditing_policies.py @@ -9,9 +9,8 @@ def __init__(self, resource_group_name, server_name, database_name, facade): self.database_name = database_name self.facade = facade - # TODO: make it really async. async def fetch_all(self): - policies = self.facade.database_blob_auditing_policies.get( + policies = await self.facade.get_database_blob_auditing_policies( self.resource_group_name, self.server_name, self.database_name) self._parse_policies(policies) diff --git a/ScoutSuite/providers/azure/resources/sqldatabase/database_threat_detection_policies.py b/ScoutSuite/providers/azure/resources/sqldatabase/database_threat_detection_policies.py index 424ed2c1f..9fa2b3225 100644 --- a/ScoutSuite/providers/azure/resources/sqldatabase/database_threat_detection_policies.py +++ b/ScoutSuite/providers/azure/resources/sqldatabase/database_threat_detection_policies.py @@ -9,9 +9,8 @@ def __init__(self, resource_group_name, server_name, database_name, facade): self.database_name = database_name self.facade = facade - # TODO: make it really async. async def fetch_all(self): - policies = self.facade.database_threat_detection_policies.get( + policies = await self.facade.get_database_threat_detection_policies( self.resource_group_name, self.server_name, self.database_name) self._parse_policies(policies) diff --git a/ScoutSuite/providers/azure/resources/sqldatabase/databases.py b/ScoutSuite/providers/azure/resources/sqldatabase/databases.py index 81f02e272..c51de1c92 100644 --- a/ScoutSuite/providers/azure/resources/sqldatabase/databases.py +++ b/ScoutSuite/providers/azure/resources/sqldatabase/databases.py @@ -1,3 +1,5 @@ +import asyncio + from ScoutSuite.providers.azure.resources.resources import AzureCompositeResources from .database_blob_auditing_policies import DatabaseBlobAuditingPolicies @@ -19,20 +21,29 @@ def __init__(self, resource_group_name, server_name, facade): self.server_name = server_name self.facade = facade - # TODO: make it really async. async def fetch_all(self): - for db in self.facade.databases.list_by_server(self.resource_group_name, self.server_name): + for db in await self.facade.get_databases(self.resource_group_name, self.server_name): # We do not want to scan 'master' database which is auto-generated by Azure and read-only: if db.name == 'master': continue self[db.name] = { 'id': db.name, + 'name': db.name } - await self._fetch_children( - parent=self[db.name], - resource_group_name=self.resource_group_name, - server_name=self.server_name, - database_name=db.name, - facade=self.facade - ) + + # TODO: make a refactoring of the following: + if len(self) == 0: + return + tasks = { + asyncio.ensure_future( + self._fetch_children( + parent=db, + resource_group_name=self.resource_group_name, + server_name=self.server_name, + database_name=db['name'], + facade=self.facade + ) + ) for db in self.values() + } + await asyncio.wait(tasks) diff --git a/ScoutSuite/providers/azure/resources/sqldatabase/replication_links.py b/ScoutSuite/providers/azure/resources/sqldatabase/replication_links.py index 57f5bab72..bffe27ca8 100644 --- a/ScoutSuite/providers/azure/resources/sqldatabase/replication_links.py +++ b/ScoutSuite/providers/azure/resources/sqldatabase/replication_links.py @@ -9,10 +9,9 @@ def __init__(self, resource_group_name, server_name, database_name, facade): self.database_name = database_name self.facade = facade - # TODO: make it really async. async def fetch_all(self): - links = list(self.facade.replication_links.list_by_database( - self.resource_group_name, self.server_name, self.database_name)) + links = await self.facade.get_database_replication_links( + self.resource_group_name, self.server_name, self.database_name) self._parse_links(links) def _parse_links(self, links): diff --git a/ScoutSuite/providers/azure/resources/sqldatabase/server_azure_ad_administrators.py b/ScoutSuite/providers/azure/resources/sqldatabase/server_azure_ad_administrators.py index fc67436a9..1552fb98d 100644 --- a/ScoutSuite/providers/azure/resources/sqldatabase/server_azure_ad_administrators.py +++ b/ScoutSuite/providers/azure/resources/sqldatabase/server_azure_ad_administrators.py @@ -10,10 +10,9 @@ def __init__(self, resource_group_name, server_name, facade): self.server_name = server_name self.facade = facade - # TODO: make it really async. async def fetch_all(self): try: - self.facade.server_azure_ad_administrators.get(self.resource_group_name, self.server_name) + await self.facade.get_server_azure_ad_administrators(self.resource_group_name, self.server_name) self['ad_admin_configured'] = True except CloudError: # no ad admin configured returns a 404 error self['ad_admin_configured'] = False diff --git a/ScoutSuite/providers/azure/resources/sqldatabase/server_blob_auditing_policies.py b/ScoutSuite/providers/azure/resources/sqldatabase/server_blob_auditing_policies.py index 7a0a5ff55..b24e1df90 100644 --- a/ScoutSuite/providers/azure/resources/sqldatabase/server_blob_auditing_policies.py +++ b/ScoutSuite/providers/azure/resources/sqldatabase/server_blob_auditing_policies.py @@ -8,9 +8,8 @@ def __init__(self, resource_group_name, server_name, facade): self.server_name = server_name self.facade = facade - # TODO: make it really async. async def fetch_all(self): - policies = self.facade.server_blob_auditing_policies.get( + policies = await self.facade.get_server_blob_auditing_policies( self.resource_group_name, self.server_name) self._parse_policies(policies) diff --git a/ScoutSuite/providers/azure/resources/sqldatabase/server_security_alert_policies.py b/ScoutSuite/providers/azure/resources/sqldatabase/server_security_alert_policies.py index 59ef96d2a..5d34faed5 100644 --- a/ScoutSuite/providers/azure/resources/sqldatabase/server_security_alert_policies.py +++ b/ScoutSuite/providers/azure/resources/sqldatabase/server_security_alert_policies.py @@ -8,9 +8,8 @@ def __init__(self, resource_group_name, server_name, facade): self.server_name = server_name self.facade = facade - # TODO: make it really async. async def fetch_all(self): - policies = self.facade.server_security_alert_policies.get( + policies = await self.facade.get_server_security_alert_policies( self.resource_group_name, self.server_name) self._parse_policies(policies) diff --git a/ScoutSuite/providers/azure/resources/sqldatabase/servers.py b/ScoutSuite/providers/azure/resources/sqldatabase/servers.py index 1408edeaa..f2f62e807 100644 --- a/ScoutSuite/providers/azure/resources/sqldatabase/servers.py +++ b/ScoutSuite/providers/azure/resources/sqldatabase/servers.py @@ -1,8 +1,9 @@ -from azure.mgmt.sql import SqlManagementClient +import asyncio from ScoutSuite.providers.azure.resources.resources import AzureCompositeResources from ScoutSuite.providers.azure.utils import get_resource_group_name from ScoutSuite.providers.utils import get_non_provider_id +from ScoutSuite.providers.azure.facade.sqldatabase import SQLDatabaseFacade from .databases import Databases from .server_azure_ad_administrators import ServerAzureAdAdministrators @@ -18,24 +19,34 @@ class Servers(AzureCompositeResources): (ServerSecurityAlertPolicies, 'threat_detection') ] - # TODO: make it really async. async def fetch_all(self, credentials, **kwargs): # TODO: build that facade somewhere else: - facade = SqlManagementClient(credentials.credentials, credentials.subscription_id) + facade = SQLDatabaseFacade(credentials.credentials, credentials.subscription_id) self['servers'] = {} - for server in facade.servers.list(): + for server in await facade.get_servers(): id = get_non_provider_id(server.id) resource_group_name = get_resource_group_name(server.id) self['servers'][id] = { 'id': id, - 'name': server.name + 'name': server.name, + 'resource_group_name': resource_group_name } - await self._fetch_children( - parent=self['servers'][id], - resource_group_name=resource_group_name, - server_name=server.name, - facade=facade) + + # TODO: make a refactoring of the following: + if len(self['servers']) == 0: + return + tasks = { + asyncio.ensure_future( + self._fetch_children( + parent=server, + resource_group_name=server['resource_group_name'], + server_name=server['name'], + facade=facade + ) + ) for server in self['servers'].values() + } + await asyncio.wait(tasks) self['servers_count'] = len(self['servers']) diff --git a/ScoutSuite/providers/azure/resources/sqldatabase/transparent_data_encryptions.py b/ScoutSuite/providers/azure/resources/sqldatabase/transparent_data_encryptions.py index b4aa4c076..4fe946a4a 100644 --- a/ScoutSuite/providers/azure/resources/sqldatabase/transparent_data_encryptions.py +++ b/ScoutSuite/providers/azure/resources/sqldatabase/transparent_data_encryptions.py @@ -9,9 +9,8 @@ def __init__(self, resource_group_name, server_name, database_name, facade): self.database_name = database_name self.facade = facade - # TODO: make it really async. async def fetch_all(self): - encryptions = self.facade.transparent_data_encryptions.get( + encryptions = await self.facade.get_database_transparent_data_encryptions( self.resource_group_name, self.server_name, self.database_name) self._parse_encryptions(encryptions) diff --git a/ScoutSuite/providers/utils.py b/ScoutSuite/providers/utils.py index 30759aa82..b74128f6a 100644 --- a/ScoutSuite/providers/utils.py +++ b/ScoutSuite/providers/utils.py @@ -1,4 +1,5 @@ from hashlib import sha1 +import asyncio def get_non_provider_id(name): @@ -12,3 +13,7 @@ def get_non_provider_id(name): name_hash = sha1() name_hash.update(name.encode('utf-8')) return name_hash.hexdigest() + + +def run_concurrently(func): + return asyncio.get_event_loop().run_in_executor(executor=None, func=func)