diff --git a/backend/api_handler.py b/backend/api_handler.py index 2fb4f9717..74559b1ac 100644 --- a/backend/api_handler.py +++ b/backend/api_handler.py @@ -1,7 +1,6 @@ import json import logging import os -import datetime from argparse import Namespace from time import perf_counter @@ -11,16 +10,19 @@ ) from dataall.base.api import bootstrap as bootstrap_schema, get_executable_schema -from dataall.base.services.service_provider_factory import ServiceProviderFactory +from dataall.base.utils.api_handler_utils import ( + extract_groups, + attach_tenant_policy_for_groups, + check_reauth, + validate_and_block_if_maintenance_window, +) from dataall.core.tasks.service_handlers import Worker from dataall.base.aws.sqs import SqsQueue -from dataall.base.aws.parameter_store import ParameterStoreManager from dataall.base.context import set_context, dispose_context, RequestContext -from dataall.core.permissions.services.tenant_policy_service import TenantPolicyService from dataall.base.db import get_engine -from dataall.core.permissions.services.tenant_permissions import TENANT_ALL from dataall.base.loader import load_modules, ImportMode + logger = logging.getLogger() logger.setLevel(os.environ.get('LOG_LEVEL', 'INFO')) log = logging.getLogger(__name__) @@ -32,7 +34,6 @@ load_modules(modes={ImportMode.API}) SCHEMA = bootstrap_schema() TYPE_DEFS = gql(SCHEMA.gql(with_directives=False)) -REAUTH_TTL = int(os.environ.get('REAUTH_TTL', '5')) ENVNAME = os.getenv('envname', 'local') ENGINE = get_engine(envname=ENVNAME) Worker.queue = SqsQueue.send @@ -60,27 +61,6 @@ def adapted(obj, info, **kwargs): print(f'Lambda Context ' f'Initialization took: {end - start:.3f} sec') -def get_cognito_groups(claims): - if not claims: - raise ValueError( - 'Received empty claims. ' 'Please verify authorizer configuration', - claims, - ) - groups = list() - saml_groups = claims.get('custom:saml.groups', '') - if len(saml_groups): - groups: list = saml_groups.replace('[', '').replace(']', '').replace(', ', ',').split(',') - cognito_groups = claims.get('cognito:groups', '') - if len(cognito_groups): - groups.extend(cognito_groups.split(',')) - return groups - - -def get_custom_groups(user_id): - service_provider = ServiceProviderFactory.get_service_provider_instance() - return service_provider.get_groups_for_user(user_id) - - def handler(event, context): """Sample pure Lambda function @@ -131,31 +111,11 @@ def handler(event, context): if 'user_id' in event['requestContext']['authorizer']: user_id = event['requestContext']['authorizer']['user_id'] log.debug('username is %s', username) - try: - groups = [] - if os.environ.get('custom_auth', None): - groups.extend(get_custom_groups(user_id)) - else: - groups.extend(get_cognito_groups(claims)) - log.debug('groups are %s', ','.join(groups)) - with ENGINE.scoped_session() as session: - for group in groups: - policy = TenantPolicyService.find_tenant_policy(session, group, TenantPolicyService.TENANT_NAME) - if not policy: - print(f'No policy found for Team {group}. Attaching TENANT_ALL permissions') - TenantPolicyService.attach_group_tenant_policy( - session=session, - group=group, - permissions=TENANT_ALL, - tenant_name=TenantPolicyService.TENANT_NAME, - ) - - except Exception as e: - print(f'Error managing groups due to: {e}') - groups = [] - set_context(RequestContext(ENGINE, username, groups, user_id)) + groups: list = extract_groups(user_id=user_id, claims=claims) + attach_tenant_policy_for_groups(groups=groups) + set_context(RequestContext(ENGINE, username, groups, user_id)) app_context = { 'engine': ENGINE, 'username': username, @@ -163,50 +123,18 @@ def handler(event, context): 'schema': SCHEMA, } - # Determine if there are any Operations that Require ReAuth From SSM Parameter - try: - reauth_apis = ParameterStoreManager.get_parameter_value( - region=os.getenv('AWS_REGION', 'eu-west-1'), parameter_path=f'/dataall/{ENVNAME}/reauth/apis' - ).split(',') - except Exception: - log.info('No ReAuth APIs Found in SSM') - reauth_apis = None + query = json.loads(event.get('body')) + + maintenance_window_validation_response = validate_and_block_if_maintenance_window(query=query, groups=groups) + if maintenance_window_validation_response is not None: + return maintenance_window_validation_response + reauth_validation_response = check_reauth(query=query, auth_time=claims['auth_time'], username=username) + if reauth_validation_response is not None: + return reauth_validation_response + else: raise Exception(f'Could not initialize user context from event {event}') - query = json.loads(event.get('body')) - - # If The Operation is a ReAuth Operation - Ensure A Non-Expired Session or Return Error - if reauth_apis and query.get('operationName', None) in reauth_apis: - now = datetime.datetime.now(datetime.timezone.utc) - try: - auth_time_datetime = datetime.datetime.fromtimestamp(int(claims['auth_time']), tz=datetime.timezone.utc) - if auth_time_datetime + datetime.timedelta(minutes=REAUTH_TTL) < now: - raise Exception('ReAuth') - except Exception as e: - log.info(f'ReAuth Required for User {username} on Operation {query.get("operationName", "")}, Error: {e}') - response = { - 'data': {query.get('operationName', 'operation'): None}, - 'errors': [ - { - 'message': f"ReAuth Required To Perform This Action {query.get('operationName', '')}", - 'locations': None, - 'path': [query.get('operationName', '')], - 'extensions': {'code': 'REAUTH'}, - } - ], - } - return { - 'statusCode': 401, - 'headers': { - 'content-type': 'application/json', - 'Access-Control-Allow-Origin': '*', - 'Access-Control-Allow-Headers': '*', - 'Access-Control-Allow-Methods': '*', - }, - 'body': json.dumps(response), - } - success, response = graphql_sync(schema=executable_schema, data=query, context_value=app_context) dispose_context() diff --git a/backend/dataall/base/aws/parameter_store.py b/backend/dataall/base/aws/parameter_store.py index cda2acfd8..40b13eeda 100644 --- a/backend/dataall/base/aws/parameter_store.py +++ b/backend/dataall/base/aws/parameter_store.py @@ -37,6 +37,18 @@ def get_parameter_value(AwsAccountId=None, region=None, parameter_path=None): raise Exception(e) return parameter_value + @staticmethod + def get_parameters_by_path(AwsAccountId=None, region=None, parameter_path=None): + if not parameter_path: + raise Exception('Parameter name is None') + try: + parameter_values = ParameterStoreManager.client(AwsAccountId, region).get_parameters_by_path( + Path=parameter_path + )['Parameters'] + except ClientError as e: + raise Exception(e) + return parameter_values + @staticmethod def update_parameter(AwsAccountId, region, parameter_name, parameter_value): if not parameter_name: diff --git a/backend/dataall/base/utils/api_handler_utils.py b/backend/dataall/base/utils/api_handler_utils.py new file mode 100644 index 000000000..1413a151c --- /dev/null +++ b/backend/dataall/base/utils/api_handler_utils.py @@ -0,0 +1,181 @@ +import datetime +import json +import os +import logging + +from graphql import parse, utilities, OperationType, GraphQLSyntaxError +from dataall.base.aws.parameter_store import ParameterStoreManager +from dataall.base.db import get_engine +from dataall.base.services.service_provider_factory import ServiceProviderFactory +from dataall.core.permissions.services.tenant_permissions import TENANT_ALL +from dataall.core.permissions.services.tenant_policy_service import TenantPolicyService +from dataall.modules.maintenance.api.enums import MaintenanceModes, MaintenanceStatus +from dataall.modules.maintenance.services.maintenance_service import MaintenanceService +from dataall.base.config import config +from dataall.core.permissions.services.tenant_policy_service import TenantPolicyValidationService + +logger = logging.getLogger() +logger.setLevel(os.environ.get('LOG_LEVEL', 'INFO')) +log = logging.getLogger(__name__) + +ENVNAME = os.getenv('envname', 'local') +REAUTH_TTL = int(os.environ.get('REAUTH_TTL', '5')) +# ALLOWED OPERATIONS WHEN A USER IS NOT DATAALL ADMIN AND NO-ACCESS MODE IS SELECTED +MAINTENANCE_ALLOWED_OPERATIONS_WHEN_NO_ACCESS = [ + item.casefold() for item in ['getGroupsForUser', 'getMaintenanceWindowStatus'] +] +ENGINE = get_engine(envname=ENVNAME) + + +def get_cognito_groups(claims): + if not claims: + raise ValueError( + 'Received empty claims. ' 'Please verify authorizer configuration', + claims, + ) + groups = list() + saml_groups = claims.get('custom:saml.groups', '') + translation_table = str.maketrans({'[': None, ']': None, ', ': ','}) + if len(saml_groups): + groups = saml_groups.translate(translation_table).split(',') + cognito_groups = claims.get('cognito:groups', '') + if len(cognito_groups): + groups.extend(cognito_groups.split(',')) + return groups + + +def get_custom_groups(user_id): + service_provider = ServiceProviderFactory.get_service_provider_instance() + return service_provider.get_groups_for_user(user_id) + + +def send_unauthorized_response(operation='', message='', extension=None): + response = { + 'data': {operation: None}, + 'errors': [ + { + 'message': message, + 'locations': None, + 'path': [operation], + } + ], + } + if extension is not None: + response['errors'][0]['extensions'] = extension + return { + 'statusCode': 401, + 'headers': { + 'content-type': 'application/json', + 'Access-Control-Allow-Origin': '*', + 'Access-Control-Allow-Headers': '*', + 'Access-Control-Allow-Methods': '*', + }, + 'body': json.dumps(response), + } + + +def extract_groups(user_id, claims): + groups = [] + try: + if os.environ.get('custom_auth', None): + groups.extend(get_custom_groups(user_id)) + else: + groups.extend(get_cognito_groups(claims)) + log.debug('groups are %s', ','.join(groups)) + return groups + except Exception as e: + log.exception(f'Error managing groups due to: {e}') + return groups + + +def attach_tenant_policy_for_groups(groups=None): + if groups is None: + groups = [] + with ENGINE.scoped_session() as session: + for group in groups: + policy = TenantPolicyService.find_tenant_policy(session, group, TenantPolicyService.TENANT_NAME) + if not policy: + log.info(f'No policy found for Team {group}. Attaching TENANT_ALL permissions') + TenantPolicyService.attach_group_tenant_policy( + session=session, + group=group, + permissions=TENANT_ALL, + tenant_name=TenantPolicyService.TENANT_NAME, + ) + + +def check_reauth(query, auth_time, username): + # Determine if there are any Operations that Require ReAuth From SSM Parameter + try: + reauth_apis = ParameterStoreManager.get_parameter_value( + region=os.getenv('AWS_REGION', 'eu-west-1'), parameter_path=f'/dataall/{ENVNAME}/reauth/apis' + ).split(',') + except Exception: + log.info('No ReAuth APIs Found in SSM') + reauth_apis = None + + # If The Operation is a ReAuth Operation - Ensure A Non-Expired Session or Return Error + if reauth_apis and query.get('operationName', None) in reauth_apis: + now = datetime.datetime.now(datetime.timezone.utc) + try: + auth_time_datetime = datetime.datetime.fromtimestamp(int(auth_time), tz=datetime.timezone.utc) + if auth_time_datetime + datetime.timedelta(minutes=REAUTH_TTL) < now: + raise Exception('ReAuth') + except Exception as e: + log.info(f'ReAuth Required for User {username} on Operation {query.get("operationName", "")}, Error: {e}') + return send_unauthorized_response( + operation=query.get('operationName', 'operation'), + message=f"ReAuth Required To Perform This Action {query.get('operationName', '')}", + extension={'code': 'REAUTH'}, + ) + + +def validate_and_block_if_maintenance_window(query, groups, blocked_for_mode_enum=None): + """ + When the maintenance module is set to active, checks + - If the maintenance mode is enabled + - Based on the maintenance mode, actions which can be taken by user can be modified + - READ-ONLY -> Block All Mutation calls and allow query graphql calls + - NO-ACCESS -> Block All graphql query call irrespective of type + - Check if the user belongs to the DAAdministrators group + @param query: graphql query dict containing operation, query, variables + @param groups: user groups + @param blocked_for_mode_enum: sets the mode for blocking only specific modes. When set to None, both graphql types ( Query and Mutation ) will be blocked. When a specific mode is set, blocking will only occure for that mode + @return: error response if maintenance window is blocking gql calls else None + """ + if config.get_property('modules.maintenance.active'): + maintenance_mode = MaintenanceService._get_maintenance_window_mode(engine=ENGINE) + maintenance_status = MaintenanceService.get_maintenance_window_status().status + isAdmin = TenantPolicyValidationService.is_tenant_admin(groups) + + if ( + (maintenance_mode == MaintenanceModes.NOACCESS.value) + and (maintenance_status is not MaintenanceStatus.INACTIVE.value) + and not isAdmin + and (blocked_for_mode_enum is None or blocked_for_mode_enum == MaintenanceModes.NOACCESS) + ): + if query.get('operationName', '').casefold() not in MAINTENANCE_ALLOWED_OPERATIONS_WHEN_NO_ACCESS: + return send_unauthorized_response( + operation=query.get('operationName', 'operation'), + message='Access Restricted: data.all is currently undergoing maintenance, and your actions are temporarily blocked.', + ) + elif ( + (maintenance_mode == MaintenanceModes.READONLY.value) + and (maintenance_status is not MaintenanceStatus.INACTIVE.value) + and not isAdmin + and (blocked_for_mode_enum is None or blocked_for_mode_enum == MaintenanceModes.READONLY) + ): + # If its mutation then block and return + try: + parsed_query_document = parse(query.get('query', '')) + graphQL_operation_type = utilities.get_operation_ast(parsed_query_document) + if graphQL_operation_type.operation == OperationType.MUTATION: + return send_unauthorized_response( + operation=query.get('operationName', 'operation'), + message='Access Restricted: data.all is currently undergoing maintenance, and your actions are temporarily blocked.', + ) + except GraphQLSyntaxError as e: + log.error( + f'Error occured while parsing query when validating for {maintenance_mode} maintenance mode due to - {e}' + ) + raise e diff --git a/backend/dataall/core/stacks/aws/ecs.py b/backend/dataall/core/stacks/aws/ecs.py index 8da1e1118..866d21db5 100644 --- a/backend/dataall/core/stacks/aws/ecs.py +++ b/backend/dataall/core/stacks/aws/ecs.py @@ -86,10 +86,13 @@ def run_ecs_task( raise e @staticmethod - def is_task_running(cluster_name, started_by): + def is_task_running(cluster_name, started_by=None): try: client = boto3.client('ecs') - running_tasks = client.list_tasks(cluster=cluster_name, startedBy=started_by, desiredStatus='RUNNING') + if started_by is None: + running_tasks = client.list_tasks(cluster=cluster_name, desiredStatus='RUNNING') + else: + running_tasks = client.list_tasks(cluster=cluster_name, startedBy=started_by, desiredStatus='RUNNING') if running_tasks and running_tasks.get('taskArns'): return True return False diff --git a/backend/dataall/modules/maintenance/__init__.py b/backend/dataall/modules/maintenance/__init__.py new file mode 100644 index 000000000..29596caf8 --- /dev/null +++ b/backend/dataall/modules/maintenance/__init__.py @@ -0,0 +1,21 @@ +"""Contains the code related to Maintenance Window Activity""" + +import logging +from typing import Set + +from dataall.base.loader import ImportMode, ModuleInterface + +log = logging.getLogger(__name__) + + +class MaintenanceApiModuleInterface(ModuleInterface): + """Implements ModuleInterface for Maintenance GraphQl lambda""" + + @staticmethod + def is_supported(modes: Set[ImportMode]) -> bool: + return ImportMode.API in modes + + def __init__(self): + import dataall.modules.maintenance.api + + log.info('API of maintenance window activity has been imported') diff --git a/backend/dataall/modules/maintenance/api/__init__.py b/backend/dataall/modules/maintenance/api/__init__.py new file mode 100644 index 000000000..a2690b057 --- /dev/null +++ b/backend/dataall/modules/maintenance/api/__init__.py @@ -0,0 +1,5 @@ +"""The package defines the schema for Maintenance Module""" + +from dataall.modules.maintenance.api import mutations, queries, types, resolvers, enums + +__all__ = ['types', 'queries', 'mutations', 'resolvers', 'enums'] diff --git a/backend/dataall/modules/maintenance/api/enums.py b/backend/dataall/modules/maintenance/api/enums.py new file mode 100644 index 000000000..36c328913 --- /dev/null +++ b/backend/dataall/modules/maintenance/api/enums.py @@ -0,0 +1,18 @@ +"""Contains the enums used in maintenance module""" + +from dataall.base.api import GraphQLEnumMapper + + +class MaintenanceModes(GraphQLEnumMapper): + """Describes the Maintenance Modes""" + + READONLY = 'READ-ONLY' + NOACCESS = 'NO-ACCESS' + + +class MaintenanceStatus(GraphQLEnumMapper): + """Describe the various statuses for maintenance""" + + PENDING = 'PENDING' + INACTIVE = 'INACTIVE' + ACTIVE = 'ACTIVE' diff --git a/backend/dataall/modules/maintenance/api/mutations.py b/backend/dataall/modules/maintenance/api/mutations.py new file mode 100644 index 000000000..da83e9e50 --- /dev/null +++ b/backend/dataall/modules/maintenance/api/mutations.py @@ -0,0 +1,18 @@ +"""The module defines GraphQL mutations for the Maintenance Window Activity""" + +from dataall.base.api import gql +from dataall.modules.maintenance.api.resolvers import start_maintenance_window, stop_maintenance_window + + +startMaintenanceWindow = gql.MutationField( + name='startMaintenanceWindow', + args=[gql.Argument(name='mode', type=gql.String)], + type=gql.Boolean, + resolver=start_maintenance_window, +) + +stopMaintenanceWindow = gql.MutationField( + name='stopMaintenanceWindow', + type=gql.Boolean, + resolver=stop_maintenance_window, +) diff --git a/backend/dataall/modules/maintenance/api/queries.py b/backend/dataall/modules/maintenance/api/queries.py new file mode 100644 index 000000000..11c623cf0 --- /dev/null +++ b/backend/dataall/modules/maintenance/api/queries.py @@ -0,0 +1,9 @@ +"""The module defines GraphQL queries for the Maintenance Activity>""" + +from dataall.base.api import gql +from dataall.modules.maintenance.api.resolvers import get_maintenance_window_status + + +getMaintenanceWindowStatus = gql.QueryField( + name='getMaintenanceWindowStatus', type=gql.Ref('Maintenance'), resolver=get_maintenance_window_status +) diff --git a/backend/dataall/modules/maintenance/api/resolvers.py b/backend/dataall/modules/maintenance/api/resolvers.py new file mode 100644 index 000000000..88597f324 --- /dev/null +++ b/backend/dataall/modules/maintenance/api/resolvers.py @@ -0,0 +1,18 @@ +from dataall.base.api.context import Context +from dataall.modules.maintenance.api.enums import MaintenanceModes +from dataall.modules.maintenance.api.types import Maintenance +from dataall.modules.maintenance.services.maintenance_service import MaintenanceService + + +def start_maintenance_window(context: Context, source: Maintenance, mode: str): + if mode not in [item.value for item in list(MaintenanceModes)]: + raise Exception('Mode is not conforming to the MaintenanceModes enum') + return MaintenanceService.start_maintenance_window(mode=mode) + + +def stop_maintenance_window(context: Context, source: Maintenance): + return MaintenanceService.stop_maintenance_window() + + +def get_maintenance_window_status(context: Context, source: Maintenance): + return MaintenanceService.get_maintenance_window_status() diff --git a/backend/dataall/modules/maintenance/api/types.py b/backend/dataall/modules/maintenance/api/types.py new file mode 100644 index 000000000..422f664fe --- /dev/null +++ b/backend/dataall/modules/maintenance/api/types.py @@ -0,0 +1,9 @@ +"""Defines the object types of the Maintenance activity""" + +from dataall.base.api import gql + + +Maintenance = gql.ObjectType( + name='Maintenance', + fields=[gql.Field(name='status', type=gql.NonNullableType(gql.String)), gql.Field(name='mode', type=gql.String)], +) diff --git a/tests/core/permissions/db/__init__.py b/backend/dataall/modules/maintenance/aws/__init__.py similarity index 100% rename from tests/core/permissions/db/__init__.py rename to backend/dataall/modules/maintenance/aws/__init__.py diff --git a/backend/dataall/modules/maintenance/aws/event_bridge.py b/backend/dataall/modules/maintenance/aws/event_bridge.py new file mode 100644 index 000000000..0b489a029 --- /dev/null +++ b/backend/dataall/modules/maintenance/aws/event_bridge.py @@ -0,0 +1,28 @@ +import logging + +import boto3 + +logger = logging.getLogger(__name__) + + +class EventBridge: + def __init__(self, region=None): + self.client = boto3.client('events', region_name=region) + + def enable_scheduled_ecs_tasks(self, list_of_tasks): + logger.info('Enabling ecs tasks') + try: + for ecs_task in list_of_tasks: + self.client.enable_rule(Name=ecs_task) + except Exception as e: + logger.error(f'Error while re-enabling scheduled ecs tasks due to {e}') + raise e + + def disable_scheduled_ecs_tasks(self, list_of_tasks): + logger.info('Disabling ecs tasks') + try: + for ecs_task in list_of_tasks: + self.client.disable_rule(Name=ecs_task) + except Exception as e: + logger.error(f'Error while disabling scheduled ecs tasks due to {e}') + raise e diff --git a/backend/dataall/modules/maintenance/db/__init__.py b/backend/dataall/modules/maintenance/db/__init__.py new file mode 100644 index 000000000..86631d191 --- /dev/null +++ b/backend/dataall/modules/maintenance/db/__init__.py @@ -0,0 +1 @@ +"""Contains a code to that interacts with the database""" diff --git a/backend/dataall/modules/maintenance/db/maintenance_models.py b/backend/dataall/modules/maintenance/db/maintenance_models.py new file mode 100644 index 000000000..d20870dda --- /dev/null +++ b/backend/dataall/modules/maintenance/db/maintenance_models.py @@ -0,0 +1,13 @@ +"""ORM models for maintenance activity""" + +from sqlalchemy import Column, String + +from dataall.base.db import Base + + +class Maintenance(Base): + """ORM Model for maintenance window""" + + __tablename__ = 'maintenance' + status = Column(String, nullable=False, primary_key=True) + mode = Column(String, default='', nullable=True) diff --git a/backend/dataall/modules/maintenance/db/maintenance_repository.py b/backend/dataall/modules/maintenance/db/maintenance_repository.py new file mode 100644 index 000000000..1a47e4a36 --- /dev/null +++ b/backend/dataall/modules/maintenance/db/maintenance_repository.py @@ -0,0 +1,23 @@ +""" +DAO layer that encapsulates the logic and interaction with the database for maintenance +""" + +import logging +from dataall.modules.maintenance.db.maintenance_models import Maintenance + +log = logging.getLogger(__name__) + + +class MaintenanceRepository: + def __init__(self, session): + self._session = session + + def save_maintenance_status_and_mode(self, maintenance_status: str, maintenance_mode: str): + log.debug(f'Saving maintenance status and mode as {maintenance_status} , {maintenance_mode} respectively') + maintenance_record = self._session.query(Maintenance).one() + maintenance_record.status = maintenance_status + maintenance_record.mode = maintenance_mode + self._session.commit() + + def get_maintenance_record(self): + return self._session.query(Maintenance).one() diff --git a/backend/dataall/modules/maintenance/services/__init__.py b/backend/dataall/modules/maintenance/services/__init__.py new file mode 100644 index 000000000..dbb93f498 --- /dev/null +++ b/backend/dataall/modules/maintenance/services/__init__.py @@ -0,0 +1,8 @@ +""" +Contains the code needed for service layer. +The service layer is a layer where all business logic is aggregated +""" + +from dataall.modules.maintenance.services import maintenance_service + +__all__ = ['maintenance_service'] diff --git a/backend/dataall/modules/maintenance/services/maintenance_service.py b/backend/dataall/modules/maintenance/services/maintenance_service.py new file mode 100644 index 000000000..4fb270859 --- /dev/null +++ b/backend/dataall/modules/maintenance/services/maintenance_service.py @@ -0,0 +1,149 @@ +""" +A service layer for maintenance activity +Defines functions and business logic to be performed for maintenance window +""" + +import logging +import os + +from dataall.modules.maintenance.aws.event_bridge import EventBridge +from dataall.base.aws.parameter_store import ParameterStoreManager +from dataall.base.context import get_context +from dataall.core.permissions.services.tenant_policy_service import TenantPolicyValidationService +from dataall.modules.maintenance.api.enums import MaintenanceStatus +from dataall.modules.maintenance.db.maintenance_repository import MaintenanceRepository +from dataall.core.stacks.aws.ecs import Ecs + +logger = logging.getLogger(__name__) + + +class MaintenanceService: + @staticmethod + def start_maintenance_window(mode: str = None): + """ + Start maintenance window by performing following actions + 1. Perform validation to check if the user belongs to the DAAdministrators group + 2. Put the maintenance window status to PENDING and update the maintenance mode + 3. Get all the ECS Scheduled tasks and disable the schedule for them + @param mode: mode to set for maintenance window + @return: returns True if successful or False + """ + # Check from the context if the groups contains the DAAAdminstrators group + groups = get_context().groups if get_context().groups is not None else [] + if not TenantPolicyValidationService.is_tenant_admin(groups): + raise Exception('Only data.all admin group members can start maintenance window') + + logger.info('Putting data.all into maintenance') + try: + with get_context().db_engine.scoped_session() as session: + maintenance_record = MaintenanceRepository(session).get_maintenance_record() + if ( + maintenance_record.status == MaintenanceStatus.PENDING.value + or maintenance_record.status == MaintenanceStatus.ACTIVE.value + ): + logger.error( + 'Maintenance window already in PENDING or ACTIVE state. Cannot start maintenance window. Stop the maintenance window and start again' + ) + return False + MaintenanceRepository(session).save_maintenance_status_and_mode( + maintenance_status=MaintenanceStatus.PENDING.value, maintenance_mode=mode + ) + # Disable scheduled ECS tasks + # Get all the SSM Params related to the scheduled tasks + ecs_scheduled_rules_list = MaintenanceService._get_ecs_rules() + event_bridge = EventBridge(region=os.getenv('AWS_REGION', 'eu-west-1')) + event_bridge.disable_scheduled_ecs_tasks(ecs_scheduled_rules_list) + return True + except Exception as e: + logger.error(f'Error occurred while starting maintenance window due to {e}') + return False + + @staticmethod + def stop_maintenance_window(): + """ + Stop maintenance window by performing following actions + 1. Perform validation to check if the user belongs to the DAAdministrators group + 2. Update the RDS table by changing the status to INACTIVE and mode to '-' + 3. Enable all data.all related ECS scheduled tasks + @return: return True if successful or False + """ + + # Check from the context if the groups contains the DAAAdminstrators group + groups = get_context().groups if get_context().groups is not None else [] + + if not TenantPolicyValidationService.is_tenant_admin(groups): + raise Exception('Only data.all admin group members can stop maintenance window') + logger.info('Stopping maintenance mode') + try: + with get_context().db_engine.scoped_session() as session: + maintenance_record = MaintenanceRepository(session).get_maintenance_record() + if maintenance_record.status == MaintenanceStatus.INACTIVE.value: + logger.error('Maintenance window already in INACTIVE state. Cannot stop maintenance window') + return False + MaintenanceRepository(session).save_maintenance_status_and_mode( + maintenance_status=MaintenanceStatus.INACTIVE.value, maintenance_mode='' + ) + # Enable scheduled ECS tasks + ecs_scheduled_rules_list = MaintenanceService._get_ecs_rules() + event_bridge = EventBridge(region=os.getenv('AWS_REGION', 'eu-west-1')) + event_bridge.enable_scheduled_ecs_tasks(ecs_scheduled_rules_list) + return True + except Exception as e: + logger.error(f'Error occurred while stopping maintenance window due to {e}') + return False + + @staticmethod + def get_maintenance_window_status(): + """ + Get the status of maintenance window + Maintenance record is returned after checking if all ECS tasks in the data.all created cluster have completed. + @return: Maintenance object containing status and mode + """ + logger.info('Checking maintenance window status') + try: + with get_context().db_engine.scoped_session() as session: + maintenance_record = MaintenanceRepository(session).get_maintenance_record() + if maintenance_record.status == MaintenanceStatus.PENDING.value: + # Check if ECS tasks are running + ecs_cluster_name = ParameterStoreManager.get_parameter_value( + region=os.getenv('AWS_REGION', 'eu-west-1'), + parameter_path=f"/dataall/{os.getenv('envname', 'local')}/ecs/cluster/name", + ) + if Ecs.is_task_running(cluster_name=ecs_cluster_name): + logger.info(f'Current maintenance window status - {maintenance_record.status}') + return maintenance_record + else: + logger.info( + 'All pending ECS tasks have completed running. Setting Maintenance Status to ACTIVE' + ) + maintenance_record.status = MaintenanceStatus.ACTIVE.value + session.commit() + return maintenance_record + else: + logger.info(f'Current maintenance window status - {maintenance_record.status}') + return maintenance_record + except Exception as e: + logger.error(f'Error while getting maintenance window status due to {e}') + raise e + + # Fetches the mode of maintenance window + @staticmethod + def _get_maintenance_window_mode(engine): + logger.info('Fetching mode of maintenance window') + try: + with engine.scoped_session() as session: + maintenance_record = MaintenanceRepository(session).get_maintenance_record() + logger.debug(f'Current maintenance window mode - {maintenance_record.mode}') + return maintenance_record.mode + except Exception as e: + logger.error(f'Error while getting maintenance window mode due to {e}') + raise e + + @staticmethod + def _get_ecs_rules(): + ecs_scheduled_rules = ParameterStoreManager.get_parameters_by_path( + region=os.getenv('AWS_REGION', 'eu-west-1'), + parameter_path=f"/dataall/{os.getenv('envname', 'local')}/ecs/ecs_scheduled_tasks/rule", + ) + logger.debug(ecs_scheduled_rules) + return [item['Value'] for item in ecs_scheduled_rules] diff --git a/backend/local_graphql_server.py b/backend/local_graphql_server.py index b92e86ca9..44ed6bf1f 100644 --- a/backend/local_graphql_server.py +++ b/backend/local_graphql_server.py @@ -5,6 +5,7 @@ from ariadne.constants import PLAYGROUND_HTML from flask import Flask, request, jsonify from flask_cors import CORS +from graphql import parse from dataall.base.api import get_executable_schema from dataall.core.tasks.service_handlers import Worker @@ -128,6 +129,10 @@ def graphql_server(): data = request.get_json() print('*** Request ***', request.data) + query = parse(data) + print('***** Printing Query ****** \n\n') + print(query) + context = request_context(request.headers, mock=True) logger.debug(context) diff --git a/backend/migrations/versions/b833ad41db68_maintenance_window_schema.py b/backend/migrations/versions/b833ad41db68_maintenance_window_schema.py new file mode 100644 index 000000000..231834e57 --- /dev/null +++ b/backend/migrations/versions/b833ad41db68_maintenance_window_schema.py @@ -0,0 +1,74 @@ +"""maintenance_window_schema + +Revision ID: b833ad41db68 +Revises: 194608b1ff7f +Create Date: 2024-04-16 19:30:05.226603 + +""" + +import os + +from alembic import op +import sqlalchemy as sa +from sqlalchemy import Column, String, orm + +from dataall.base.db import get_engine, has_table, Base + +# revision identifiers, used by Alembic. +revision = 'b833ad41db68' +down_revision = '458572580709' +branch_labels = None +depends_on = None + + +class Maintenance(Base): + __tablename__ = 'maintenance' + status = Column(String, nullable=False, primary_key=True) + mode = Column(String, default='', nullable=True) + + +def upgrade(): + # Upgrade scripts does the following : + # 1. Creates the maintenance table with two columns : status and mode + # 2. Creates a single record in maintenance table with status : INACTIVE and mode: '' ( Blank ) + try: + envname = os.getenv('envname', 'local') + print('ENVNAME', envname) + engine = get_engine(envname=envname).engine + + bind = op.get_bind() + session = orm.Session(bind=bind) + + # Create the maintenance table + if not has_table('maintenance', engine): + print('Creating maintenance table') + + op.create_table( + 'maintenance', + sa.Column('status', sa.String(), nullable=False, primary_key=True), + sa.Column('mode', sa.String(), nullable=True, default=''), + ) + + maintenance_record: [Maintenance] = Maintenance(status='INACTIVE', mode='') + session.add(maintenance_record) + print('Commiting single row to the maintenance table') + session.commit() + + except Exception as e: + print('Failed to create migration for maintenance table') + raise e + + +def downgrade(): + # Script for deleting the maintenance table + try: + envname = os.getenv('envname', 'local') + print('ENVNAME', envname) + engine = get_engine(envname=envname).engine + print('Starting downgrade of maintenance') + if has_table('maintenance', engine=engine): + print('Dropping table maintenance') + op.drop_table('maintenance') + except Exception as e: + print('Failed to downgrade maintenance table') + raise e diff --git a/backend/migrations/versions/d059eead99c2_rename_dataset_table_as_s3_dataset.py b/backend/migrations/versions/d059eead99c2_rename_dataset_table_as_s3_dataset.py index 0c3dcdcba..4b6630f55 100644 --- a/backend/migrations/versions/d059eead99c2_rename_dataset_table_as_s3_dataset.py +++ b/backend/migrations/versions/d059eead99c2_rename_dataset_table_as_s3_dataset.py @@ -12,7 +12,7 @@ # revision identifiers, used by Alembic. revision = 'd059eead99c2' -down_revision = '458572580709' +down_revision = 'b833ad41db68' branch_labels = None depends_on = None diff --git a/backend/search_handler.py b/backend/search_handler.py index 5a3a1318c..7985be272 100644 --- a/backend/search_handler.py +++ b/backend/search_handler.py @@ -1,10 +1,16 @@ import json import os +from dataall.base.context import RequestContext, set_context +from dataall.base.db import get_engine from dataall.base.searchproxy import connect, run_query +from dataall.base.utils.api_handler_utils import validate_and_block_if_maintenance_window, extract_groups +from dataall.modules.maintenance.api.enums import MaintenanceModes + ENVNAME = os.getenv('envname', 'local') es = connect(envname=ENVNAME) +ENGINE = get_engine(envname=ENVNAME) def handler(event, context): @@ -21,21 +27,47 @@ def handler(event, context): }, } elif event['httpMethod'] == 'POST': - body = event.get('body') - print(body) - success = True - try: - response = run_query(es, 'dataall-index', body) - except Exception: - success = False - response = {} - return { - 'statusCode': 200 if success else 400, - 'headers': { - 'content-type': 'application/json', - 'Access-Control-Allow-Origin': '*', - 'Access-Control-Allow-Headers': '*', - 'Access-Control-Allow-Methods': '*', - }, - 'body': json.dumps(response), - } + if 'authorizer' in event['requestContext']: + if 'claims' not in event['requestContext']['authorizer']: + claims = event['requestContext']['authorizer'] + else: + claims = event['requestContext']['authorizer']['claims'] + + username = claims['email'] + + # Needed for custom groups + user_id = claims['email'] + if 'user_id' in event['requestContext']['authorizer']: + user_id = event['requestContext']['authorizer']['user_id'] + + groups: list = extract_groups(user_id, claims) + + set_context(RequestContext(ENGINE, username, groups, user_id)) + + # Check if maintenance window is enabled AND if the maintenance mode is NO-ACCESS + maintenance_window_validation_response = validate_and_block_if_maintenance_window( + query={'operationName': 'OpensearchIndex'}, + groups=groups, + blocked_for_mode_enum=MaintenanceModes.NOACCESS, + ) + if maintenance_window_validation_response is not None: + return maintenance_window_validation_response + + body = event.get('body') + print(body) + success = True + try: + response = run_query(es, 'dataall-index', body) + except Exception: + success = False + response = {} + return { + 'statusCode': 200 if success else 400, + 'headers': { + 'content-type': 'application/json', + 'Access-Control-Allow-Origin': '*', + 'Access-Control-Allow-Headers': '*', + 'Access-Control-Allow-Methods': '*', + }, + 'body': json.dumps(response), + } diff --git a/config.json b/config.json index 4c2d2094e..a3e68f5b2 100644 --- a/config.json +++ b/config.json @@ -42,6 +42,9 @@ }, "dashboards": { "active": true + }, + "maintenance": { + "active" : true } }, "core": { diff --git a/deploy/stacks/backend_stack.py b/deploy/stacks/backend_stack.py index 54dc80a1f..348f55ea7 100644 --- a/deploy/stacks/backend_stack.py +++ b/deploy/stacks/backend_stack.py @@ -305,6 +305,7 @@ def __init__( lambdas=[ self.lambda_api_stack.aws_handler, self.lambda_api_stack.api_handler, + self.lambda_api_stack.elasticsearch_proxy_handler, ], ecs_security_groups=self.ecs_stack.ecs_security_groups, prod_sizing=prod_sizing, diff --git a/deploy/stacks/container.py b/deploy/stacks/container.py index 16314f30b..baff554db 100644 --- a/deploy/stacks/container.py +++ b/deploy/stacks/container.py @@ -617,6 +617,15 @@ def set_scheduled_task( rule_name=scheduled_task_id, security_groups=[security_group], ) + + # Add the rule of the scheduled task to parameter store + ssm.StringParameter( + self, + f'ECSTaskRule-{scheduled_task_id}', + parameter_name=f'/dataall/{self._envname}/ecs/ecs_scheduled_tasks/rule/{scheduled_task_id}', + string_value=scheduled_task.event_rule.rule_name, + ) + return scheduled_task, task @property diff --git a/deploy/stacks/lambda_api.py b/deploy/stacks/lambda_api.py index fcf5e8c31..fb8789154 100644 --- a/deploy/stacks/lambda_api.py +++ b/deploy/stacks/lambda_api.py @@ -96,7 +96,9 @@ def __init__( self.esproxy_dlq = self.set_dlq(f'{resource_prefix}-{envname}-esproxy-dlq') esproxy_sg = self.create_lambda_sgs(envname, 'esproxy', resource_prefix, vpc) - + esproxy_env = {'envname': envname, 'LOG_LEVEL': 'INFO'} + if custom_auth: + esproxy_env['custom_auth'] = custom_auth.get('provider', None) self.elasticsearch_proxy_handler = _lambda.DockerImageFunction( self, 'ElasticSearchProxyHandler', @@ -113,7 +115,7 @@ def __init__( security_groups=[esproxy_sg], memory_size=1664 if prod_sizing else 256, timeout=Duration.minutes(15), - environment={'envname': envname, 'LOG_LEVEL': 'INFO'}, + environment=esproxy_env, environment_encryption=lambda_env_key, dead_letter_queue_enabled=True, dead_letter_queue=self.esproxy_dlq, @@ -448,6 +450,10 @@ def create_function_role(self, envname, resource_prefix, fn_name, pivot_role_nam f'arn:aws:aoss:{self.region}:{self.account}:collection/*', ], ), + iam.PolicyStatement( + actions=['events:EnableRule', 'events:DisableRule'], + resources=[f'arn:aws:events:{self.region}:{self.account}:rule/dataall*'], + ), ], ) role = iam.Role( diff --git a/frontend/src/authentication/components/AuthGuard.js b/frontend/src/authentication/components/AuthGuard.js index 8410407ab..063c80eb1 100644 --- a/frontend/src/authentication/components/AuthGuard.js +++ b/frontend/src/authentication/components/AuthGuard.js @@ -1,18 +1,59 @@ import PropTypes from 'prop-types'; -import { useState } from 'react'; +import { useEffect, useState } from 'react'; import { Navigate, useLocation } from 'react-router-dom'; import { Login } from '../views/Login'; import { useAuth } from '../hooks'; import { + isModuleEnabled, + ModuleNames, RegexToValidateWindowPathName, WindowPathLengthThreshold -} from '../../utils'; +} from 'utils'; +import { useClient, useGroups } from 'services'; +import { LoadingScreen, NoAccessMaintenanceWindow } from 'design'; +import { getMaintenanceStatus } from '../../modules/Maintenance/services'; +import { + PENDING_STATUS, + ACTIVE_STATUS +} from '../../modules/Maintenance/components/MaintenanceViewer'; +import { SET_ERROR, useDispatch } from 'globalErrors'; export const AuthGuard = (props) => { const { children } = props; const auth = useAuth(); const location = useLocation(); const [requestedLocation, setRequestedLocation] = useState(null); + const [isNoAccessMaintenance, setNoAccessMaintenanceFlag] = useState(null); + const client = useClient(); + const groups = useGroups(); + const dispatch = useDispatch(); + + const checkMaintenanceMode = async () => { + const response = await client.query(getMaintenanceStatus()); + if (!response.errors && response.data.getMaintenanceWindowStatus != null) { + if ( + [PENDING_STATUS, ACTIVE_STATUS].includes( + response.data.getMaintenanceWindowStatus.status + ) && + response.data.getMaintenanceWindowStatus.mode === 'NO-ACCESS' && + !groups.includes('DAAdministrators') + ) { + setNoAccessMaintenanceFlag(true); + } else { + setNoAccessMaintenanceFlag(false); + } + } + }; + + useEffect(async () => { + // Check if the maintenance window is enabled and has NO-ACCESS Status + // If yes then display a blank screen with a message that data.all is in maintenance mode ( Check use of isNoAccessMaintenance state ) + if (isModuleEnabled(ModuleNames.MAINTENANCE) === true) { + if (client) { + checkMaintenanceMode().catch((e) => dispatch({ type: SET_ERROR, e })); + } + } + }, [client, groups]); if (!auth.isAuthenticated) { if (location.pathname !== requestedLocation) { @@ -31,6 +72,17 @@ export const AuthGuard = (props) => { return ; } + if ( + isNoAccessMaintenance == null && + isModuleEnabled(ModuleNames.MAINTENANCE) === true + ) { + return ; + } + + if (isNoAccessMaintenance === true) { + return ; + } + if (requestedLocation && location.pathname !== requestedLocation) { setRequestedLocation(null); return ; diff --git a/frontend/src/design/components/NoAccessMaintenanceWindow.js b/frontend/src/design/components/NoAccessMaintenanceWindow.js new file mode 100644 index 000000000..7036a8d11 --- /dev/null +++ b/frontend/src/design/components/NoAccessMaintenanceWindow.js @@ -0,0 +1,36 @@ +import { Box, Typography } from '@mui/material'; +import React from 'react'; +import config from '../../generated/config.json'; +import { SanitizedHTML } from './SanitizedHTML'; + +export const NoAccessMaintenanceWindow = () => ( + + {config.modules.maintenance.custom_maintenance_text !== undefined ? ( + + + + ) : ( + + data.all is in maintenance mode. Please contact data.all administrators + for any assistance. + + )} + +); diff --git a/frontend/src/design/components/index.js b/frontend/src/design/components/index.js index b9c460dd7..cd01380bb 100644 --- a/frontend/src/design/components/index.js +++ b/frontend/src/design/components/index.js @@ -29,3 +29,4 @@ export * from './defaults'; export * from './layout'; export * from './popovers'; export * from './SanitizedHTML'; +export * from './NoAccessMaintenanceWindow'; diff --git a/frontend/src/design/components/layout/DefaultNavbar.js b/frontend/src/design/components/layout/DefaultNavbar.js index b851d267e..a861011fb 100644 --- a/frontend/src/design/components/layout/DefaultNavbar.js +++ b/frontend/src/design/components/layout/DefaultNavbar.js @@ -1,5 +1,5 @@ -import React from 'react'; -import { AppBar, Box, IconButton, Toolbar } from '@mui/material'; +import React, { useEffect, useState } from 'react'; +import { AppBar, Box, IconButton, Toolbar, Typography } from '@mui/material'; import { makeStyles } from '@mui/styles'; import { Menu } from '@mui/icons-material'; import PropTypes from 'prop-types'; @@ -7,6 +7,15 @@ import { AccountPopover, NotificationsPopover } from '../popovers'; import { Logo } from '../Logo'; import { SettingsDrawer } from '../SettingsDrawer'; import { ModuleNames, isModuleEnabled } from 'utils'; +import config from '../../../generated/config.json'; +import { + PENDING_STATUS, + ACTIVE_STATUS +} from '../../../modules/Maintenance/components/MaintenanceViewer'; +import { useClient } from 'services'; +import { getMaintenanceStatus } from '../../../modules/Maintenance/services'; +import { SET_ERROR, useDispatch } from 'globalErrors'; +import { SanitizedHTML } from '../SanitizedHTML'; const useStyles = makeStyles((theme) => ({ appBar: { @@ -17,9 +26,58 @@ const useStyles = makeStyles((theme) => ({ export const DefaultNavbar = ({ openDrawer, onOpenDrawerChange }) => { const classes = useStyles(); + const [isMaintenance, setMaintenanceFlag] = useState(false); + const dispatch = useDispatch(); + const client = useClient(); + + const _getMaintenanceStatus = async () => { + const response = await client.query(getMaintenanceStatus()); + if (!response.errors && response.data.getMaintenanceWindowStatus !== null) { + if ( + response.data.getMaintenanceWindowStatus.status === ACTIVE_STATUS || + response.data.getMaintenanceWindowStatus.status === PENDING_STATUS + ) { + setMaintenanceFlag(true); + } + } else { + const error = response.errors + ? response.errors[0].message + : 'Could not fetch status of maintenance window'; + dispatch({ type: SET_ERROR, error }); + } + }; + + useEffect(async () => { + if (client && isModuleEnabled(ModuleNames.MAINTENANCE)) { + _getMaintenanceStatus().catch((err) => + dispatch({ type: SET_ERROR, err }) + ); + } + }, [client]); return ( + {isModuleEnabled(ModuleNames.MAINTENANCE) && isMaintenance ? ( + + {config.modules.maintenance.custom_maintenance_text !== undefined ? ( + + + + ) : ( + + data.all is in maintenance mode. You can still navigate inside + data.all but during this period, please do not make any + modifications to any data.all assets ( datasets, environment, etc + ). + + )} + + ) : ( + <> + )} + {!openDrawer && ( { const { settings } = useSettings(); const [currentTab, setCurrentTab] = useState('teams'); @@ -90,6 +96,7 @@ const AdministrationView = () => { {currentTab === 'teams' && } {currentTab === 'dashboard' && } + {currentTab === 'maintenance' && } diff --git a/frontend/src/modules/Maintenance/components/MaintenanceViewer.js b/frontend/src/modules/Maintenance/components/MaintenanceViewer.js new file mode 100644 index 000000000..2feb3713e --- /dev/null +++ b/frontend/src/modules/Maintenance/components/MaintenanceViewer.js @@ -0,0 +1,438 @@ +import { + Box, + Button, + Card, + CardHeader, + CircularProgress, + Dialog, + Divider, + Grid, + IconButton, + MenuItem, + TextField, + Typography +} from '@mui/material'; +import React, { useCallback, useEffect, useState } from 'react'; +import { Article, CancelRounded, SystemUpdate } from '@mui/icons-material'; +import { LoadingButton } from '@mui/lab'; +import { Label } from 'design'; +import { + getMaintenanceStatus, + stopMaintenanceWindow, + startMaintenanceWindow +} from '../services'; +import { useClient } from 'services'; +import { SET_ERROR, useDispatch } from 'globalErrors'; +import { useSnackbar } from 'notistack'; + +const maintenanceModes = [ + { value: 'READ-ONLY', label: 'Read-Only' }, + { value: 'NO-ACCESS', label: 'No-Access' } +]; + +const START_MAINTENANCE = 'Start Maintenance'; +const END_MAINTENANCE = 'End Maintenance'; +export const PENDING_STATUS = 'PENDING'; +export const ACTIVE_STATUS = 'ACTIVE'; +export const INACTIVE_STATUS = 'INACTIVE'; + +export const MaintenanceConfirmationPopUp = (props) => { + const { + popUp, + setPopUp, + confirmedMode, + setConfirmedMode, + maintenanceButtonText, + setMaintenanceButtonText, + setDropDownStatus, + setMaintenanceWindowStatus + } = props; + const client = useClient(); + const dispatch = useDispatch(); + const { enqueueSnackbar } = useSnackbar(); + + const handlePopUpModal = async () => { + if (maintenanceButtonText === START_MAINTENANCE) { + if (!client) { + dispatch({ + type: SET_ERROR, + error: 'Client not initialized for starting maintenance window' + }); + } + const response = await client.mutate( + startMaintenanceWindow({ mode: confirmedMode }) + ); + if (!response.errors && response.data.startMaintenanceWindow != null) { + const respData = response.data.startMaintenanceWindow; + if (respData === true) { + setMaintenanceButtonText(END_MAINTENANCE); + setMaintenanceWindowStatus(PENDING_STATUS); + setDropDownStatus(false); + enqueueSnackbar( + 'Maintenance Window Started. Please check the status', + { + anchorOrigin: { + horizontal: 'right', + vertical: 'top' + }, + variant: 'success' + } + ); + } else { + enqueueSnackbar('Could not start maintenance window', { + anchorOrigin: { + horizontal: 'right', + vertical: 'top' + }, + variant: 'success' + }); + } + } else { + const error = response.errors + ? response.errors[0].message + : 'Something went wrong while starting maintenance window. Please check gql logs'; + dispatch({ type: SET_ERROR, error }); + } + } else if (maintenanceButtonText === END_MAINTENANCE) { + const response = await client.mutate(stopMaintenanceWindow()); + if ( + !response.errors && + response.data.stopMaintenanceWindow != null && + response.data.stopMaintenanceWindow === true + ) { + setMaintenanceButtonText(START_MAINTENANCE); + // Unfreeze the dropdown menu + setDropDownStatus(true); + setConfirmedMode(''); + setMaintenanceWindowStatus(INACTIVE_STATUS); + enqueueSnackbar('Maintenance Window Stopped', { + anchorOrigin: { + horizontal: 'right', + vertical: 'top' + }, + variant: 'success' + }); + } else { + const error = response.errors + ? response.errors[0].message + : 'Something went wrong while stopping maintenance window. Please check gql logs'; + dispatch({ type: SET_ERROR, error }); + } + } + setPopUp(false); + }; + + return ( + + + + + Are you sure you want to {maintenanceButtonText.toLowerCase()}? + + } + /> + + + + + + + + + ); +}; + +export const MaintenanceViewer = () => { + const client = useClient(); + const [refreshing, setRefreshing] = useState(false); + const [updating, setUpdating] = useState(false); + const [mode, setMode] = useState(''); + const [popUp, setPopUp] = useState(false); + const [confirmedMode, setConfirmedMode] = useState(''); + const [maintenanceButtonText, setMaintenanceButtonText] = + useState(START_MAINTENANCE); + const [maintenanceWindowStatus, setMaintenanceWindowStatus] = + useState(INACTIVE_STATUS); + const [dropDownStatus, setDropDownStatus] = useState(false); + const [refreshingTimer, setRefreshingTimer] = useState(''); + const { enqueueSnackbar, closeSnackbar } = useSnackbar(); + const dispatch = useDispatch(); + + const refreshMaintenanceView = async () => { + setUpdating(true); + setRefreshing(true); + _getMaintenanceWindowStatus() + .then((data) => { + setMaintenanceWindowStatus(data.status); + if (data.status === INACTIVE_STATUS) { + setMaintenanceButtonText(START_MAINTENANCE); + setConfirmedMode(''); + setDropDownStatus(true); + clearInterval(refreshingTimer); + } else { + setMaintenanceButtonText(END_MAINTENANCE); + setConfirmedMode( + maintenanceModes.find((obj) => obj.value === data.mode).label + ); + setDropDownStatus(false); + } + setUpdating(false); + setRefreshing(false); + }) + .catch((e) => dispatch({ type: SET_ERROR, e })); + }; + + const _getMaintenanceWindowStatus = async () => { + if (client) { + const response = await client.query(getMaintenanceStatus()); + if ( + !response.errors && + response.data.getMaintenanceWindowStatus !== null + ) { + return response.data.getMaintenanceWindowStatus; + } else { + const error = response.errors + ? response.errors[0].message + : 'Could not fetch status of maintenance window'; + dispatch({ type: SET_ERROR, error }); + } + } + }; + + const startMaintenanceWindow = () => { + // Check if proper maintenance mode is selected + if ( + !maintenanceModes.map((obj) => obj.value).includes(mode) && + maintenanceButtonText === START_MAINTENANCE + ) { + dispatch({ + type: SET_ERROR, + error: 'Please select correct maintenance mode' + }); + } + setConfirmedMode(mode); + setPopUp(true); + }; + + const refreshStatus = async () => { + closeSnackbar(); + const response = await client.query(getMaintenanceStatus()); + if (!response.errors && response.data.getMaintenanceWindowStatus !== null) { + const maintenanceStatusData = response.data.getMaintenanceWindowStatus; + setMaintenanceWindowStatus(maintenanceStatusData.status); + if ( + maintenanceStatusData.status === INACTIVE_STATUS || + maintenanceStatusData.status === ACTIVE_STATUS + ) { + clearInterval(refreshingTimer); + } else { + enqueueSnackbar( + + + + + + + + Maintenance Window Status is being updated + + + + , + { + key: new Date().getTime() + Math.random(), + anchorOrigin: { + horizontal: 'right', + vertical: 'top' + }, + variant: 'info', + persist: true, + action: (key) => ( + { + closeSnackbar(key); + }} + > + + + ) + } + ); + } + } else { + const error = response.errors + ? response.errors[0].message + : 'Maintenance Status not found. Something went wrong'; + dispatch({ type: SET_ERROR, error }); + } + }; + + const initializeMaintenanceView = useCallback(async () => { + setRefreshing(true); + const response = await client.query(getMaintenanceStatus()); + if (!response.errors && response.data.getMaintenanceWindowStatus !== null) { + const maintenanceStatusData = response.data.getMaintenanceWindowStatus; + if ( + maintenanceStatusData.status === PENDING_STATUS || + maintenanceStatusData.status === ACTIVE_STATUS + ) { + setMaintenanceButtonText(END_MAINTENANCE); + setMaintenanceWindowStatus(maintenanceStatusData.status); + setConfirmedMode( + maintenanceModes.find( + (obj) => obj.value === maintenanceStatusData.mode + ).label + ); + setDropDownStatus(false); + } else if (maintenanceStatusData.status === INACTIVE_STATUS) { + setMaintenanceButtonText(START_MAINTENANCE); + setConfirmedMode(''); + setDropDownStatus(true); + } + } else { + const error = response.errors + ? response.errors[0].message + : 'Maintenance Status not found. Something went wrong'; + dispatch({ type: SET_ERROR, error }); + } + setRefreshing(false); + }, [client]); + + useEffect(() => { + if (client) { + initializeMaintenanceView().catch((e) => + dispatch({ type: SET_ERROR, e }) + ); + const setTimer = setInterval(() => { + refreshStatus().catch((e) => + dispatch({ type: SET_ERROR, error: e.message }) + ); + }, [10000]); + setRefreshingTimer(setTimer); + return () => clearInterval(setTimer); + } + }, [client]); + + return ( + + {refreshing ? ( + + ) : ( + + + Create a Maintenance Window} /> + + + + + { + setMode(event.target.value); + }} + select + value={mode} + variant="outlined" + disabled={!dropDownStatus} + > + {maintenanceModes.map((group) => ( + + {group.label} + + ))} + + + + + } + sx={{ m: 1 }} + variant="contained" + > + Refresh + + + + + + + Maintenance window status :{' '} + {maintenanceWindowStatus === ACTIVE_STATUS ? ( + + ) : maintenanceWindowStatus === PENDING_STATUS ? ( + + ) : maintenanceWindowStatus === INACTIVE_STATUS ? ( + + ) : ( + <> - + )} + + + | + + + Current maintenance mode : {confirmedMode} + + + + + + Note - For safe deployments, please deploy when the status is{' '} + + + + + + + )} + + ); +}; diff --git a/frontend/src/modules/Maintenance/index.js b/frontend/src/modules/Maintenance/index.js new file mode 100644 index 000000000..58db6fd95 --- /dev/null +++ b/frontend/src/modules/Maintenance/index.js @@ -0,0 +1,5 @@ +export const MaintenanceModule = { + moduleDefinition: true, + name: 'maintenance', + isEnvironmentModule: false +}; diff --git a/frontend/src/modules/Maintenance/services/getMaintenanceStatus.js b/frontend/src/modules/Maintenance/services/getMaintenanceStatus.js new file mode 100644 index 000000000..3995cadc5 --- /dev/null +++ b/frontend/src/modules/Maintenance/services/getMaintenanceStatus.js @@ -0,0 +1,12 @@ +import { gql } from 'apollo-boost'; + +export const getMaintenanceStatus = () => ({ + query: gql` + query getMaintenanceWindowStatus { + getMaintenanceWindowStatus { + status + mode + } + } + ` +}); diff --git a/frontend/src/modules/Maintenance/services/index.js b/frontend/src/modules/Maintenance/services/index.js new file mode 100644 index 000000000..ca55e15e2 --- /dev/null +++ b/frontend/src/modules/Maintenance/services/index.js @@ -0,0 +1,3 @@ +export * from './getMaintenanceStatus'; +export * from './stopMaintenanceWindow'; +export * from './startMaintenanceWindow'; diff --git a/frontend/src/modules/Maintenance/services/startMaintenanceWindow.js b/frontend/src/modules/Maintenance/services/startMaintenanceWindow.js new file mode 100644 index 000000000..1005ba007 --- /dev/null +++ b/frontend/src/modules/Maintenance/services/startMaintenanceWindow.js @@ -0,0 +1,10 @@ +import { gql } from 'apollo-boost'; + +export const startMaintenanceWindow = ({ mode }) => ({ + variables: { mode }, + mutation: gql` + mutation startMaintenanceWindow($mode: String) { + startMaintenanceWindow(mode: $mode) + } + ` +}); diff --git a/frontend/src/modules/Maintenance/services/stopMaintenanceWindow.js b/frontend/src/modules/Maintenance/services/stopMaintenanceWindow.js new file mode 100644 index 000000000..22cd88ea0 --- /dev/null +++ b/frontend/src/modules/Maintenance/services/stopMaintenanceWindow.js @@ -0,0 +1,9 @@ +import { gql } from 'apollo-boost'; + +export const stopMaintenanceWindow = () => ({ + mutation: gql` + mutation stopMaintenanceWindow { + stopMaintenanceWindow + } + ` +}); diff --git a/frontend/src/modules/index.js b/frontend/src/modules/index.js index e91124e64..ef1a656da 100644 --- a/frontend/src/modules/index.js +++ b/frontend/src/modules/index.js @@ -9,3 +9,4 @@ export * from './Pipelines'; export * from './S3_Datasets'; export * from './Shares'; export * from './Worksheets'; +export * from './Maintenance'; diff --git a/frontend/src/services/hooks/useClient.js b/frontend/src/services/hooks/useClient.js index 061b98f0e..80967bb86 100644 --- a/frontend/src/services/hooks/useClient.js +++ b/frontend/src/services/hooks/useClient.js @@ -75,6 +75,10 @@ export const useClient = () => { if (extensions?.code === 'REAUTH') { setReAuth(operation); } + // Dispatch to show message when a 4xx network error is returned + if (networkError) { + dispatch({ type: SET_ERROR, error: `${message}` }); + } } ); } diff --git a/frontend/src/services/hooks/useGroups.js b/frontend/src/services/hooks/useGroups.js index 87d7919ed..09afa1e5a 100644 --- a/frontend/src/services/hooks/useGroups.js +++ b/frontend/src/services/hooks/useGroups.js @@ -16,12 +16,9 @@ export const useGroups = () => { ) { setGroups(['Engineers', 'Scientists', 'DAAdministrators']); } else if (process.env.REACT_APP_CUSTOM_AUTH) { - if (!auth.user) { - dispatch({ - type: SET_ERROR, - error: 'Cannot Set User Groups as the User is not defined' - }); - } + // Returning when auth.user is not present + // Not dispatching error as useGroups is triggered in auth guard when the user is not authenticated + if (!auth.user) return; // return if the client is null, and then trigger this when the client is present if (client == null) return; const response = await client.query(getGroupsForUser(auth.user.short_id)); diff --git a/tests/modules/maintenance/__init__.py b/tests/modules/maintenance/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/tests/modules/maintenance/test_mainenance_gql.py b/tests/modules/maintenance/test_mainenance_gql.py new file mode 100644 index 000000000..7f475b2c2 --- /dev/null +++ b/tests/modules/maintenance/test_mainenance_gql.py @@ -0,0 +1,143 @@ +from unittest.mock import MagicMock +from dataall.base.config import config + +import pytest + +from dataall.modules.maintenance.db.maintenance_models import Maintenance + + +@pytest.fixture(scope='module') +def mock_ecs_client(module_mocker): + module_mocker.patch( + 'dataall.modules.maintenance.services.maintenance_service.ParameterStoreManager.get_parameters_by_path', + return_value=[{'item': 'task1', 'Value': 'task1'}, {'item': 'task2', 'Value': 'task2'}], + ) + mock_events = MagicMock() + module_mocker.patch( + 'dataall.modules.maintenance.services.maintenance_service.EventBridge', return_value=mock_events + ) + mock_events().disable_scheduled_ecs_tasks.return_value = True + yield mock_events + + +@pytest.fixture(scope='function') +def init_maintenance_record(db): + with db.scoped_session() as session: + maintenance_record = Maintenance(status='INACTIVE', mode='') + session.add(maintenance_record) + session.commit() + yield + with db.scoped_session() as session: + maintenance_record = session.query(Maintenance).one() + session.delete(maintenance_record) + session.commit() + + +def test_start_maintenance_window(db, client, mock_ecs_client, init_maintenance_record): + response = client.query( + """ + mutation startMaintenanceWindow($mode: String!){ + startMaintenanceWindow(mode: $mode) + } + """, + mode='READ-ONLY', + username='alice', + groups=['DAAdministrators', 'Engineers'], + ) + + assert response + assert response.data.startMaintenanceWindow is True + + with db.scoped_session() as session: + maintenance_record = session.query(Maintenance).one() + assert maintenance_record.status == 'PENDING' + assert maintenance_record.mode == 'READ-ONLY' + + +def test_start_maintenance_window_with_team_not_a_data_admin(client, mock_ecs_client, init_maintenance_record): + response = client.query( + """ + mutation startMaintenanceWindow($mode: String!){ + startMaintenanceWindow(mode: $mode) + } + """, + mode='READ-ONLY', + username='alice', + groups=['Engineers'], + ) + + assert response + assert 'Only data.all admin group members can start maintenance window' in response.errors[0]['message'] + + +def test_stop_maintenance_window(db, client, mock_ecs_client, init_maintenance_record): + # Initialize the maintenance window with ACTIVE status and READ-ONLY mode + with db.scoped_session() as session: + maintenance_record = session.query(Maintenance).one() + maintenance_record.mode = 'READ-ONLY' + maintenance_record.status = 'ACTIVE' + session.add(maintenance_record) + session.commit() + + response = client.query( + """ + mutation stopMaintenanceWindow{ + stopMaintenanceWindow + } + """, + username='alice', + groups=['DAAdministrators', 'Engineers'], + ) + + assert response + assert response.data.stopMaintenanceWindow is True + + +def test_stop_maintenance_window_no_dataall_admin(db, client, mock_ecs_client, init_maintenance_record): + # Initialize the maintenance window with ACTIVE status and READ-ONLY mode + with db.scoped_session() as session: + maintenance_record = session.query(Maintenance).one() + maintenance_record.mode = 'READ-ONLY' + maintenance_record.status = 'ACTIVE' + session.add(maintenance_record) + session.commit() + + response = client.query( + """ + mutation stopMaintenanceWindow{ + stopMaintenanceWindow + } + """, + username='alice', + groups=['Engineers'], + ) + + assert response + assert 'Only data.all admin group members can stop maintenance window' in response.errors[0]['message'] + + +def test_get_maintenance_window_status(db, client, mock_ecs_client, init_maintenance_record): + # Initialize the maintenance window with ACTIVE status and READ-ONLY mode + with db.scoped_session() as session: + maintenance_record = session.query(Maintenance).one() + maintenance_record.mode = 'READ-ONLY' + maintenance_record.status = 'ACTIVE' + session.add(maintenance_record) + session.commit() + + response = client.query( + """ + query getMaintenanceWindowStatus{ + getMaintenanceWindowStatus{ + status, + mode + } + } + """, + username='alice', + groups=['DAAdministrators', 'Engineers'], + ) + + assert response + assert response.data.getMaintenanceWindowStatus.status == 'ACTIVE' + assert response.data.getMaintenanceWindowStatus.mode == 'READ-ONLY'