diff --git a/backend/api_handler.py b/backend/api_handler.py index 46d902f6d..890235347 100644 --- a/backend/api_handler.py +++ b/backend/api_handler.py @@ -12,7 +12,9 @@ from dataall.api.Objects import bootstrap as bootstrap_schema, get_executable_schema from dataall.aws.handlers.service_handlers import Worker from dataall.aws.handlers.sqs import SqsQueue +from dataall.core.context import set_context, dispose_context, RequestContext from dataall.db import init_permissions, get_engine, api, permissions +from dataall.modules.loader import load_modules, ImportMode from dataall.searchproxy import connect logger = logging.getLogger() @@ -23,6 +25,7 @@ for name in ['boto3', 's3transfer', 'botocore', 'boto']: logging.getLogger(name).setLevel(logging.ERROR) +load_modules(modes=[ImportMode.API]) SCHEMA = bootstrap_schema() TYPE_DEFS = gql(SCHEMA.gql(with_directives=False)) ENVNAME = os.getenv('envname', 'local') @@ -42,7 +45,6 @@ def adapted(obj, info, **kwargs): username=info.context['username'], groups=info.context['groups'], schema=info.context['schema'], - cdkproxyurl=info.context['cdkproxyurl'], ), source=obj or None, **kwargs, @@ -135,14 +137,16 @@ def handler(event, context): print(f'Error managing groups due to: {e}') groups = [] + set_context(RequestContext(ENGINE, username, groups, ES)) + app_context = { 'engine': ENGINE, 'es': ES, 'username': username, 'groups': groups, 'schema': SCHEMA, - 'cdkproxyurl': None, } + else: raise Exception(f'Could not initialize user context from event {event}') @@ -150,6 +154,8 @@ def handler(event, context): success, response = graphql_sync( schema=executable_schema, data=query, context_value=app_context ) + + dispose_context() response = json.dumps(response) log.info('Lambda Response %s', response) diff --git a/backend/aws_handler.py b/backend/aws_handler.py index 56089ab34..872c8f433 100644 --- a/backend/aws_handler.py +++ b/backend/aws_handler.py @@ -4,6 +4,7 @@ from dataall.aws.handlers.service_handlers import Worker from dataall.db import get_engine +from dataall.modules.loader import load_modules, ImportMode logger = logging.getLogger() logger.setLevel(os.environ.get('LOG_LEVEL')) @@ -13,6 +14,8 @@ engine = get_engine(envname=ENVNAME) +load_modules(modes=[ImportMode.TASKS]) + def handler(event, context=None): """Processes messages received from sqs""" diff --git a/backend/cdkproxymain.py b/backend/cdkproxymain.py index 1a2ff2de9..602b580c8 100644 --- a/backend/cdkproxymain.py +++ b/backend/cdkproxymain.py @@ -10,6 +10,7 @@ import dataall.cdkproxy.cdk_cli_wrapper as wrapper from dataall.cdkproxy.stacks import StackManager from dataall import db +from dataall.modules.loader import load_modules, ImportMode print('\n'.join(sys.path)) @@ -20,7 +21,7 @@ f"Application started for envname= `{ENVNAME}` DH_DOCKER_VERSION:{os.environ.get('DH_DOCKER_VERSION')}" ) - +load_modules(modes=[ImportMode.CDK]) StackManager.registered_stacks() diff --git a/backend/dataall/api/Objects/Dashboard/resolvers.py b/backend/dataall/api/Objects/Dashboard/resolvers.py index a44800502..714b6c4b9 100644 --- a/backend/dataall/api/Objects/Dashboard/resolvers.py +++ b/backend/dataall/api/Objects/Dashboard/resolvers.py @@ -211,12 +211,6 @@ def get_dashboard_organization(context: Context, source: models.Dashboard, **kwa return org -def get_dashboard_environment(context: Context, source: models.Dashboard, **kwargs): - with context.engine.scoped_session() as session: - env = session.query(models.Environment).get(source.environmentUri) - return env - - def request_dashboard_share( context: Context, source: models.Dashboard, diff --git a/backend/dataall/api/Objects/Dashboard/schema.py b/backend/dataall/api/Objects/Dashboard/schema.py index a8db3f3bf..58b6a30cb 100644 --- a/backend/dataall/api/Objects/Dashboard/schema.py +++ b/backend/dataall/api/Objects/Dashboard/schema.py @@ -2,6 +2,8 @@ from .resolvers import * from ...constants import DashboardRole +from dataall.api.Objects.Environment.resolvers import resolve_environment + Dashboard = gql.ObjectType( name='Dashboard', fields=[ @@ -23,7 +25,7 @@ gql.Field( 'environment', type=gql.Ref('Environment'), - resolver=get_dashboard_environment, + resolver=resolve_environment, ), gql.Field( 'userRoleForDashboard', diff --git a/backend/dataall/api/Objects/DataPipeline/resolvers.py b/backend/dataall/api/Objects/DataPipeline/resolvers.py index d5db551bb..fb5fe2c90 100644 --- a/backend/dataall/api/Objects/DataPipeline/resolvers.py +++ b/backend/dataall/api/Objects/DataPipeline/resolvers.py @@ -52,7 +52,7 @@ def create_pipeline(context: Context, source, input=None): payload={'account': pipeline.AwsAccountId, 'region': pipeline.region}, ) - stack_helper.deploy_stack(context, pipeline.DataPipelineUri) + stack_helper.deploy_stack(pipeline.DataPipelineUri) return pipeline @@ -80,7 +80,7 @@ def update_pipeline(context: Context, source, DataPipelineUri: str, input: dict check_perm=True, ) if (pipeline.template == ""): - stack_helper.deploy_stack(context, pipeline.DataPipelineUri) + stack_helper.deploy_stack(pipeline.DataPipelineUri) return pipeline @@ -111,14 +111,6 @@ def get_pipeline(context: Context, source, DataPipelineUri: str = None): ) -def get_pipeline_env(context: Context, source: models.DataPipeline, **kwargs): - if not source: - return None - with context.engine.scoped_session() as session: - env = session.query(models.Environment).get(source.environmentUri) - return env - - def resolve_user_role(context: Context, source: models.DataPipeline): if not source: return None @@ -155,15 +147,6 @@ def list_pipeline_environments(context: Context, source: models.DataPipeline, fi ) -def get_pipeline_org(context: Context, source: models.DataPipeline, **kwargs): - if not source: - return None - with context.engine.scoped_session() as session: - env = session.query(models.Environment).get(source.environmentUri) - org = session.query(models.Organization).get(env.organizationUri) - return org - - def get_clone_url_http(context: Context, source: models.DataPipeline, **kwargs): if not source: return None @@ -249,7 +232,6 @@ def get_stack(context, source: models.DataPipeline, **kwargs): if not source: return None return stack_helper.get_stack_with_cfn_resources( - context=context, targetUri=source.DataPipelineUri, environmentUri=source.environmentUri, ) @@ -399,7 +381,6 @@ def delete_pipeline( if deleteFromAWS: stack_helper.delete_repository( - context=context, target_uri=DataPipelineUri, accountid=env.AwsAccountId, cdk_role_arn=env.CDKRoleArn, @@ -408,21 +389,17 @@ def delete_pipeline( ) if pipeline.devStrategy == "cdk-trunk": stack_helper.delete_stack( - context=context, target_uri=DataPipelineUri, accountid=env.AwsAccountId, cdk_role_arn=env.CDKRoleArn, region=env.region, - target_type='cdkpipeline', ) else: stack_helper.delete_stack( - context=context, target_uri=DataPipelineUri, accountid=env.AwsAccountId, cdk_role_arn=env.CDKRoleArn, region=env.region, - target_type='pipeline', ) return True diff --git a/backend/dataall/api/Objects/DataPipeline/schema.py b/backend/dataall/api/Objects/DataPipeline/schema.py index 72f00cac2..c3fd8c0c1 100644 --- a/backend/dataall/api/Objects/DataPipeline/schema.py +++ b/backend/dataall/api/Objects/DataPipeline/schema.py @@ -1,6 +1,8 @@ from ... import gql from .resolvers import * from ...constants import DataPipelineRole +from dataall.api.Objects.Environment.resolvers import resolve_environment +from dataall.api.Objects.Organization.resolvers import resolve_organization_by_env DataPipeline = gql.ObjectType( name='DataPipeline', @@ -16,10 +18,10 @@ gql.Field('repo', type=gql.String), gql.Field('SamlGroupName', type=gql.String), gql.Field( - 'organization', type=gql.Ref('Organization'), resolver=get_pipeline_org + 'organization', type=gql.Ref('Organization'), resolver=resolve_organization_by_env ), gql.Field( - 'environment', type=gql.Ref('Environment'), resolver=get_pipeline_env + 'environment', type=gql.Ref('Environment'), resolver=resolve_environment ), gql.Field( 'developmentEnvironments', diff --git a/backend/dataall/api/Objects/Dataset/resolvers.py b/backend/dataall/api/Objects/Dataset/resolvers.py index 93a89d1e1..a03b0647f 100644 --- a/backend/dataall/api/Objects/Dataset/resolvers.py +++ b/backend/dataall/api/Objects/Dataset/resolvers.py @@ -38,7 +38,7 @@ def create_dataset(context: Context, source, input=None): session=session, es=context.es, datasetUri=dataset.datasetUri ) - stack_helper.deploy_dataset_stack(context, dataset) + stack_helper.deploy_dataset_stack(dataset) dataset.userRoleForDataset = DatasetRole.Creator.value @@ -76,7 +76,7 @@ def import_dataset(context: Context, source, input=None): session=session, es=context.es, datasetUri=dataset.datasetUri ) - stack_helper.deploy_dataset_stack(context, dataset) + stack_helper.deploy_dataset_stack(dataset) dataset.userRoleForDataset = DatasetRole.Creator.value @@ -222,7 +222,7 @@ def update_dataset(context, source, datasetUri: str = None, input: dict = None): ) indexers.upsert_dataset(session, context.es, datasetUri) - stack_helper.deploy_dataset_stack(context, updated_dataset) + stack_helper.deploy_dataset_stack(updated_dataset) return updated_dataset @@ -493,7 +493,6 @@ def get_dataset_stack(context: Context, source: models.Dataset, **kwargs): if not source: return None return stack_helper.get_stack_with_cfn_resources( - context=context, targetUri=source.datasetUri, environmentUri=source.environmentUri, ) @@ -575,14 +574,12 @@ def delete_dataset( if deleteFromAWS: stack_helper.delete_stack( - context=context, target_uri=datasetUri, accountid=env.AwsAccountId, cdk_role_arn=env.CDKRoleArn, region=env.region, - target_type='dataset', ) - stack_helper.deploy_stack(context, dataset.environmentUri) + stack_helper.deploy_stack(dataset.environmentUri) return True diff --git a/backend/dataall/api/Objects/Environment/input_types.py b/backend/dataall/api/Objects/Environment/input_types.py index 0b87eec63..ad5b36c90 100644 --- a/backend/dataall/api/Objects/Environment/input_types.py +++ b/backend/dataall/api/Objects/Environment/input_types.py @@ -10,6 +10,14 @@ ], ) +ModifyEnvironmentParameterInput = gql.InputType( + name='ModifyEnvironmentParameterInput', + arguments=[ + gql.Argument('key', gql.String), + gql.Argument('value', gql.String) + ] +) + NewEnvironmentInput = gql.InputType( name='NewEnvironmentInput', arguments=[ @@ -21,7 +29,6 @@ gql.Argument('AwsAccountId', gql.NonNullableType(gql.String)), gql.Argument('region', gql.NonNullableType(gql.String)), gql.Argument('dashboardsEnabled', type=gql.Boolean), - gql.Argument('notebooksEnabled', type=gql.Boolean), gql.Argument('mlStudiosEnabled', type=gql.Boolean), gql.Argument('pipelinesEnabled', type=gql.Boolean), gql.Argument('warehousesEnabled', type=gql.Boolean), @@ -30,6 +37,8 @@ gql.Argument('publicSubnetIds', gql.ArrayType(gql.String)), gql.Argument('EnvironmentDefaultIAMRoleName', gql.String), gql.Argument('resourcePrefix', gql.String), + gql.Argument('parameters', gql.ArrayType(ModifyEnvironmentParameterInput)) + ], ) @@ -44,11 +53,11 @@ gql.Argument('privateSubnetIds', gql.ArrayType(gql.String)), gql.Argument('publicSubnetIds', gql.ArrayType(gql.String)), gql.Argument('dashboardsEnabled', type=gql.Boolean), - gql.Argument('notebooksEnabled', type=gql.Boolean), gql.Argument('mlStudiosEnabled', type=gql.Boolean), gql.Argument('pipelinesEnabled', type=gql.Boolean), gql.Argument('warehousesEnabled', type=gql.Boolean), gql.Argument('resourcePrefix', gql.String), + gql.Argument('parameters', gql.ArrayType(ModifyEnvironmentParameterInput)) ], ) diff --git a/backend/dataall/api/Objects/Environment/resolvers.py b/backend/dataall/api/Objects/Environment/resolvers.py index 49cf72797..60af060a7 100644 --- a/backend/dataall/api/Objects/Environment/resolvers.py +++ b/backend/dataall/api/Objects/Environment/resolvers.py @@ -5,7 +5,7 @@ import boto3 from botocore.config import Config from botocore.exceptions import ClientError -from sqlalchemy import and_ +from sqlalchemy import and_, exc from ..Organization.resolvers import * from ..Stack import stack_helper @@ -69,7 +69,7 @@ def create_environment(context: Context, source, input=None): target_uri=env.environmentUri, target_label=env.label, ) - stack_helper.deploy_stack(context, targetUri=env.environmentUri) + stack_helper.deploy_stack(targetUri=env.environmentUri) env.userRoleInEnvironment = EnvironmentPermission.Owner.value return env @@ -99,9 +99,7 @@ def update_environment( if input.get('dashboardsEnabled') or ( environment.resourcePrefix != previous_resource_prefix ): - stack_helper.deploy_stack( - context=context, targetUri=environment.environmentUri - ) + stack_helper.deploy_stack(targetUri=environment.environmentUri) return environment @@ -116,7 +114,7 @@ def invite_group(context: Context, source, input): check_perm=True, ) - stack_helper.deploy_stack(context=context, targetUri=environment.environmentUri) + stack_helper.deploy_stack(targetUri=environment.environmentUri) return environment @@ -153,7 +151,7 @@ def update_group_permissions(context, source, input): check_perm=True, ) - stack_helper.deploy_stack(context=context, targetUri=environment.environmentUri) + stack_helper.deploy_stack(targetUri=environment.environmentUri) return environment @@ -169,7 +167,7 @@ def remove_group(context: Context, source, environmentUri=None, groupUri=None): check_perm=True, ) - stack_helper.deploy_stack(context=context, targetUri=environment.environmentUri) + stack_helper.deploy_stack(targetUri=environment.environmentUri) return environment @@ -507,7 +505,6 @@ def generate_environment_access_token( def get_environment_stack(context: Context, source: models.Environment, **kwargs): return stack_helper.get_stack_with_cfn_resources( - context=context, targetUri=source.environmentUri, environmentUri=source.environmentUri, ) @@ -526,23 +523,27 @@ def delete_environment( ) environment = db.api.Environment.get_environment_by_uri(session, environmentUri) - db.api.Environment.delete_environment( - session, - username=context.username, - groups=context.groups, - uri=environmentUri, - data={'environment': environment}, - check_perm=True, - ) + try: + db.api.Environment.delete_environment( + session, + username=context.username, + groups=context.groups, + uri=environmentUri, + data={'environment': environment}, + check_perm=True, + ) + except exc.IntegrityError: + raise exceptions.EnvironmentResourcesFound( + action='Delete Environment', + message='Delete all environment related objects before proceeding', + ) if deleteFromAWS: stack_helper.delete_stack( - context=context, target_uri=environmentUri, accountid=environment.AwsAccountId, cdk_role_arn=environment.CDKRoleArn, region=environment.region, - target_type='environment', ) return True @@ -597,7 +598,7 @@ def enable_subscriptions( environment.subscriptionsConsumersTopicImported = False environment.subscriptionsEnabled = True session.commit() - stack_helper.deploy_stack(context=context, targetUri=environment.environmentUri) + stack_helper.deploy_stack(targetUri=environment.environmentUri) return True @@ -618,7 +619,7 @@ def disable_subscriptions(context: Context, source, environmentUri: str = None): environment.subscriptionsProducersTopicImported = False environment.subscriptionsEnabled = False session.commit() - stack_helper.deploy_stack(context=context, targetUri=environment.environmentUri) + stack_helper.deploy_stack(targetUri=environment.environmentUri) return True @@ -702,3 +703,19 @@ def get_pivot_role_name(context: Context, source, organizationUri=None): message='Pivot role name could not be found on AWS Secretsmanager', ) return pivot_role_name + + +def resolve_environment(context, source, **kwargs): + """Resolves the environment for a environmental resource""" + if not source: + return None + with context.engine.scoped_session() as session: + return session.query(models.Environment).get(source.environmentUri) + + +def resolve_parameters(context, source: models.Environment, **kwargs): + """Resolves a parameters for the environment""" + if not source: + return None + with context.engine.scoped_session() as session: + return Environment.get_environment_parameters(session, source.environmentUri) diff --git a/backend/dataall/api/Objects/Environment/schema.py b/backend/dataall/api/Objects/Environment/schema.py index 528f7b649..1c1bae604 100644 --- a/backend/dataall/api/Objects/Environment/schema.py +++ b/backend/dataall/api/Objects/Environment/schema.py @@ -42,6 +42,14 @@ ) +EnvironmentParameter = gql.ObjectType( + name='EnvironmentParameter', + fields=[ + gql.Field(name='key', type=gql.String), + gql.Field(name='value', type=gql.String), + ] +) + Environment = gql.ObjectType( name='Environment', fields=[ @@ -76,7 +84,6 @@ ), gql.Field('validated', type=gql.Boolean), gql.Field('dashboardsEnabled', type=gql.Boolean), - gql.Field('notebooksEnabled', type=gql.Boolean), gql.Field('mlStudiosEnabled', type=gql.Boolean), gql.Field('pipelinesEnabled', type=gql.Boolean), gql.Field('warehousesEnabled', type=gql.Boolean), @@ -95,6 +102,11 @@ type=gql.ArrayType(gql.Ref('Vpc')), resolver=resolve_vpc_list, ), + gql.Field( + name='parameters', + resolver=resolve_parameters, + type=gql.ArrayType(gql.Ref('EnvironmentParameter')), + ), ], ) diff --git a/backend/dataall/api/Objects/KeyValueTag/resolvers.py b/backend/dataall/api/Objects/KeyValueTag/resolvers.py index f2df4730d..3f9141862 100644 --- a/backend/dataall/api/Objects/KeyValueTag/resolvers.py +++ b/backend/dataall/api/Objects/KeyValueTag/resolvers.py @@ -27,5 +27,5 @@ def update_key_value_tags(context: Context, source, input=None): data=input, check_perm=True, ) - stack_helper.deploy_stack(context=context, targetUri=input['targetUri']) + stack_helper.deploy_stack(targetUri=input['targetUri']) return kv_tags diff --git a/backend/dataall/api/Objects/Organization/resolvers.py b/backend/dataall/api/Objects/Organization/resolvers.py index f97f2849c..7c1bb7736 100644 --- a/backend/dataall/api/Objects/Organization/resolvers.py +++ b/backend/dataall/api/Objects/Organization/resolvers.py @@ -175,3 +175,16 @@ def list_organization_groups( data=filter, check_perm=True, ) + + +def resolve_organization_by_env(context, source, **kwargs): + """ + Resolves the organization for environmental resource. + """ + if not source: + return None + with context.engine.scoped_session() as session: + env: models.Environment = session.query(models.Environment).get( + source.environmentUri + ) + return session.query(models.Organization).get(env.organizationUri) diff --git a/backend/dataall/api/Objects/RedshiftCluster/resolvers.py b/backend/dataall/api/Objects/RedshiftCluster/resolvers.py index 3ee0f17df..0fa854532 100644 --- a/backend/dataall/api/Objects/RedshiftCluster/resolvers.py +++ b/backend/dataall/api/Objects/RedshiftCluster/resolvers.py @@ -42,7 +42,7 @@ def create( ) cluster.CFNStackName = stack.name if stack else None - stack_helper.deploy_stack(context=context, targetUri=cluster.clusterUri) + stack_helper.deploy_stack(targetUri=cluster.clusterUri) cluster.userRoleForCluster = RedshiftClusterRole.Creator.value return cluster @@ -121,7 +121,7 @@ def import_cluster(context: Context, source, environmentUri: str, clusterInput: log.info('Updating imported cluster iam_roles') Worker.queue(engine=context.engine, task_ids=[redshift_assign_role_task.taskUri]) - stack_helper.deploy_stack(context=context, targetUri=cluster.clusterUri) + stack_helper.deploy_stack(targetUri=cluster.clusterUri) return cluster @@ -262,12 +262,10 @@ def delete( if deleteFromAWS: stack_helper.delete_stack( - context=context, target_uri=clusterUri, accountid=env.AwsAccountId, cdk_role_arn=env.CDKRoleArn, - region=env.region, - target_type='redshiftcluster', + region=env.region ) return True @@ -526,7 +524,6 @@ def resolve_stack(context: Context, source: models.RedshiftCluster, **kwargs): if not source: return None return stack_helper.get_stack_with_cfn_resources( - context=context, targetUri=source.clusterUri, environmentUri=source.environmentUri, ) diff --git a/backend/dataall/api/Objects/SagemakerNotebook/__init__.py b/backend/dataall/api/Objects/SagemakerNotebook/__init__.py deleted file mode 100644 index dfa46b264..000000000 --- a/backend/dataall/api/Objects/SagemakerNotebook/__init__.py +++ /dev/null @@ -1,9 +0,0 @@ -from . import ( - input_types, - mutations, - queries, - resolvers, - schema, -) - -__all__ = ['resolvers', 'schema', 'input_types', 'queries', 'mutations'] diff --git a/backend/dataall/api/Objects/SagemakerNotebook/input_types.py b/backend/dataall/api/Objects/SagemakerNotebook/input_types.py deleted file mode 100644 index 7db8bfa24..000000000 --- a/backend/dataall/api/Objects/SagemakerNotebook/input_types.py +++ /dev/null @@ -1,38 +0,0 @@ -from ... import gql - -NewSagemakerNotebookInput = gql.InputType( - name='NewSagemakerNotebookInput ', - arguments=[ - gql.Argument('label', gql.NonNullableType(gql.String)), - gql.Argument('description', gql.String), - gql.Argument('environmentUri', gql.NonNullableType(gql.String)), - gql.Argument('SamlAdminGroupName', gql.NonNullableType(gql.String)), - gql.Argument('tags', gql.ArrayType(gql.String)), - gql.Argument('topics', gql.String), - gql.Argument('VpcId', gql.String), - gql.Argument('SubnetId', gql.String), - gql.Argument('VolumeSizeInGB', gql.Integer), - gql.Argument('InstanceType', gql.String), - ], -) - -ModifySagemakerNotebookInput = gql.InputType( - name='ModifySagemakerNotebookInput', - arguments=[ - gql.Argument('label', gql.String), - gql.Argument('tags', gql.ArrayType(gql.String)), - gql.Argument('description', gql.String), - ], -) - -SagemakerNotebookFilter = gql.InputType( - name='SagemakerNotebookFilter', - arguments=[ - gql.Argument('term', gql.String), - gql.Argument('page', gql.Integer), - gql.Argument('pageSize', gql.Integer), - gql.Argument('sort', gql.String), - gql.Argument('limit', gql.Integer), - gql.Argument('offset', gql.Integer), - ], -) diff --git a/backend/dataall/api/Objects/SagemakerNotebook/mutations.py b/backend/dataall/api/Objects/SagemakerNotebook/mutations.py deleted file mode 100644 index 895239797..000000000 --- a/backend/dataall/api/Objects/SagemakerNotebook/mutations.py +++ /dev/null @@ -1,33 +0,0 @@ -from ... import gql -from .resolvers import * - -createSagemakerNotebook = gql.MutationField( - name='createSagemakerNotebook', - args=[gql.Argument(name='input', type=gql.Ref('NewSagemakerNotebookInput'))], - type=gql.Ref('SagemakerNotebook'), - resolver=create_notebook, -) - -startSagemakerNotebook = gql.MutationField( - name='startSagemakerNotebook', - args=[gql.Argument(name='notebookUri', type=gql.NonNullableType(gql.String))], - type=gql.String, - resolver=start_notebook, -) - -stopSagemakerNotebook = gql.MutationField( - name='stopSagemakerNotebook', - args=[gql.Argument(name='notebookUri', type=gql.NonNullableType(gql.String))], - type=gql.String, - resolver=stop_notebook, -) - -deleteSagemakerNotebook = gql.MutationField( - name='deleteSagemakerNotebook', - args=[ - gql.Argument(name='notebookUri', type=gql.NonNullableType(gql.String)), - gql.Argument(name='deleteFromAWS', type=gql.Boolean), - ], - type=gql.String, - resolver=delete_notebook, -) diff --git a/backend/dataall/api/Objects/SagemakerNotebook/queries.py b/backend/dataall/api/Objects/SagemakerNotebook/queries.py deleted file mode 100644 index 54cc54c50..000000000 --- a/backend/dataall/api/Objects/SagemakerNotebook/queries.py +++ /dev/null @@ -1,24 +0,0 @@ -from ... import gql -from .resolvers import * - -getSagemakerNotebook = gql.QueryField( - name='getSagemakerNotebook', - args=[gql.Argument(name='notebookUri', type=gql.NonNullableType(gql.String))], - type=gql.Ref('SagemakerNotebook'), - resolver=get_notebook, -) - - -listSagemakerNotebooks = gql.QueryField( - name='listSagemakerNotebooks', - args=[gql.Argument('filter', gql.Ref('SagemakerNotebookFilter'))], - type=gql.Ref('SagemakerNotebookSearchResult'), - resolver=list_notebooks, -) - -getSagemakerNotebookPresignedUrl = gql.QueryField( - name='getSagemakerNotebookPresignedUrl', - args=[gql.Argument(name='notebookUri', type=gql.NonNullableType(gql.String))], - type=gql.String, - resolver=get_notebook_presigned_url, -) diff --git a/backend/dataall/api/Objects/SagemakerNotebook/resolvers.py b/backend/dataall/api/Objects/SagemakerNotebook/resolvers.py deleted file mode 100644 index eb5f2c32f..000000000 --- a/backend/dataall/api/Objects/SagemakerNotebook/resolvers.py +++ /dev/null @@ -1,218 +0,0 @@ -from .... import db -from ..Stack import stack_helper -from ....api.constants import SagemakerNotebookRole -from ....api.context import Context -from ....aws.handlers.sagemaker import Sagemaker -from ....db import permissions, models -from ....db.api import ResourcePolicy, Notebook, KeyValueTag, Stack - - -def create_notebook(context: Context, source, input: dict = None): - with context.engine.scoped_session() as session: - - notebook = Notebook.create_notebook( - session=session, - username=context.username, - groups=context.groups, - uri=input['environmentUri'], - data=input, - check_perm=True, - ) - - Stack.create_stack( - session=session, - environment_uri=notebook.environmentUri, - target_type='notebook', - target_uri=notebook.notebookUri, - target_label=notebook.label, - ) - - stack_helper.deploy_stack(context=context, targetUri=notebook.notebookUri) - - return notebook - - -def list_notebooks(context, source, filter: dict = None): - if not filter: - filter = {} - with context.engine.scoped_session() as session: - return Notebook.paginated_user_notebooks( - session=session, - username=context.username, - groups=context.groups, - uri=None, - data=filter, - check_perm=True, - ) - - -def get_notebook(context, source, notebookUri: str = None): - with context.engine.scoped_session() as session: - return Notebook.get_notebook( - session=session, - username=context.username, - groups=context.groups, - uri=notebookUri, - data=None, - check_perm=True, - ) - - -def resolve_status(context, source: models.SagemakerNotebook, **kwargs): - if not source: - return None - return Sagemaker.get_notebook_instance_status( - AwsAccountId=source.AWSAccountId, - region=source.region, - NotebookInstanceName=source.NotebookInstanceName, - ) - - -def start_notebook(context, source: models.SagemakerNotebook, notebookUri: str = None): - with context.engine.scoped_session() as session: - ResourcePolicy.check_user_resource_permission( - session=session, - username=context.username, - groups=context.groups, - resource_uri=notebookUri, - permission_name=permissions.UPDATE_NOTEBOOK, - ) - notebook = Notebook.get_notebook( - session=session, - username=context.username, - groups=context.groups, - uri=notebookUri, - data=None, - check_perm=True, - ) - Sagemaker.start_instance( - notebook.AWSAccountId, notebook.region, notebook.NotebookInstanceName - ) - return 'Starting' - - -def stop_notebook(context, source: models.SagemakerNotebook, notebookUri: str = None): - with context.engine.scoped_session() as session: - ResourcePolicy.check_user_resource_permission( - session=session, - username=context.username, - groups=context.groups, - resource_uri=notebookUri, - permission_name=permissions.UPDATE_NOTEBOOK, - ) - notebook = Notebook.get_notebook( - session=session, - username=context.username, - groups=context.groups, - uri=notebookUri, - data=None, - check_perm=True, - ) - Sagemaker.stop_instance( - notebook.AWSAccountId, notebook.region, notebook.NotebookInstanceName - ) - return 'Stopping' - - -def get_notebook_presigned_url( - context, source: models.SagemakerNotebook, notebookUri: str = None -): - with context.engine.scoped_session() as session: - ResourcePolicy.check_user_resource_permission( - session=session, - username=context.username, - groups=context.groups, - resource_uri=notebookUri, - permission_name=permissions.GET_NOTEBOOK, - ) - notebook = Notebook.get_notebook( - session=session, - username=context.username, - groups=context.groups, - uri=notebookUri, - data=None, - check_perm=True, - ) - url = Sagemaker.presigned_url( - notebook.AWSAccountId, notebook.region, notebook.NotebookInstanceName - ) - return url - - -def delete_notebook( - context, - source: models.SagemakerNotebook, - notebookUri: str = None, - deleteFromAWS: bool = None, -): - with context.engine.scoped_session() as session: - ResourcePolicy.check_user_resource_permission( - session=session, - resource_uri=notebookUri, - permission_name=permissions.DELETE_NOTEBOOK, - groups=context.groups, - username=context.username, - ) - notebook = Notebook.get_notebook_by_uri(session, notebookUri) - env: models.Environment = db.api.Environment.get_environment_by_uri( - session, notebook.environmentUri - ) - - KeyValueTag.delete_key_value_tags(session, notebook.notebookUri, 'notebook') - - session.delete(notebook) - - ResourcePolicy.delete_resource_policy( - session=session, - resource_uri=notebook.notebookUri, - group=notebook.SamlAdminGroupName, - ) - - if deleteFromAWS: - stack_helper.delete_stack( - context=context, - target_uri=notebookUri, - accountid=env.AwsAccountId, - cdk_role_arn=env.CDKRoleArn, - region=env.region, - target_type='notebook', - ) - - return True - - -def resolve_environment(context, source, **kwargs): - if not source: - return None - with context.engine.scoped_session() as session: - return session.query(models.Environment).get(source.environmentUri) - - -def resolve_organization(context, source, **kwargs): - if not source: - return None - with context.engine.scoped_session() as session: - env: models.Environment = session.query(models.Environment).get( - source.environmentUri - ) - return session.query(models.Organization).get(env.organizationUri) - - -def resolve_user_role(context: Context, source: models.SagemakerNotebook): - if not source: - return None - if source.owner == context.username: - return SagemakerNotebookRole.Creator.value - elif context.groups and source.SamlAdminGroupName in context.groups: - return SagemakerNotebookRole.Admin.value - return SagemakerNotebookRole.NoPermission.value - - -def resolve_stack(context: Context, source: models.SagemakerNotebook, **kwargs): - if not source: - return None - return stack_helper.get_stack_with_cfn_resources( - context=context, - targetUri=source.notebookUri, - environmentUri=source.environmentUri, - ) diff --git a/backend/dataall/api/Objects/SagemakerNotebook/schema.py b/backend/dataall/api/Objects/SagemakerNotebook/schema.py deleted file mode 100644 index 61e5c6bb5..000000000 --- a/backend/dataall/api/Objects/SagemakerNotebook/schema.py +++ /dev/null @@ -1,54 +0,0 @@ -from ... import gql -from .resolvers import * - -SagemakerNotebook = gql.ObjectType( - name='SagemakerNotebook', - fields=[ - gql.Field(name='notebookUri', type=gql.ID), - gql.Field(name='environmentUri', type=gql.NonNullableType(gql.String)), - gql.Field(name='label', type=gql.String), - gql.Field(name='description', type=gql.String), - gql.Field(name='tags', type=gql.ArrayType(gql.String)), - gql.Field(name='name', type=gql.String), - gql.Field(name='owner', type=gql.String), - gql.Field(name='created', type=gql.String), - gql.Field(name='updated', type=gql.String), - gql.Field(name='SamlAdminGroupName', type=gql.String), - gql.Field(name='VpcId', type=gql.String), - gql.Field(name='SubnetId', type=gql.String), - gql.Field(name='InstanceType', type=gql.String), - gql.Field(name='RoleArn', type=gql.String), - gql.Field(name='VolumeSizeInGB', type=gql.Integer), - gql.Field( - name='userRoleForNotebook', - type=SagemakerNotebookRole.toGraphQLEnum(), - resolver=resolve_user_role, - ), - gql.Field( - name='NotebookInstanceStatus', type=gql.String, resolver=resolve_status - ), - gql.Field( - name='environment', - type=gql.Ref('Environment'), - resolver=resolve_environment, - ), - gql.Field( - name='organization', - type=gql.Ref('Organization'), - resolver=resolve_organization, - ), - gql.Field(name='stack', type=gql.Ref('Stack'), resolver=resolve_stack), - ], -) - -SagemakerNotebookSearchResult = gql.ObjectType( - name='SagemakerNotebookSearchResult', - fields=[ - gql.Field(name='count', type=gql.Integer), - gql.Field(name='page', type=gql.Integer), - gql.Field(name='pages', type=gql.Integer), - gql.Field(name='hasNext', type=gql.Boolean), - gql.Field(name='hasPrevious', type=gql.Boolean), - gql.Field(name='nodes', type=gql.ArrayType(SagemakerNotebook)), - ], -) diff --git a/backend/dataall/api/Objects/SagemakerStudio/resolvers.py b/backend/dataall/api/Objects/SagemakerStudio/resolvers.py index 32d6bffa2..a70c6de4e 100644 --- a/backend/dataall/api/Objects/SagemakerStudio/resolvers.py +++ b/backend/dataall/api/Objects/SagemakerStudio/resolvers.py @@ -69,9 +69,7 @@ def create_sagemaker_studio_user_profile(context: Context, source, input: dict = target_label=sm_user_profile.label, ) - stack_helper.deploy_stack( - context=context, targetUri=sm_user_profile.sagemakerStudioUserProfileUri - ) + stack_helper.deploy_stack(targetUri=sm_user_profile.sagemakerStudioUserProfileUri) return sm_user_profile @@ -209,41 +207,21 @@ def delete_sagemaker_studio_user_profile( if deleteFromAWS: stack_helper.delete_stack( - context=context, target_uri=sagemakerStudioUserProfileUri, accountid=env.AwsAccountId, cdk_role_arn=env.CDKRoleArn, - region=env.region, - target_type='notebook', + region=env.region ) return True -def resolve_environment(context, source, **kwargs): - if not source: - return None - with context.engine.scoped_session() as session: - return session.query(models.Environment).get(source.environmentUri) - - -def resolve_organization(context, source, **kwargs): - if not source: - return None - with context.engine.scoped_session() as session: - env: models.Environment = session.query(models.Environment).get( - source.environmentUri - ) - return session.query(models.Organization).get(env.organizationUri) - - def resolve_stack( context: Context, source: models.SagemakerStudioUserProfile, **kwargs ): if not source: return None return stack_helper.get_stack_with_cfn_resources( - context=context, targetUri=source.sagemakerStudioUserProfileUri, environmentUri=source.environmentUri, ) diff --git a/backend/dataall/api/Objects/SagemakerStudio/schema.py b/backend/dataall/api/Objects/SagemakerStudio/schema.py index b19f1967b..9e07a620f 100644 --- a/backend/dataall/api/Objects/SagemakerStudio/schema.py +++ b/backend/dataall/api/Objects/SagemakerStudio/schema.py @@ -2,6 +2,10 @@ from .resolvers import * from ....api.constants import SagemakerStudioRole + +from dataall.api.Objects.Organization.resolvers import resolve_organization_by_env +from dataall.api.Objects.Environment.resolvers import resolve_environment + SagemakerStudio = gql.ObjectType( name='SagemakerStudio', fields=[ @@ -31,7 +35,7 @@ gql.Field( name='organization', type=gql.Ref('Organization'), - resolver=resolve_organization, + resolver=resolve_organization_by_env, ), gql.Field(name='stack', type=gql.Ref('Stack'), resolver=resolve_stack), ], @@ -98,7 +102,7 @@ gql.Field( name='organization', type=gql.Ref('Organization'), - resolver=resolve_organization, + resolver=resolve_organization_by_env, ), gql.Field(name='stack', type=gql.Ref('Stack'), resolver=resolve_stack), ], diff --git a/backend/dataall/api/Objects/ShareObject/resolvers.py b/backend/dataall/api/Objects/ShareObject/resolvers.py index f14642aeb..6bbb64bf4 100644 --- a/backend/dataall/api/Objects/ShareObject/resolvers.py +++ b/backend/dataall/api/Objects/ShareObject/resolvers.py @@ -281,16 +281,6 @@ def resolve_principal(context: Context, source: models.ShareObject, **kwargs): ) -def resolve_environment(context: Context, source: models.ShareObject, **kwargs): - if not source: - return None - with context.engine.scoped_session() as session: - environment = db.api.Environment.get_environment_by_uri( - session, source.environmentUri - ) - return environment - - def resolve_group(context: Context, source: models.ShareObject, **kwargs): if not source: return None diff --git a/backend/dataall/api/Objects/ShareObject/schema.py b/backend/dataall/api/Objects/ShareObject/schema.py index b045d6072..7a26154e3 100644 --- a/backend/dataall/api/Objects/ShareObject/schema.py +++ b/backend/dataall/api/Objects/ShareObject/schema.py @@ -1,4 +1,5 @@ from .resolvers import * +from dataall.api.Objects.Environment.resolvers import resolve_environment ShareableObject = gql.Union( name='ShareableObject', diff --git a/backend/dataall/api/Objects/Stack/resolvers.py b/backend/dataall/api/Objects/Stack/resolvers.py index 52988f163..8cd4b1edf 100644 --- a/backend/dataall/api/Objects/Stack/resolvers.py +++ b/backend/dataall/api/Objects/Stack/resolvers.py @@ -113,5 +113,5 @@ def update_stack( data={'targetType': targetType}, check_perm=True, ) - stack_helper.deploy_stack(context, stack.targetUri) + stack_helper.deploy_stack(stack.targetUri) return stack diff --git a/backend/dataall/api/Objects/Stack/stack_helper.py b/backend/dataall/api/Objects/Stack/stack_helper.py index ea2857ba9..659dd2ab0 100644 --- a/backend/dataall/api/Objects/Stack/stack_helper.py +++ b/backend/dataall/api/Objects/Stack/stack_helper.py @@ -9,9 +9,13 @@ from ....db import models from ....utils import Parameter +from dataall.core.config import config +from dataall.core.context import get_context -def get_stack_with_cfn_resources(context: Context, targetUri: str, environmentUri: str): - with context.engine.scoped_session() as session: + +def get_stack_with_cfn_resources(targetUri: str, environmentUri: str): + context = get_context() + with context.db_engine.scoped_session() as session: env: models.Environment = session.query(models.Environment).get(environmentUri) stack: models.Stack = db.api.Stack.find_stack_by_target_uri( session, target_uri=targetUri @@ -30,7 +34,7 @@ def get_stack_with_cfn_resources(context: Context, targetUri: str, environmentUr return stack cfn_task = save_describe_stack_task(session, env, stack, targetUri) - Worker.queue(engine=context.engine, task_ids=[cfn_task.taskUri]) + Worker.queue(engine=context.db_engine, task_ids=[cfn_task.taskUri]) return stack @@ -52,15 +56,16 @@ def save_describe_stack_task(session, environment, stack, target_uri): return cfn_task -def deploy_stack(context, targetUri): - with context.engine.scoped_session() as session: +def deploy_stack(targetUri): + context = get_context() + with context.db_engine.scoped_session() as session: stack: models.Stack = db.api.Stack.get_stack_by_target_uri( session, target_uri=targetUri ) envname = os.getenv('envname', 'local') if envname in ['local', 'pytest', 'dkrcompose']: - requests.post(f'{context.cdkproxyurl}/stack/{stack.stackUri}') + requests.post(f'{config.get_property("cdk_proxy_url")}/stack/{stack.stackUri}') else: cluster_name = Parameter().get_parameter( @@ -74,24 +79,25 @@ def deploy_stack(context, targetUri): ) session.add(task) session.commit() - Worker.queue(engine=context.engine, task_ids=[task.taskUri]) + Worker.queue(engine=context.db_engine, task_ids=[task.taskUri]) return stack -def deploy_dataset_stack(context, dataset: models.Dataset): +def deploy_dataset_stack(dataset: models.Dataset): """ Each dataset stack deployment triggers environment stack update to rebuild teams IAM roles data access policies """ - deploy_stack(context, dataset.datasetUri) - deploy_stack(context, dataset.environmentUri) + deploy_stack(dataset.datasetUri) + deploy_stack(dataset.environmentUri) def delete_stack( - context, target_uri, accountid, cdk_role_arn, region, target_type=None + target_uri, accountid, cdk_role_arn, region ): - with context.engine.scoped_session() as session: + context = get_context() + with context.db_engine.scoped_session() as session: stack: models.Stack = db.api.Stack.find_stack_by_target_uri( session, target_uri=target_uri ) @@ -109,14 +115,15 @@ def delete_stack( ) session.add(task) - Worker.queue(context.engine, [task.taskUri]) + Worker.queue(context.db_engine, [task.taskUri]) return True def delete_repository( - context, target_uri, accountid, cdk_role_arn, region, repo_name + target_uri, accountid, cdk_role_arn, region, repo_name ): - with context.engine.scoped_session() as session: + context = get_context() + with context.db_engine.scoped_session() as session: task = models.Task( targetUri=target_uri, action='repo.datapipeline.delete', @@ -128,5 +135,5 @@ def delete_repository( }, ) session.add(task) - Worker.queue(context.engine, [task.taskUri]) + Worker.queue(context.db_engine, [task.taskUri]) return True diff --git a/backend/dataall/api/Objects/__init__.py b/backend/dataall/api/Objects/__init__.py index 1239c4273..060f2ba6e 100644 --- a/backend/dataall/api/Objects/__init__.py +++ b/backend/dataall/api/Objects/__init__.py @@ -38,7 +38,6 @@ Notification, Vpc, Tenant, - SagemakerNotebook, KeyValueTag, Vote, ) @@ -92,7 +91,6 @@ def adapted(obj, info, **kwargs): username=info.context['username'], groups=info.context['groups'], schema=info.context['schema'], - cdkproxyurl=info.context['cdkproxyurl'], ), source=obj or None, **kwargs, diff --git a/backend/dataall/api/constants.py b/backend/dataall/api/constants.py index 746b9e0dd..ad712b4a4 100644 --- a/backend/dataall/api/constants.py +++ b/backend/dataall/api/constants.py @@ -110,13 +110,6 @@ class ScheduledQueryRole(GraphQLEnumMapper): NoPermission = '000' -class SagemakerNotebookRole(GraphQLEnumMapper): - Creator = '950' - Admin = '900' - Shared = '300' - NoPermission = '000' - - class SagemakerStudioRole(GraphQLEnumMapper): Creator = '950' Admin = '900' diff --git a/backend/dataall/api/context.py b/backend/dataall/api/context.py index d2b0a88f0..a210dc0a1 100644 --- a/backend/dataall/api/context.py +++ b/backend/dataall/api/context.py @@ -5,10 +5,8 @@ def __init__( es=None, username=None, groups=None, - cdkproxyurl=None, ): self.engine = engine self.es = es self.username = username self.groups = groups - self.cdkproxyurl = cdkproxyurl diff --git a/backend/dataall/aws/handlers/ecs.py b/backend/dataall/aws/handlers/ecs.py index 539bd9320..fa816ee2e 100644 --- a/backend/dataall/aws/handlers/ecs.py +++ b/backend/dataall/aws/handlers/ecs.py @@ -63,6 +63,7 @@ def run_share_management_ecs_task(envname, share_uri, handler): subnets, [ {'name': 'shareUri', 'value': share_uri}, + {'name': 'config_location', 'value': '/config.json'}, {'name': 'envname', 'value': envname}, {'name': 'handler', 'value': handler}, { @@ -119,6 +120,7 @@ def run_cdkproxy_task(stack_uri): subnets, [ {'name': 'stackUri', 'value': stack_uri}, + {'name': 'config_location', 'value': '/config.json'}, {'name': 'envname', 'value': envname}, { 'name': 'AWS_REGION', diff --git a/backend/dataall/aws/handlers/sagemaker.py b/backend/dataall/aws/handlers/sagemaker.py deleted file mode 100644 index 1d2c76b0e..000000000 --- a/backend/dataall/aws/handlers/sagemaker.py +++ /dev/null @@ -1,86 +0,0 @@ -import logging - -from .sts import SessionHelper -from botocore.exceptions import ClientError - -logger = logging.getLogger(__name__) - - -class Sagemaker: - @staticmethod - def client(AwsAccountId, region): - session = SessionHelper.remote_session(AwsAccountId) - return session.client('sagemaker', region_name=region) - - @staticmethod - def get_notebook_instance_status(AwsAccountId, region, NotebookInstanceName): - try: - client = Sagemaker.client(AwsAccountId, region) - response = client.describe_notebook_instance( - NotebookInstanceName=NotebookInstanceName - ) - return response.get('NotebookInstanceStatus', 'NOT FOUND') - except ClientError as e: - logger.error( - f'Could not retrieve instance {NotebookInstanceName} status due to: {e} ' - ) - return 'NOT FOUND' - - @staticmethod - def presigned_url(AwsAccountId, region, NotebookInstanceName): - try: - client = Sagemaker.client(AwsAccountId, region) - response = client.create_presigned_notebook_instance_url( - NotebookInstanceName=NotebookInstanceName - ) - return response['AuthorizedUrl'] - except ClientError as e: - raise e - - @staticmethod - def presigned_url_jupyterlab(AwsAccountId, region, NotebookInstanceName): - try: - client = Sagemaker.client(AwsAccountId, region) - response = client.create_presigned_notebook_instance_url( - NotebookInstanceName=NotebookInstanceName - ) - url_parts = response['AuthorizedUrl'].split('?authToken') - url = url_parts[0] + '/lab' + '?authToken' + url_parts[1] - return url - except ClientError as e: - raise e - - @staticmethod - def start_instance(AwsAccountId, region, NotebookInstanceName): - try: - client = Sagemaker.client(AwsAccountId, region) - status = Sagemaker.get_notebook_instance_status( - AwsAccountId, region, NotebookInstanceName - ) - client.start_notebook_instance(NotebookInstanceName=NotebookInstanceName) - return status - except ClientError as e: - return e - - @staticmethod - def stop_instance(AwsAccountId, region, NotebookInstanceName): - try: - client = Sagemaker.client(AwsAccountId, region) - client.stop_notebook_instance(NotebookInstanceName=NotebookInstanceName) - except ClientError as e: - raise e - - @staticmethod - def get_security_groups(AwsAccountId, region): - try: - session = SessionHelper.remote_session(AwsAccountId) - client = session.client('ec2', region_name=region) - response = client.describe_security_groups() - sgnames = [SG['GroupName'] for SG in response['SecurityGroups']] - sgindex = [ - i for i, s in enumerate(sgnames) if 'DefaultLinuxSecurityGroup' in s - ] - SecurityGroupIds = [response['SecurityGroups'][sgindex[0]]['GroupId']] - return SecurityGroupIds - except ClientError as e: - raise e diff --git a/backend/dataall/cdkproxy/app.py b/backend/dataall/cdkproxy/app.py index cf208f7fb..fde8fa2eb 100644 --- a/backend/dataall/cdkproxy/app.py +++ b/backend/dataall/cdkproxy/app.py @@ -7,16 +7,19 @@ from tabulate import tabulate from dataall.cdkproxy.stacks import instanciate_stack +from dataall.modules.loader import load_modules, ImportMode print(sys.version) logger = logging.getLogger('cdkapp process') logger.setLevel('INFO') +load_modules(modes=[ImportMode.CDK]) + class CdkRunner: @staticmethod def create(): - logger.info('Creating Stack') + logger.info('Ï') app = App() # 1. Reading info from context # 1.1 Reading account from context diff --git a/backend/dataall/cdkproxy/cdk_cli_wrapper.py b/backend/dataall/cdkproxy/cdk_cli_wrapper.py index 8066d9350..67f0790d0 100644 --- a/backend/dataall/cdkproxy/cdk_cli_wrapper.py +++ b/backend/dataall/cdkproxy/cdk_cli_wrapper.py @@ -104,6 +104,7 @@ def deploy_cdk_stack(engine: Engine, stackid: str, app_path: str = None, path: s 'PYTHONPATH': python_path, 'CURRENT_AWS_ACCOUNT': this_aws_account, 'envname': os.environ.get('envname', 'local'), + 'config_location': "/config.json" } if creds: env.update( diff --git a/backend/dataall/cdkproxy/requirements.txt b/backend/dataall/cdkproxy/requirements.txt index f2da84ebe..486357963 100644 --- a/backend/dataall/cdkproxy/requirements.txt +++ b/backend/dataall/cdkproxy/requirements.txt @@ -17,4 +17,5 @@ werkzeug==2.2.3 constructs>=10.0.0,<11.0.0 git-remote-codecommit==1.16 aws-ddk==0.5.1 -aws-ddk-core==0.5.1 \ No newline at end of file +aws-ddk-core==0.5.1 +deprecated==1.2.13 \ No newline at end of file diff --git a/backend/dataall/cdkproxy/stacks/__init__.py b/backend/dataall/cdkproxy/stacks/__init__.py index 81abe263b..3857b30c0 100644 --- a/backend/dataall/cdkproxy/stacks/__init__.py +++ b/backend/dataall/cdkproxy/stacks/__init__.py @@ -2,7 +2,6 @@ from .environment import EnvironmentSetup from .pipeline import PipelineStack from .manager import stack, instanciate_stack, StackManager -from .notebook import SagemakerNotebook from .redshift_cluster import RedshiftStack from .sagemakerstudio import SagemakerStudioUserProfile diff --git a/backend/dataall/cdkproxy/stacks/dataset.py b/backend/dataall/cdkproxy/stacks/dataset.py index 4ee53beb1..133b5f928 100644 --- a/backend/dataall/cdkproxy/stacks/dataset.py +++ b/backend/dataall/cdkproxy/stacks/dataset.py @@ -538,6 +538,6 @@ def __init__(self, scope, id, target_uri: str = None, **kwargs): Tags.of(self).add('Classification', dataset.confidentiality) - TagsUtil.add_tags(self) + TagsUtil.add_tags(stack=self, model=models.Dataset, target_type="dataset") CDKNagUtil.check_rules(self) diff --git a/backend/dataall/cdkproxy/stacks/environment.py b/backend/dataall/cdkproxy/stacks/environment.py index e54b24988..55bccca0e 100644 --- a/backend/dataall/cdkproxy/stacks/environment.py +++ b/backend/dataall/cdkproxy/stacks/environment.py @@ -589,7 +589,7 @@ def __init__(self, scope, id, target_uri: str = None, **kwargs): parameter_name=f'/dataall/{self._environment.environmentUri}/sagemaker/sagemakerstudio/domain_id', ) - TagsUtil.add_tags(self) + TagsUtil.add_tags(stack=self, model=models.Environment, target_type="environment") CDKNagUtil.check_rules(self) diff --git a/backend/dataall/cdkproxy/stacks/pipeline.py b/backend/dataall/cdkproxy/stacks/pipeline.py index 995422283..616151a4e 100644 --- a/backend/dataall/cdkproxy/stacks/pipeline.py +++ b/backend/dataall/cdkproxy/stacks/pipeline.py @@ -370,7 +370,7 @@ def __init__(self, scope, id, target_uri: str = None, **kwargs): value=codepipeline_pipeline.pipeline_name, ) - TagsUtil.add_tags(self) + TagsUtil.add_tags(stack=self, model=models.DataPipeline, target_type="pipeline") CDKNagUtil.check_rules(self) diff --git a/backend/dataall/cdkproxy/stacks/policies/__init__.py b/backend/dataall/cdkproxy/stacks/policies/__init__.py index e69de29bb..964bb37a5 100644 --- a/backend/dataall/cdkproxy/stacks/policies/__init__.py +++ b/backend/dataall/cdkproxy/stacks/policies/__init__.py @@ -0,0 +1,9 @@ +"""Contains the code for creating environment policies""" + +from dataall.cdkproxy.stacks.policies import ( + _lambda, cloudformation, codestar, databrew, glue, + lakeformation, quicksight, redshift, stepfunctions, data_policy, service_policy +) + +__all__ = ["_lambda", "cloudformation", "codestar", "databrew", "glue", "lakeformation", "quicksight", + "redshift", "stepfunctions", "data_policy", "service_policy", "mlstudio"] diff --git a/backend/dataall/cdkproxy/stacks/policies/_lambda.py b/backend/dataall/cdkproxy/stacks/policies/_lambda.py index 2aa76257b..31a10baad 100644 --- a/backend/dataall/cdkproxy/stacks/policies/_lambda.py +++ b/backend/dataall/cdkproxy/stacks/policies/_lambda.py @@ -1,9 +1,13 @@ +from dataall.db import permissions from .service_policy import ServicePolicy from aws_cdk import aws_iam as iam class Lambda(ServicePolicy): - def get_statements(self): + def get_statements(self, group_permissions, **kwargs): + if permissions.CREATE_PIPELINE not in group_permissions: + return [] + statements = [ iam.PolicyStatement( actions=[ diff --git a/backend/dataall/cdkproxy/stacks/policies/cloudformation.py b/backend/dataall/cdkproxy/stacks/policies/cloudformation.py index 12eb8297a..5afb45d1b 100644 --- a/backend/dataall/cdkproxy/stacks/policies/cloudformation.py +++ b/backend/dataall/cdkproxy/stacks/policies/cloudformation.py @@ -3,7 +3,7 @@ class Cloudformation(ServicePolicy): - def get_statements(self): + def get_statements(self, group_permissions, **kwargs): statements = [ iam.PolicyStatement( actions=[ diff --git a/backend/dataall/cdkproxy/stacks/policies/codestar.py b/backend/dataall/cdkproxy/stacks/policies/codestar.py index 021409d61..32ffe25b7 100644 --- a/backend/dataall/cdkproxy/stacks/policies/codestar.py +++ b/backend/dataall/cdkproxy/stacks/policies/codestar.py @@ -1,9 +1,13 @@ +from dataall.db import permissions from .service_policy import ServicePolicy from aws_cdk import aws_iam as iam class CodeStar(ServicePolicy): - def get_statements(self): + def get_statements(self, group_permissions, **kwargs): + if permissions.CREATE_PIPELINE not in group_permissions: + return [] + statements = [ iam.PolicyStatement( actions=[ diff --git a/backend/dataall/cdkproxy/stacks/policies/databrew.py b/backend/dataall/cdkproxy/stacks/policies/databrew.py index 19aa41293..270879639 100644 --- a/backend/dataall/cdkproxy/stacks/policies/databrew.py +++ b/backend/dataall/cdkproxy/stacks/policies/databrew.py @@ -1,9 +1,13 @@ +from dataall.db import permissions from .service_policy import ServicePolicy from aws_cdk import aws_iam as iam class Databrew(ServicePolicy): - def get_statements(self): + def get_statements(self, group_permissions, **kwargs): + if permissions.CREATE_DATASET not in group_permissions: + return [] + statements = [ iam.PolicyStatement(actions=['databrew:List*'], resources=['*']), iam.PolicyStatement( diff --git a/backend/dataall/cdkproxy/stacks/policies/glue.py b/backend/dataall/cdkproxy/stacks/policies/glue.py index 896622cfe..aa1dbe479 100644 --- a/backend/dataall/cdkproxy/stacks/policies/glue.py +++ b/backend/dataall/cdkproxy/stacks/policies/glue.py @@ -1,9 +1,13 @@ +from dataall.db import permissions from .service_policy import ServicePolicy from aws_cdk import aws_iam as iam class Glue(ServicePolicy): - def get_statements(self): + def get_statements(self, group_permissions, **kwargs): + if permissions.CREATE_DATASET not in group_permissions: + return [] + statements = [ iam.PolicyStatement( actions=[ diff --git a/backend/dataall/cdkproxy/stacks/policies/lakeformation.py b/backend/dataall/cdkproxy/stacks/policies/lakeformation.py index e495a1da2..3eb5d835c 100644 --- a/backend/dataall/cdkproxy/stacks/policies/lakeformation.py +++ b/backend/dataall/cdkproxy/stacks/policies/lakeformation.py @@ -1,10 +1,14 @@ from aws_cdk import aws_iam as iam +from dataall.db import permissions from .service_policy import ServicePolicy class LakeFormation(ServicePolicy): - def get_statements(self): + def get_statements(self, group_permissions, **kwargs): + if permissions.CREATE_DATASET not in group_permissions: + return [] + return [ iam.PolicyStatement( actions=[ diff --git a/backend/dataall/cdkproxy/stacks/policies/mlstudio.py b/backend/dataall/cdkproxy/stacks/policies/mlstudio.py new file mode 100644 index 000000000..05b44c903 --- /dev/null +++ b/backend/dataall/cdkproxy/stacks/policies/mlstudio.py @@ -0,0 +1,16 @@ +from dataall.cdkproxy.stacks.policies.service_policy import ServicePolicy + +from dataall.db import permissions +from dataall.modules.common.sagemaker.cdk.statements import create_sagemaker_statements + + +class SagemakerPolicy(ServicePolicy): + """ + Creates a sagemaker policy for accessing and interacting with ML studio + """ + + def get_statements(self, group_permissions, **kwargs): + if permissions.CREATE_SGMSTUDIO_NOTEBOOK not in group_permissions: + return [] + + return create_sagemaker_statements(self.account, self.region, self.tag_key, self.tag_value) diff --git a/backend/dataall/cdkproxy/stacks/policies/quicksight.py b/backend/dataall/cdkproxy/stacks/policies/quicksight.py index 487ddb429..e67b3436c 100644 --- a/backend/dataall/cdkproxy/stacks/policies/quicksight.py +++ b/backend/dataall/cdkproxy/stacks/policies/quicksight.py @@ -1,10 +1,14 @@ from aws_cdk import aws_iam as iam +from dataall.db import permissions from .service_policy import ServicePolicy class QuickSight(ServicePolicy): - def get_statements(self): + def get_statements(self, group_permissions, **kwargs): + if permissions.CREATE_DASHBOARD not in group_permissions: + return [] + return [ iam.PolicyStatement( actions=[ diff --git a/backend/dataall/cdkproxy/stacks/policies/redshift.py b/backend/dataall/cdkproxy/stacks/policies/redshift.py index 1c02dee66..5b3684ad4 100644 --- a/backend/dataall/cdkproxy/stacks/policies/redshift.py +++ b/backend/dataall/cdkproxy/stacks/policies/redshift.py @@ -1,10 +1,14 @@ from aws_cdk import aws_iam as iam +from dataall.db import permissions from .service_policy import ServicePolicy class Redshift(ServicePolicy): - def get_statements(self): + def get_statements(self, group_permissions, **kwargs): + if permissions.CREATE_REDSHIFT_CLUSTER not in group_permissions: + return [] + return [ iam.PolicyStatement( actions=[ diff --git a/backend/dataall/cdkproxy/stacks/policies/sagemaker.py b/backend/dataall/cdkproxy/stacks/policies/sagemaker.py deleted file mode 100644 index fee698989..000000000 --- a/backend/dataall/cdkproxy/stacks/policies/sagemaker.py +++ /dev/null @@ -1,140 +0,0 @@ -from .service_policy import ServicePolicy -from aws_cdk import aws_iam as iam - - -class Sagemaker(ServicePolicy): - def get_statements(self): - statements = [ - iam.PolicyStatement( - actions=[ - 'sagemaker:List*', - 'sagemaker:Describe*', - 'sagemaker:BatchGet*', - 'sagemaker:BatchDescribe*', - 'sagemaker:Search', - 'sagemaker:RenderUiTemplate', - 'sagemaker:GetSearchSuggestions', - 'sagemaker:QueryLineage', - 'sagemaker:CreateNotebookInstanceLifecycleConfig', - 'sagemaker:DeleteNotebookInstanceLifecycleConfig', - 'sagemaker:CreatePresignedDomainUrl' - ], - resources=['*'], - ), - iam.PolicyStatement( - actions=['sagemaker:AddTags'], - resources=['*'], - conditions={ - 'StringEquals': { - f'aws:ResourceTag/{self.tag_key}': [self.tag_value] - } - }, - ), - iam.PolicyStatement( - actions=['sagemaker:Delete*'], - resources=[ - f'arn:aws:sagemaker:{self.region}:{self.account}:notebook-instance/*', - f'arn:aws:sagemaker:{self.region}:{self.account}:algorithm/*', - f'arn:aws:sagemaker:{self.region}:{self.account}:model/*', - f'arn:aws:sagemaker:{self.region}:{self.account}:endpoint/*', - f'arn:aws:sagemaker:{self.region}:{self.account}:endpoint-config/*', - f'arn:aws:sagemaker:{self.region}:{self.account}:experiment/*', - f'arn:aws:sagemaker:{self.region}:{self.account}:experiment-trial/*', - f'arn:aws:sagemaker:{self.region}:{self.account}:experiment-group/*', - f'arn:aws:sagemaker:{self.region}:{self.account}:model-bias-job-definition/*', - f'arn:aws:sagemaker:{self.region}:{self.account}:model-package/*', - f'arn:aws:sagemaker:{self.region}:{self.account}:model-package-group/*', - f'arn:aws:sagemaker:{self.region}:{self.account}:model-quality-job-definition/*', - f'arn:aws:sagemaker:{self.region}:{self.account}:monitoring-schedule/*', - f'arn:aws:sagemaker:{self.region}:{self.account}:pipeline/*', - f'arn:aws:sagemaker:{self.region}:{self.account}:project/*', - f'arn:aws:sagemaker:{self.region}:{self.account}:app/*' - ], - conditions={ - 'StringEquals': { - f'aws:ResourceTag/{self.tag_key}': [self.tag_value] - } - }, - ), - iam.PolicyStatement( - actions=['sagemaker:CreateApp'], - resources=['*'] - ), - iam.PolicyStatement( - actions=['sagemaker:Create*'], - resources=['*'], - ), - iam.PolicyStatement( - actions=['sagemaker:Start*', 'sagemaker:Stop*'], - resources=[ - f'arn:aws:sagemaker:{self.region}:{self.account}:notebook-instance/*', - f'arn:aws:sagemaker:{self.region}:{self.account}:monitoring-schedule/*', - f'arn:aws:sagemaker:{self.region}:{self.account}:pipeline/*', - f'arn:aws:sagemaker:{self.region}:{self.account}:training-job/*', - f'arn:aws:sagemaker:{self.region}:{self.account}:processing-job/*', - f'arn:aws:sagemaker:{self.region}:{self.account}:hyper-parameter-tuning-job/*', - f'arn:aws:sagemaker:{self.region}:{self.account}:transform-job/*', - f'arn:aws:sagemaker:{self.region}:{self.account}:automl-job/*' - ], - conditions={ - 'StringEquals': { - f'aws:ResourceTag/{self.tag_key}': [self.tag_value] - } - }, - ), - iam.PolicyStatement( - actions=['sagemaker:Update*'], - resources=[ - f'arn:aws:sagemaker:{self.region}:{self.account}:notebook-instance/*', - f'arn:aws:sagemaker:{self.region}:{self.account}:notebook-instance-lifecycle-config/*', - f'arn:aws:sagemaker:{self.region}:{self.account}:studio-lifecycle-config/*', - f'arn:aws:sagemaker:{self.region}:{self.account}:endpoint/*', - f'arn:aws:sagemaker:{self.region}:{self.account}:pipeline/*', - f'arn:aws:sagemaker:{self.region}:{self.account}:pipeline-execution/*', - f'arn:aws:sagemaker:{self.region}:{self.account}:monitoring-schedule/*', - f'arn:aws:sagemaker:{self.region}:{self.account}:experiment/*', - f'arn:aws:sagemaker:{self.region}:{self.account}:experiment-trial/*', - f'arn:aws:sagemaker:{self.region}:{self.account}:experiment-trial-component/*', - f'arn:aws:sagemaker:{self.region}:{self.account}:model-package/*', - f'arn:aws:sagemaker:{self.region}:{self.account}:training-job/*', - f'arn:aws:sagemaker:{self.region}:{self.account}:project/*' - ], - conditions={ - 'StringEquals': { - f'aws:ResourceTag/{self.tag_key}': [self.tag_value] - } - }, - ), - iam.PolicyStatement( - actions=['sagemaker:InvokeEndpoint', 'sagemaker:InvokeEndpointAsync'], - resources=[ - f'arn:aws:sagemaker:{self.region}:{self.account}:endpoint/*' - ], - conditions={ - 'StringEquals': { - f'aws:ResourceTag/{self.tag_key}': [self.tag_value] - } - }, - ), - iam.PolicyStatement( - actions=[ - 'logs:CreateLogGroup', - 'logs:CreateLogStream', - 'logs:PutLogEvents'], - resources=[ - f'arn:aws:logs:{self.region}:{self.account}:log-group:/aws/sagemaker/*', - f'arn:aws:logs:{self.region}:{self.account}:log-group:/aws/sagemaker/*:log-stream:*', - ] - ), - iam.PolicyStatement( - actions=[ - 'ecr:GetAuthorizationToken', - 'ecr:BatchCheckLayerAvailability', - 'ecr:GetDownloadUrlForLayer', - 'ecr:BatchGetImage'], - resources=[ - '*' - ] - ) - ] - return statements diff --git a/backend/dataall/cdkproxy/stacks/policies/service_policy.py b/backend/dataall/cdkproxy/stacks/policies/service_policy.py index 007ee0044..2bea680fe 100644 --- a/backend/dataall/cdkproxy/stacks/policies/service_policy.py +++ b/backend/dataall/cdkproxy/stacks/policies/service_policy.py @@ -3,8 +3,6 @@ from aws_cdk import aws_iam -from ....db import permissions - logger = logging.getLogger() @@ -37,17 +35,6 @@ def generate_policies(self) -> [aws_iam.ManagedPolicy]: """ Creates aws_iam.Policy based on declared subclasses of Policy object """ - from .redshift import Redshift - from .databrew import Databrew - from .lakeformation import LakeFormation - from .sagemaker import Sagemaker - from ._lambda import Lambda - from .codestar import CodeStar - from .glue import Glue - from .stepfunctions import StepFunctions - from .quicksight import QuickSight - from .cloudformation import Cloudformation - policies: [aws_iam.ManagedPolicy] = [ # This policy covers the minumum actions required independent # of the service permissions given to the group. @@ -107,27 +94,9 @@ def generate_policies(self) -> [aws_iam.ManagedPolicy]: services = ServicePolicy.__subclasses__() - if permissions.CREATE_REDSHIFT_CLUSTER not in self.permissions: - services.remove(Redshift) - if permissions.CREATE_DATASET not in self.permissions: - services.remove(Databrew) - services.remove(LakeFormation) - services.remove(Glue) - if ( - permissions.CREATE_NOTEBOOK not in self.permissions - and permissions.CREATE_SGMSTUDIO_NOTEBOOK not in self.permissions - ): - services.remove(Sagemaker) - if permissions.CREATE_PIPELINE not in self.permissions: - services.remove(Lambda) - services.remove(CodeStar) - services.remove(StepFunctions) - if permissions.CREATE_DASHBOARD not in self.permissions: - services.remove(QuickSight) - statements = [] for service in services: - statements.extend(service.get_statements(self)) + statements.extend(service.get_statements(self, self.permissions)) statements_chunks: list = [ statements[i : i + 8] for i in range(0, len(statements), 8) @@ -144,7 +113,7 @@ def generate_policies(self) -> [aws_iam.ManagedPolicy]: ) return policies - def get_statements(self, **kwargs) -> List[aws_iam.PolicyStatement]: + def get_statements(self, group_permissions, **kwargs) -> List[aws_iam.PolicyStatement]: """ This method implements a policy based on a tag key and optionally a resource prefix :return: list diff --git a/backend/dataall/cdkproxy/stacks/policies/stepfunctions.py b/backend/dataall/cdkproxy/stacks/policies/stepfunctions.py index d8611b001..845c9cd34 100644 --- a/backend/dataall/cdkproxy/stacks/policies/stepfunctions.py +++ b/backend/dataall/cdkproxy/stacks/policies/stepfunctions.py @@ -1,10 +1,14 @@ from aws_cdk import aws_iam as iam +from dataall.db import permissions from .service_policy import ServicePolicy class StepFunctions(ServicePolicy): - def get_statements(self): + def get_statements(self, group_permissions, **kwargs): + if permissions.CREATE_PIPELINE not in group_permissions: + return [] + return [ iam.PolicyStatement( actions=[ diff --git a/backend/dataall/cdkproxy/stacks/redshift_cluster.py b/backend/dataall/cdkproxy/stacks/redshift_cluster.py index ee7839fc2..f546786b6 100644 --- a/backend/dataall/cdkproxy/stacks/redshift_cluster.py +++ b/backend/dataall/cdkproxy/stacks/redshift_cluster.py @@ -184,6 +184,6 @@ def __init__(self, scope, id: str, target_uri: str = None, **kwargs) -> None: hosted_rotation=aws_secretsmanager.HostedRotation.redshift_single_user(), ) - TagsUtil.add_tags(self) + TagsUtil.add_tags(stack=self, model=models.RedshiftCluster, target_type="redshift") CDKNagUtil.check_rules(self) diff --git a/backend/dataall/cdkproxy/stacks/sagemakerstudio.py b/backend/dataall/cdkproxy/stacks/sagemakerstudio.py index a858cdfc2..0e002c7e5 100644 --- a/backend/dataall/cdkproxy/stacks/sagemakerstudio.py +++ b/backend/dataall/cdkproxy/stacks/sagemakerstudio.py @@ -93,6 +93,6 @@ def __init__(self, scope, id: str, target_uri: str = None, **kwargs) -> None: .to_string() ) - TagsUtil.add_tags(self) + TagsUtil.add_tags(stack=self, model=models.SagemakerStudioUserProfile, target_type="mlstudio") CDKNagUtil.check_rules(self) diff --git a/backend/dataall/core/__init__.py b/backend/dataall/core/__init__.py new file mode 100644 index 000000000..4790c2f83 --- /dev/null +++ b/backend/dataall/core/__init__.py @@ -0,0 +1 @@ +"""The package contains the core functionality that is required by data.all to work correctly""" diff --git a/backend/dataall/core/config.py b/backend/dataall/core/config.py new file mode 100644 index 000000000..80f13e490 --- /dev/null +++ b/backend/dataall/core/config.py @@ -0,0 +1,70 @@ +"""Reads and encapsulates the configuration provided in config.json""" +import json +import copy +from typing import Any, Dict +import os +from pathlib import Path + + +class _Config: + """A container of properties in the configuration file + and any other that can be specified/overwritten later in the application""" + + def __init__(self): + self._config = _Config._read_config_file() + + def get_property(self, key: str, default=None) -> Any: + """ + Retrieves a copy of the property + Config uses dot as a separator to navigate easy to the needed property e.g. + some.needed.parameter is equivalent of config["some"]["needed"]["parameter"] + It enables fast navigation for any nested parameter + """ + res = self._config + + props = key.split(".") + + # going through the hierarchy of json + for prop in props: + if prop not in res: + if default is not None: + return default + + raise KeyError(f"Couldn't find a property {key} in the config") + + res = res[prop] + return copy.deepcopy(res) + + def set_property(self, key: str, value: Any) -> None: + """ + Sets a property into the config + If the property has dot it will be split to nested levels + """ + conf = self._config + props = key.split(".") + + for i, prop in enumerate(props): + if i == len(props) - 1: + conf[prop] = value + else: + conf[prop] = conf[prop] if prop in conf is not None else {} + conf = conf[prop] + + @staticmethod + def _read_config_file() -> Dict[str, Any]: + with open(_Config._path_to_file()) as config_file: + return json.load(config_file) + + @staticmethod + def _path_to_file() -> str: + """Tries to get a property. If not defined it tries to resolve the config from the current file's directory""" + path = os.getenv("config_location") + if path: + return path + return os.path.join(Path(__file__).parents[3], "config.json") + + def __repr__(self): + return str(self._config) + + +config = _Config() diff --git a/backend/dataall/core/context.py b/backend/dataall/core/context.py new file mode 100644 index 000000000..dcf594896 --- /dev/null +++ b/backend/dataall/core/context.py @@ -0,0 +1,42 @@ +""" +API for request context. +Request context is a storage for associated with the request and should accessible from any part of application +that in the request scope + +The class uses Flask's approach to handle request: ThreadLocal +That approach should work fine for AWS Lambdas and local server that uses Flask app +""" + +from dataclasses import dataclass +from typing import List + +from dataall.db.connection import Engine +from threading import local +import opensearchpy + + +_request_storage = local() + + +@dataclass(frozen=True) +class RequestContext: + """Contains API for every graphql request""" + db_engine: Engine + username: str + groups: List[str] + es_engine: opensearchpy.OpenSearch + + +def get_context() -> RequestContext: + """Retrieves context associated with a request""" + return _request_storage.context + + +def set_context(context: RequestContext) -> None: + """Retrieves context associated with a request""" + _request_storage.context = context + + +def dispose_context() -> None: + """Dispose context after the request completion""" + _request_storage.context = None diff --git a/backend/dataall/core/environment/__init__.py b/backend/dataall/core/environment/__init__.py new file mode 100644 index 000000000..9673de7d7 --- /dev/null +++ b/backend/dataall/core/environment/__init__.py @@ -0,0 +1 @@ +"""The central package of the application to work with the environment""" diff --git a/backend/dataall/core/environment/db/__init__.py b/backend/dataall/core/environment/db/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/backend/dataall/core/environment/db/repositories.py b/backend/dataall/core/environment/db/repositories.py new file mode 100644 index 000000000..4b243beab --- /dev/null +++ b/backend/dataall/core/environment/db/repositories.py @@ -0,0 +1,33 @@ +from dataall.core.environment.models import EnvironmentParameter +from sqlalchemy.sql import and_ + + +class EnvironmentParameterRepository: + """CRUD operations for EnvironmentParameter""" + + def __init__(self, session): + self._session = session + + def get_param(self, env_uri, param_key): + return self._session.query(EnvironmentParameter).filter( + and_( + EnvironmentParameter.environmentUri == env_uri, + EnvironmentParameter.key == param_key + ) + ).first() + + def get_params(self, env_uri): + return self._session.query(EnvironmentParameter).filter( + EnvironmentParameter.environmentUri == env_uri + ) + + def update_params(self, env_uri, params): + """Rewrite all parameters for the environment""" + self.delete_params(env_uri) + self._session.add_all(params) + + def delete_params(self, env_uri): + """Erase all environment parameters""" + self._session.query(EnvironmentParameter).filter( + EnvironmentParameter.environmentUri == env_uri + ).delete() diff --git a/backend/dataall/core/environment/models.py b/backend/dataall/core/environment/models.py new file mode 100644 index 000000000..624ac0090 --- /dev/null +++ b/backend/dataall/core/environment/models.py @@ -0,0 +1,20 @@ +"""The package contains the database models that are related to the environment""" + +from sqlalchemy import Column, String, ForeignKey +from dataall.db import Resource, Base + + +class EnvironmentParameter(Base): + """Represent the parameter of the environment""" + __tablename__ = 'environment_parameters' + environmentUri = Column(String, ForeignKey("environment.environmentUri"), primary_key=True) + key = Column('paramKey', String, primary_key=True) + value = Column('paramValue', String, nullable=True) + + def __init__(self, env_uri, key, value): + self.environmentUri = env_uri + self.key = key + self.value = value + + def __repr__(self): + return f'EnvironmentParameter(paramKey={self.key}, paramValue={self.value})' diff --git a/backend/dataall/core/permission_checker.py b/backend/dataall/core/permission_checker.py new file mode 100644 index 000000000..784665996 --- /dev/null +++ b/backend/dataall/core/permission_checker.py @@ -0,0 +1,128 @@ +""" +Contains decorators that check if user has a permission to access +and interact with resources or do some actions in the app +""" +import contextlib +from typing import Protocol + +from dataall.core.context import RequestContext, get_context +from dataall.db.api import TenantPolicy, ResourcePolicy, Environment + + +class Identifiable(Protocol): + """Protocol to identify resources for checking permissions""" + def get_uri(self) -> str: + ... + + +def _check_group_environment_permission(session, permission, uri, admin_group): + context: RequestContext = get_context() + Environment.check_group_environment_permission( + session=session, + username=context.username, + groups=context.groups, + uri=uri, + group=admin_group, + permission_name=permission, + ) + + +def _check_tenant_permission(session, permission): + context: RequestContext = get_context() + TenantPolicy.check_user_tenant_permission( + session=session, + username=context.username, + groups=context.groups, + tenant_name='dataall', + permission_name=permission + ) + + +def _check_resource_permission(session, uri, permission): + context: RequestContext = get_context() + ResourcePolicy.check_user_resource_permission( + session=session, + username=context.username, + groups=context.groups, + resource_uri=uri, + permission_name=permission, + ) + + +def _process_func(func): + """Helper function that helps decorate methods/functions""" + def no_decorated(f): + return f + + static_func = False + try: + func.__func__ + static_func = True + fn = func.__func__ + except AttributeError: + fn = func + + # returns a function to call and static decorator if applied + return fn, staticmethod if static_func else no_decorated + + +def has_resource_permission(permission: str, resource_name: str = None): + """ + Decorator that check if a user has access to the resource. + The method or function decorated with this decorator must have a URI of accessing resource + Good rule of thumb: if there is a URI that accesses a specific resource, + hence it has URI - it must be decorated with this decorator + """ + def decorator(f): + fn, fn_decorator = _process_func(f) + + def decorated(*args, **kwargs): + uri: str + if resource_name: + resource: Identifiable = kwargs[resource_name] + uri = resource.get_uri() + else: + uri = kwargs["uri"] + + with get_context().db_engine.scoped_session() as session: + _check_resource_permission(session, uri, permission) + + return fn(*args, **kwargs) + + return fn_decorator(decorated) + + return decorator + + +def has_tenant_permission(permission: str): + """ + Decorator to check if a user has a permission to do some action. + All the information about the user is retrieved from RequestContext + """ + def decorator(f): + fn, fn_decorator = _process_func(f) + + def decorated(*args, **kwargs): + with get_context().db_engine.scoped_session() as session: + _check_tenant_permission(session, permission) + + return fn(*args, **kwargs) + + return fn_decorator(decorated) + + return decorator + + +def has_group_permission(permission): + def decorator(f): + fn, fn_decorator = _process_func(f) + + def decorated(*args, admin_group, uri, **kwargs): + with get_context().db_engine.scoped_session() as session: + _check_group_environment_permission(session, permission, uri, admin_group) + + return fn(*args, uri=uri, admin_group=admin_group, **kwargs) + + return fn_decorator(decorated) + + return decorator diff --git a/backend/dataall/db/api/__init__.py b/backend/dataall/db/api/__init__.py index 765d1b68a..01647c81b 100644 --- a/backend/dataall/db/api/__init__.py +++ b/backend/dataall/db/api/__init__.py @@ -18,7 +18,6 @@ from .notification import Notification from .redshift_cluster import RedshiftCluster from .vpc import Vpc -from .notebook import Notebook from .sgm_studio_notebook import SgmStudioNotebook from .dashboard import Dashboard from .pipeline import Pipeline diff --git a/backend/dataall/db/api/environment.py b/backend/dataall/db/api/environment.py index e41386024..4a436bf9a 100644 --- a/backend/dataall/db/api/environment.py +++ b/backend/dataall/db/api/environment.py @@ -22,6 +22,8 @@ ) from ..models.Permission import PermissionType from ..paginator import Page, paginate +from dataall.core.environment.models import EnvironmentParameter +from ...core.environment.db.repositories import EnvironmentParameterRepository from ...utils.naming_convention import ( NamingConventionService, NamingConventionPattern, @@ -56,15 +58,17 @@ def create_environment(session, username, groups, uri, data=None, check_perm=Non EnvironmentDefaultIAMRoleArn=f'arn:aws:iam::{data.get("AwsAccountId")}:role/{data.get("EnvironmentDefaultIAMRoleName")}', CDKRoleArn=f"arn:aws:iam::{data.get('AwsAccountId')}:role/{data['cdk_role_name']}", dashboardsEnabled=data.get('dashboardsEnabled', False), - notebooksEnabled=data.get('notebooksEnabled', True), mlStudiosEnabled=data.get('mlStudiosEnabled', True), pipelinesEnabled=data.get('pipelinesEnabled', True), warehousesEnabled=data.get('warehousesEnabled', True), resourcePrefix=data.get('resourcePrefix'), ) + session.add(env) session.commit() + Environment._update_env_parameters(session, env, data) + env.EnvironmentDefaultBucketName = NamingConventionService( target_uri=env.environmentUri, target_label=env.label, @@ -185,8 +189,6 @@ def update_environment(session, username, groups, uri, data=None, check_perm=Non environment.tags = data.get('tags') if 'dashboardsEnabled' in data.keys(): environment.dashboardsEnabled = data.get('dashboardsEnabled') - if 'notebooksEnabled' in data.keys(): - environment.notebooksEnabled = data.get('notebooksEnabled') if 'mlStudiosEnabled' in data.keys(): environment.mlStudiosEnabled = data.get('mlStudiosEnabled') if 'pipelinesEnabled' in data.keys(): @@ -196,6 +198,8 @@ def update_environment(session, username, groups, uri, data=None, check_perm=Non if data.get('resourcePrefix'): environment.resourcePrefix = data.get('resourcePrefix') + Environment._update_env_parameters(session, environment, data) + ResourcePolicy.attach_resource_policy( session=session, resource_uri=environment.environmentUri, @@ -205,6 +209,19 @@ def update_environment(session, username, groups, uri, data=None, check_perm=Non ) return environment + @staticmethod + def _update_env_parameters(session, env: models.Environment, data): + """Removes old parameters and creates new parameters associated with the environment""" + params = data.get("parameters") + if not params: + return + + env_uri = env.environmentUri + new_params = [EnvironmentParameter( + env_uri, param.get("key"), param.get("value") + ) for param in params] + EnvironmentParameterRepository(session).update_params(env_uri, new_params) + @staticmethod @has_tenant_perm(permissions.MANAGE_ENVIRONMENTS) @has_resource_perm(permissions.INVITE_ENVIRONMENT_GROUP) @@ -276,9 +293,6 @@ def validate_permissions(session, uri, g_permissions, group): if permissions.CREATE_REDSHIFT_CLUSTER in g_permissions: g_permissions.append(permissions.LIST_ENVIRONMENT_REDSHIFT_CLUSTERS) - if permissions.CREATE_NOTEBOOK in g_permissions: - g_permissions.append(permissions.LIST_ENVIRONMENT_NOTEBOOKS) - if permissions.CREATE_SGMSTUDIO_NOTEBOOK in g_permissions: g_permissions.append(permissions.LIST_ENVIRONMENT_SGMSTUDIO_NOTEBOOKS) @@ -1284,63 +1298,6 @@ def paginated_environment_redshift_clusters( page=data.get('page', 1), ).to_dict() - @staticmethod - def list_environment_objects(session, environment_uri): - environment_objects = [] - datasets = ( - session.query(models.Dataset.label, models.Dataset.datasetUri) - .filter(models.Dataset.environmentUri == environment_uri) - .all() - ) - notebooks = ( - session.query( - models.SagemakerNotebook.label, - models.SagemakerNotebook.notebookUri, - ) - .filter(models.SagemakerNotebook.environmentUri == environment_uri) - .all() - ) - ml_studios = ( - session.query( - models.SagemakerStudioUserProfile.label, - models.SagemakerStudioUserProfile.sagemakerStudioUserProfileUri, - ) - .filter(models.SagemakerStudioUserProfile.environmentUri == environment_uri) - .all() - ) - redshift_clusters = ( - session.query( - models.RedshiftCluster.label, models.RedshiftCluster.clusterUri - ) - .filter(models.RedshiftCluster.environmentUri == environment_uri) - .all() - ) - pipelines = ( - session.query(models.DataPipeline.label, models.DataPipeline.DataPipelineUri) - .filter(models.DataPipeline.environmentUri == environment_uri) - .all() - ) - dashboards = ( - session.query(models.Dashboard.label, models.Dashboard.dashboardUri) - .filter(models.Dashboard.environmentUri == environment_uri) - .all() - ) - if datasets: - environment_objects.append({'type': 'Datasets', 'data': datasets}) - if notebooks: - environment_objects.append({'type': 'Notebooks', 'data': notebooks}) - if ml_studios: - environment_objects.append({'type': 'MLStudios', 'data': ml_studios}) - if redshift_clusters: - environment_objects.append( - {'type': 'RedshiftClusters', 'data': redshift_clusters} - ) - if pipelines: - environment_objects.append({'type': 'Pipelines', 'data': pipelines}) - if dashboards: - environment_objects.append({'type': 'Dashboards', 'data': dashboards}) - return environment_objects - @staticmethod def list_group_datasets(session, username, groups, uri, data=None, check_perm=None): if not data: @@ -1372,14 +1329,6 @@ def delete_environment(session, username, groups, uri, data=None, check_perm=Non 'environment', Environment.get_environment_by_uri(session, uri) ) - environment_objects = Environment.list_environment_objects(session, uri) - - if environment_objects: - raise exceptions.EnvironmentResourcesFound( - action='Delete Environment', - message='Delete all environment related objects before proceeding', - ) - env_groups = ( session.query(models.EnvironmentGroup) .filter(models.EnvironmentGroup.environmentUri == uri) @@ -1420,6 +1369,8 @@ def delete_environment(session, username, groups, uri, data=None, check_perm=Non ) session.delete(share) + EnvironmentParameterRepository(session).delete_params(environment.environmentUri) + return session.delete(environment) @staticmethod @@ -1465,3 +1416,7 @@ def check_group_environment_permission( resource_uri=uri, permission_name=permission_name, ) + + @staticmethod + def get_environment_parameters(session, env_uri): + return EnvironmentParameterRepository(session).get_params(env_uri) diff --git a/backend/dataall/db/api/notebook.py b/backend/dataall/db/api/notebook.py deleted file mode 100644 index de3d712d6..000000000 --- a/backend/dataall/db/api/notebook.py +++ /dev/null @@ -1,154 +0,0 @@ -import logging - -from sqlalchemy import or_ -from sqlalchemy.orm import Query - -from . import ( - has_tenant_perm, - has_resource_perm, - ResourcePolicy, - Environment, -) -from .. import models, exceptions, permissions, paginate -from ...utils.naming_convention import ( - NamingConventionService, - NamingConventionPattern, -) -from ...utils.slugify import slugify - -logger = logging.getLogger(__name__) - - -class Notebook: - @staticmethod - @has_tenant_perm(permissions.MANAGE_NOTEBOOKS) - @has_resource_perm(permissions.CREATE_NOTEBOOK) - def create_notebook( - session, username, groups, uri, data=None, check_perm=None - ) -> models.SagemakerNotebook: - - Notebook.validate_params(data) - - Environment.check_group_environment_permission( - session=session, - username=username, - groups=groups, - uri=uri, - group=data['SamlAdminGroupName'], - permission_name=permissions.CREATE_NOTEBOOK, - ) - - env = Environment.get_environment_by_uri(session, uri) - - if not env.notebooksEnabled: - raise exceptions.UnauthorizedOperation( - action=permissions.CREATE_NOTEBOOK, - message=f'Notebooks feature is disabled for the environment {env.label}', - ) - - env_group: models.EnvironmentGroup = data.get( - 'environment', - Environment.get_environment_group( - session, - group_uri=data['SamlAdminGroupName'], - environment_uri=env.environmentUri, - ), - ) - - notebook = models.SagemakerNotebook( - label=data.get('label', 'Untitled'), - environmentUri=env.environmentUri, - description=data.get('description', 'No description provided'), - NotebookInstanceName=slugify(data.get('label'), separator=''), - NotebookInstanceStatus='NotStarted', - AWSAccountId=env.AwsAccountId, - region=env.region, - RoleArn=env_group.environmentIAMRoleArn, - owner=username, - SamlAdminGroupName=data.get('SamlAdminGroupName', env.SamlGroupName), - tags=data.get('tags', []), - VpcId=data.get('VpcId'), - SubnetId=data.get('SubnetId'), - VolumeSizeInGB=data.get('VolumeSizeInGB', 32), - InstanceType=data.get('InstanceType', 'ml.t3.medium'), - ) - session.add(notebook) - session.commit() - - notebook.NotebookInstanceName = NamingConventionService( - target_uri=notebook.notebookUri, - target_label=notebook.label, - pattern=NamingConventionPattern.NOTEBOOK, - resource_prefix=env.resourcePrefix, - ).build_compliant_name() - - ResourcePolicy.attach_resource_policy( - session=session, - group=data['SamlAdminGroupName'], - permissions=permissions.NOTEBOOK_ALL, - resource_uri=notebook.notebookUri, - resource_type=models.SagemakerNotebook.__name__, - ) - - if env.SamlGroupName != notebook.SamlAdminGroupName: - ResourcePolicy.attach_resource_policy( - session=session, - group=env.SamlGroupName, - permissions=permissions.NOTEBOOK_ALL, - resource_uri=notebook.notebookUri, - resource_type=models.SagemakerNotebook.__name__, - ) - - return notebook - - @staticmethod - def validate_params(data): - if not data: - raise exceptions.RequiredParameter('data') - if not data.get('environmentUri'): - raise exceptions.RequiredParameter('environmentUri') - if not data.get('label'): - raise exceptions.RequiredParameter('name') - - @staticmethod - def query_user_notebooks(session, username, groups, filter) -> Query: - query = session.query(models.SagemakerNotebook).filter( - or_( - models.SagemakerNotebook.owner == username, - models.SagemakerNotebook.SamlAdminGroupName.in_(groups), - ) - ) - if filter and filter.get('term'): - query = query.filter( - or_( - models.SagemakerNotebook.description.ilike( - filter.get('term') + '%%' - ), - models.SagemakerNotebook.label.ilike(filter.get('term') + '%%'), - ) - ) - return query - - @staticmethod - def paginated_user_notebooks( - session, username, groups, uri, data=None, check_perm=None - ) -> dict: - return paginate( - query=Notebook.query_user_notebooks(session, username, groups, data), - page=data.get('page', 1), - page_size=data.get('pageSize', 10), - ).to_dict() - - @staticmethod - @has_resource_perm(permissions.GET_NOTEBOOK) - def get_notebook(session, username, groups, uri, data=None, check_perm=True): - return Notebook.get_notebook_by_uri(session, uri) - - @staticmethod - def get_notebook_by_uri(session, uri) -> models.SagemakerNotebook: - if not uri: - raise exceptions.RequiredParameter('URI') - notebook = session.query(models.SagemakerNotebook).get(uri) - if not notebook: - raise exceptions.ObjectNotFound('SagemakerNotebook', uri) - return notebook diff --git a/backend/dataall/db/api/permission_checker.py b/backend/dataall/db/api/permission_checker.py index 6fb69d6dd..ba6ba8b24 100644 --- a/backend/dataall/db/api/permission_checker.py +++ b/backend/dataall/db/api/permission_checker.py @@ -1,7 +1,12 @@ from ..api.resource_policy import ResourcePolicy from ..api.tenant_policy import TenantPolicy +from deprecated import deprecated +@deprecated( + reason="old API. Should be removed at the end of modularization. Use dataall.core.permission_checker", + action="once" +) def has_resource_perm(permission): def decorator(f): static_func = False @@ -38,6 +43,10 @@ def decorated( return decorator +@deprecated( + reason="old API. Should be removed at the end of modularization. Use dataall.core.permission_checker", + action="once" +) def has_tenant_perm(permission): def decorator(f): static_func = False diff --git a/backend/dataall/db/api/sgm_studio_notebook.py b/backend/dataall/db/api/sgm_studio_notebook.py index d947c5b0d..a721c82b2 100644 --- a/backend/dataall/db/api/sgm_studio_notebook.py +++ b/backend/dataall/db/api/sgm_studio_notebook.py @@ -3,6 +3,8 @@ from sqlalchemy import or_ from sqlalchemy.orm import Query +from dataall.db.permissions import MANAGE_SGMSTUDIO_NOTEBOOKS + from .. import exceptions, permissions, paginate, models from . import ( has_tenant_perm, @@ -16,7 +18,7 @@ class SgmStudioNotebook: @staticmethod - @has_tenant_perm(permissions.MANAGE_NOTEBOOKS) + @has_tenant_perm(MANAGE_SGMSTUDIO_NOTEBOOKS) @has_resource_perm(permissions.CREATE_SGMSTUDIO_NOTEBOOK) def create_notebook(session, username, groups, uri, data=None, check_perm=None): diff --git a/backend/dataall/db/api/target_type.py b/backend/dataall/db/api/target_type.py index 9dc4b54fd..0eecb2569 100644 --- a/backend/dataall/db/api/target_type.py +++ b/backend/dataall/db/api/target_type.py @@ -1,64 +1,43 @@ import logging -from .. import exceptions, permissions -from .. import models +from dataall.db import exceptions, permissions logger = logging.getLogger(__name__) class TargetType: - @staticmethod - def get_target_type_permissions(): - return dict( - dataset=(permissions.GET_DATASET, permissions.UPDATE_DATASET), - environment=(permissions.GET_ENVIRONMENT, permissions.UPDATE_ENVIRONMENT), - notebook=(permissions.GET_NOTEBOOK, permissions.UPDATE_NOTEBOOK), - mlstudio=( - permissions.GET_SGMSTUDIO_NOTEBOOK, - permissions.UPDATE_SGMSTUDIO_NOTEBOOK, - ), - pipeline=(permissions.GET_PIPELINE, permissions.UPDATE_PIPELINE), - redshift=( - permissions.GET_REDSHIFT_CLUSTER, - permissions.GET_REDSHIFT_CLUSTER, - ), - ) + """Resolves the read/write permissions for different type of resources (target types)""" + _TARGET_TYPES = {} + + def __init__(self, name, read_permission, write_permission): + self.name = name + self.read_permission = read_permission + self.write_permission = write_permission + + TargetType._TARGET_TYPES[name] = self @staticmethod def get_resource_update_permission_name(target_type): TargetType.is_supported_target_type(target_type) - return TargetType.get_target_type_permissions()[target_type][1] + return TargetType._TARGET_TYPES[target_type].write_permission @staticmethod def get_resource_read_permission_name(target_type): TargetType.is_supported_target_type(target_type) - return TargetType.get_target_type_permissions()[target_type][0] + return TargetType._TARGET_TYPES[target_type].read_permission @staticmethod def is_supported_target_type(target_type): - supported_types = [ - 'dataset', - 'environment', - 'notebook', - 'mlstudio', - 'pipeline', - 'redshift', - ] - if target_type not in supported_types: + if target_type not in TargetType._TARGET_TYPES: raise exceptions.InvalidInput( 'targetType', target_type, - ' or '.join(supported_types), + ' or '.join(TargetType._TARGET_TYPES.keys()), ) - @staticmethod - def get_target_type(model_name): - target_types_map = dict( - environment=models.Environment, - dataset=models.Dataset, - notebook=models.SagemakerNotebook, - mlstudio=models.SagemakerStudioUserProfile, - pipeline=models.DataPipeline, - redshift=models.RedshiftCluster, - ) - return [k for k, v in target_types_map.items() if v == model_name][0] + +TargetType("dataset", permissions.GET_DATASET, permissions.UPDATE_DATASET) +TargetType("environment", permissions.GET_ENVIRONMENT, permissions.UPDATE_ENVIRONMENT) +TargetType("mlstudio", permissions.GET_SGMSTUDIO_NOTEBOOK, permissions.UPDATE_SGMSTUDIO_NOTEBOOK) +TargetType("pipeline", permissions.GET_PIPELINE, permissions.UPDATE_PIPELINE) +TargetType("redshift", permissions.GET_REDSHIFT_CLUSTER, permissions.GET_REDSHIFT_CLUSTER) diff --git a/backend/dataall/db/models/Dashboard.py b/backend/dataall/db/models/Dashboard.py index a4bd97587..1a24ef1cb 100644 --- a/backend/dataall/db/models/Dashboard.py +++ b/backend/dataall/db/models/Dashboard.py @@ -1,4 +1,4 @@ -from sqlalchemy import Column, String +from sqlalchemy import Column, String, ForeignKey from sqlalchemy.orm import query_expression from .. import Base, Resource, utils @@ -6,7 +6,7 @@ class Dashboard(Resource, Base): __tablename__ = 'dashboard' - environmentUri = Column(String, nullable=False) + environmentUri = Column(String, ForeignKey("environment.environmentUri"), nullable=False) organizationUri = Column(String, nullable=False) dashboardUri = Column( String, nullable=False, primary_key=True, default=utils.uuid('dashboard') diff --git a/backend/dataall/db/models/DataPipeline.py b/backend/dataall/db/models/DataPipeline.py index 4146a9db1..208443566 100644 --- a/backend/dataall/db/models/DataPipeline.py +++ b/backend/dataall/db/models/DataPipeline.py @@ -1,4 +1,4 @@ -from sqlalchemy import Column, String +from sqlalchemy import Column, String, ForeignKey from sqlalchemy.orm import query_expression from sqlalchemy.dialects import postgresql @@ -7,7 +7,7 @@ class DataPipeline(Resource, Base): __tablename__ = 'datapipeline' - environmentUri = Column(String, nullable=False) + environmentUri = Column(String, ForeignKey("environment.environmentUri"), nullable=False) DataPipelineUri = Column( String, nullable=False, primary_key=True, default=utils.uuid('DataPipelineUri') ) diff --git a/backend/dataall/db/models/Dataset.py b/backend/dataall/db/models/Dataset.py index 35de117f9..71a95fe0e 100644 --- a/backend/dataall/db/models/Dataset.py +++ b/backend/dataall/db/models/Dataset.py @@ -1,4 +1,4 @@ -from sqlalchemy import Boolean, Column, String +from sqlalchemy import Boolean, Column, String, ForeignKey from sqlalchemy.dialects import postgresql from sqlalchemy.orm import query_expression @@ -7,7 +7,7 @@ class Dataset(Resource, Base): __tablename__ = 'dataset' - environmentUri = Column(String, nullable=False) + environmentUri = Column(String, ForeignKey("environment.environmentUri"), nullable=False) organizationUri = Column(String, nullable=False) datasetUri = Column(String, primary_key=True, default=utils.uuid('dataset')) region = Column(String, default='eu-west-1') diff --git a/backend/dataall/db/models/Enums.py b/backend/dataall/db/models/Enums.py index 469eafa7a..8e981242b 100644 --- a/backend/dataall/db/models/Enums.py +++ b/backend/dataall/db/models/Enums.py @@ -73,13 +73,6 @@ class ScheduledQueryRole(Enum): NoPermission = '000' -class SagemakerNotebookRole(Enum): - Creator = '950' - Admin = '900' - Shared = '300' - NoPermission = '000' - - class SagemakerStudioRole(Enum): Creator = '950' Admin = '900' diff --git a/backend/dataall/db/models/Environment.py b/backend/dataall/db/models/Environment.py index 295f56dac..55246aabd 100644 --- a/backend/dataall/db/models/Environment.py +++ b/backend/dataall/db/models/Environment.py @@ -25,7 +25,6 @@ class Environment(Resource, Base): roleCreated = Column(Boolean, nullable=False, default=False) dashboardsEnabled = Column(Boolean, default=False) - notebooksEnabled = Column(Boolean, default=True) mlStudiosEnabled = Column(Boolean, default=True) pipelinesEnabled = Column(Boolean, default=True) warehousesEnabled = Column(Boolean, default=True) diff --git a/backend/dataall/db/models/RedshiftCluster.py b/backend/dataall/db/models/RedshiftCluster.py index db40200ae..2997da416 100644 --- a/backend/dataall/db/models/RedshiftCluster.py +++ b/backend/dataall/db/models/RedshiftCluster.py @@ -1,4 +1,4 @@ -from sqlalchemy import Column, String, ARRAY, Integer, Boolean +from sqlalchemy import Column, String, ARRAY, Integer, Boolean, ForeignKey from sqlalchemy.orm import query_expression from .. import utils, Resource, Base @@ -6,7 +6,7 @@ class RedshiftCluster(Resource, Base): __tablename__ = 'redshiftcluster' - environmentUri = Column(String, nullable=False) + environmentUri = Column(String, ForeignKey("environment.environmentUri"), nullable=False) organizationUri = Column(String, nullable=False) clusterUri = Column(String, primary_key=True, default=utils.uuid('cluster')) clusterArn = Column(String) diff --git a/backend/dataall/db/models/SagemakerStudio.py b/backend/dataall/db/models/SagemakerStudio.py index 3d469f0c9..9f1f167da 100644 --- a/backend/dataall/db/models/SagemakerStudio.py +++ b/backend/dataall/db/models/SagemakerStudio.py @@ -1,4 +1,4 @@ -from sqlalchemy import Column, String +from sqlalchemy import Column, String, ForeignKey from sqlalchemy.orm import query_expression from .. import Base @@ -21,7 +21,7 @@ class SagemakerStudio(Resource, Base): class SagemakerStudioUserProfile(Resource, Base): __tablename__ = 'sagemaker_studio_user_profile' - environmentUri = Column(String, nullable=False) + environmentUri = Column(String, ForeignKey("environment.environmentUri"), nullable=False) sagemakerStudioUserProfileUri = Column( String, primary_key=True, default=utils.uuid('sagemakerstudiouserprofile') ) diff --git a/backend/dataall/db/models/__init__.py b/backend/dataall/db/models/__init__.py index fdc5fbedf..1ce567c87 100644 --- a/backend/dataall/db/models/__init__.py +++ b/backend/dataall/db/models/__init__.py @@ -27,7 +27,6 @@ from .RedshiftClusterDatasetTable import RedshiftClusterDatasetTable from .ResourcePolicy import ResourcePolicy from .ResourcePolicyPermission import ResourcePolicyPermission -from .SagemakerNotebook import SagemakerNotebook from .SagemakerStudio import SagemakerStudio, SagemakerStudioUserProfile from .ShareObject import ShareObject from .ShareObjectItem import ShareObjectItem diff --git a/backend/dataall/db/permissions.py b/backend/dataall/db/permissions.py index cf921a30c..6a26b2033 100644 --- a/backend/dataall/db/permissions.py +++ b/backend/dataall/db/permissions.py @@ -25,7 +25,6 @@ MANAGE_DATASETS = 'MANAGE_DATASETS' MANAGE_REDSHIFT_CLUSTERS = 'MANAGE_REDSHIFT_CLUSTERS' MANAGE_DASHBOARDS = 'MANAGE_DASHBOARDS' -MANAGE_NOTEBOOKS = 'MANAGE_NOTEBOOKS' MANAGE_PIPELINES = 'MANAGE_PIPELINES' MANAGE_GROUPS = 'MANAGE_GROUPS' MANAGE_ENVIRONMENT = 'MANAGE_ENVIRONMENT' @@ -33,6 +32,7 @@ MANAGE_GLOSSARIES = 'MANAGE_GLOSSARIES' MANAGE_ENVIRONMENTS = 'MANAGE_ENVIRONMENTS' MANAGE_ORGANIZATIONS = 'MANAGE_ORGANIZATIONS' +MANAGE_SGMSTUDIO_NOTEBOOKS = 'MANAGE_SGMSTUDIO_NOTEBOOKS' """ ENVIRONMENT @@ -57,8 +57,6 @@ LIST_ENVIRONMENT_SHARED_WITH_OBJECTS = 'LIST_ENVIRONMENT_SHARED_WITH_OBJECTS' CREATE_REDSHIFT_CLUSTER = 'CREATE_REDSHIFT_CLUSTER' LIST_ENVIRONMENT_REDSHIFT_CLUSTERS = 'LIST_ENVIRONMENT_REDSHIFT_CLUSTERS' -CREATE_NOTEBOOK = 'CREATE_NOTEBOOK' -LIST_ENVIRONMENT_NOTEBOOKS = 'LIST_ENVIRONMENT_NOTEBOOKS' CREATE_SGMSTUDIO_NOTEBOOK = 'CREATE_SGMSTUDIO_NOTEBOOK' LIST_ENVIRONMENT_SGMSTUDIO_NOTEBOOKS = 'LIST_ENVIRONMENT_SGMSTUDIO_NOTEBOOKS' CREATE_DASHBOARD = 'CREATE_DASHBOARD' @@ -81,8 +79,6 @@ RUN_ATHENA_QUERY, CREATE_REDSHIFT_CLUSTER, LIST_ENVIRONMENT_REDSHIFT_CLUSTERS, - CREATE_NOTEBOOK, - LIST_ENVIRONMENT_NOTEBOOKS, CREATE_SGMSTUDIO_NOTEBOOK, LIST_ENVIRONMENT_SGMSTUDIO_NOTEBOOKS, CREATE_DASHBOARD, @@ -101,7 +97,6 @@ CREATE_SHARE_OBJECT, CREATE_REDSHIFT_CLUSTER, CREATE_SGMSTUDIO_NOTEBOOK, - CREATE_NOTEBOOK, CREATE_DASHBOARD, CREATE_PIPELINE, CREATE_NETWORK, @@ -126,8 +121,6 @@ CREATE_SHARE_OBJECT, CREATE_REDSHIFT_CLUSTER, LIST_ENVIRONMENT_REDSHIFT_CLUSTERS, - CREATE_NOTEBOOK, - LIST_ENVIRONMENT_NOTEBOOKS, LIST_ENVIRONMENT_SHARED_WITH_OBJECTS, CREATE_SGMSTUDIO_NOTEBOOK, LIST_ENVIRONMENT_SGMSTUDIO_NOTEBOOKS, @@ -284,19 +277,18 @@ MANAGE_DATASETS, MANAGE_REDSHIFT_CLUSTERS, MANAGE_DASHBOARDS, - MANAGE_NOTEBOOKS, MANAGE_PIPELINES, MANAGE_WORKSHEETS, MANAGE_GLOSSARIES, MANAGE_GROUPS, MANAGE_ENVIRONMENTS, MANAGE_ORGANIZATIONS, + MANAGE_SGMSTUDIO_NOTEBOOKS, ] TENANT_ALL_WITH_DESC = {k: k for k in TENANT_ALL} TENANT_ALL_WITH_DESC[MANAGE_DASHBOARDS] = 'Manage dashboards' TENANT_ALL_WITH_DESC[MANAGE_DATASETS] = 'Manage datasets' -TENANT_ALL_WITH_DESC[MANAGE_NOTEBOOKS] = 'Manage notebooks' TENANT_ALL_WITH_DESC[MANAGE_REDSHIFT_CLUSTERS] = 'Manage Redshift clusters' TENANT_ALL_WITH_DESC[MANAGE_GLOSSARIES] = 'Manage glossaries' TENANT_ALL_WITH_DESC[MANAGE_WORKSHEETS] = 'Manage worksheets' @@ -304,6 +296,7 @@ TENANT_ALL_WITH_DESC[MANAGE_GROUPS] = 'Manage teams' TENANT_ALL_WITH_DESC[MANAGE_PIPELINES] = 'Manage pipelines' TENANT_ALL_WITH_DESC[MANAGE_ORGANIZATIONS] = 'Manage organizations' +TENANT_ALL_WITH_DESC[MANAGE_SGMSTUDIO_NOTEBOOKS] = 'Manage ML studio notebooks' """ REDSHIFT CLUSTER @@ -335,18 +328,6 @@ GET_REDSHIFT_CLUSTER_CREDENTIALS, ] -""" -NOTEBOOKS -""" -GET_NOTEBOOK = 'GET_NOTEBOOK' -UPDATE_NOTEBOOK = 'UPDATE_NOTEBOOK' -DELETE_NOTEBOOK = 'DELETE_NOTEBOOK' -NOTEBOOK_ALL = [ - GET_NOTEBOOK, - DELETE_NOTEBOOK, - UPDATE_NOTEBOOK, -] - """ SAGEMAKER STUDIO NOTEBOOKS """ @@ -423,7 +404,6 @@ """ RESOURCES_ALL """ - RESOURCES_ALL = ( DATASET_ALL + DATASET_TABLE_READ @@ -432,7 +412,6 @@ + CONSUMPTION_ROLE_ALL + SHARE_OBJECT_ALL + REDSHIFT_CLUSTER_ALL - + NOTEBOOK_ALL + GLOSSARY_ALL + SGMSTUDIO_NOTEBOOK_ALL + DASHBOARD_ALL @@ -444,7 +423,6 @@ RESOURCES_ALL_WITH_DESC = {k: k for k in RESOURCES_ALL} RESOURCES_ALL_WITH_DESC[CREATE_DATASET] = 'Create datasets on this environment' RESOURCES_ALL_WITH_DESC[CREATE_DASHBOARD] = 'Create dashboards on this environment' -RESOURCES_ALL_WITH_DESC[CREATE_NOTEBOOK] = 'Create notebooks on this environment' RESOURCES_ALL_WITH_DESC[CREATE_REDSHIFT_CLUSTER] = 'Create Redshift clusters on this environment' RESOURCES_ALL_WITH_DESC[CREATE_SGMSTUDIO_NOTEBOOK] = 'Create ML Studio profiles on this environment' RESOURCES_ALL_WITH_DESC[INVITE_ENVIRONMENT_GROUP] = 'Invite other teams to this environment' diff --git a/backend/dataall/modules/__init__.py b/backend/dataall/modules/__init__.py new file mode 100644 index 000000000..dce3a711e --- /dev/null +++ b/backend/dataall/modules/__init__.py @@ -0,0 +1,33 @@ +""" +Contains all submodules that can be plugged into the main functionality + +How to migrate to a new module: +1) Create your python module +2) Create an implementation of ModuleInterface/s in __init__.py of your module +3) Define your module in config.json. The loader will use it to import your module + +Remember that there should not be any references from outside to modules. +The rule is simple: modules can import the core/common code, but not the other way around +Otherwise your modules will be imported automatically. +You can add logging about the importing the module in __init__.py to track unintentionally imports + +Auto import of modules: +api - contains the logic for processing GraphQL request. It registered itself automatically +see bootstrap() and @cache_instances + +cdk - contains stacks that will be deployed to AWS then it's requested. Stacks will +register itself automatically if there is decorator @stack upon the class +see StackManagerFactory and @stack - for more information on stacks + +tasks - contains code for short-running tasks that will be delegated to lambda +These task will automatically register themselves when there is @Worker.handler +upon the static! method. +see WorkerHandler - for more information on short-living tasks + +Another example of auto import is service policies. If your module has a service policy +it will be automatically imported if it inherited from ServicePolicy + +Manual import: +Permissions. Make sure you have added all permission to the core permissions +Permission resolvers in TargetType. see it for reference +""" diff --git a/backend/dataall/modules/common/__init__.py b/backend/dataall/modules/common/__init__.py new file mode 100644 index 000000000..984cce4a8 --- /dev/null +++ b/backend/dataall/modules/common/__init__.py @@ -0,0 +1 @@ +"""Contains the common code that can be shared among modules""" diff --git a/backend/dataall/modules/common/sagemaker/__init__.py b/backend/dataall/modules/common/sagemaker/__init__.py new file mode 100644 index 000000000..959747d5d --- /dev/null +++ b/backend/dataall/modules/common/sagemaker/__init__.py @@ -0,0 +1 @@ +"""Common code for machine learning studio and notebooks""" diff --git a/backend/dataall/modules/common/sagemaker/cdk/__init__.py b/backend/dataall/modules/common/sagemaker/cdk/__init__.py new file mode 100644 index 000000000..e2e75f02a --- /dev/null +++ b/backend/dataall/modules/common/sagemaker/cdk/__init__.py @@ -0,0 +1 @@ +# Contains infrastructure code shared between ML studio and notebooks diff --git a/backend/dataall/modules/common/sagemaker/cdk/statements.py b/backend/dataall/modules/common/sagemaker/cdk/statements.py new file mode 100644 index 000000000..7e15bf4cf --- /dev/null +++ b/backend/dataall/modules/common/sagemaker/cdk/statements.py @@ -0,0 +1,138 @@ +from aws_cdk import aws_iam as iam + + +def create_sagemaker_statements(account: str, region: str, tag_key: str, tag_value: str): + return [ + iam.PolicyStatement( + actions=[ + 'sagemaker:List*', + 'sagemaker:List*', + 'sagemaker:Describe*', + 'sagemaker:BatchGet*', + 'sagemaker:BatchDescribe*', + 'sagemaker:Search', + 'sagemaker:RenderUiTemplate', + 'sagemaker:GetSearchSuggestions', + 'sagemaker:QueryLineage', + 'sagemaker:CreateNotebookInstanceLifecycleConfig', + 'sagemaker:DeleteNotebookInstanceLifecycleConfig', + 'sagemaker:CreatePresignedDomainUrl' + ], + resources=['*'], + ), + iam.PolicyStatement( + actions=['sagemaker:AddTags'], + resources=['*'], + conditions={ + 'StringEquals': { + f'aws:ResourceTag/{tag_key}': [tag_value] + } + }, + ), + iam.PolicyStatement( + actions=['sagemaker:Delete*'], + resources=[ + f'arn:aws:sagemaker:{region}:{account}:notebook-instance/*', + f'arn:aws:sagemaker:{region}:{account}:algorithm/*', + f'arn:aws:sagemaker:{region}:{account}:model/*', + f'arn:aws:sagemaker:{region}:{account}:endpoint/*', + f'arn:aws:sagemaker:{region}:{account}:endpoint-config/*', + f'arn:aws:sagemaker:{region}:{account}:experiment/*', + f'arn:aws:sagemaker:{region}:{account}:experiment-trial/*', + f'arn:aws:sagemaker:{region}:{account}:experiment-group/*', + f'arn:aws:sagemaker:{region}:{account}:model-bias-job-definition/*', + f'arn:aws:sagemaker:{region}:{account}:model-package/*', + f'arn:aws:sagemaker:{region}:{account}:model-package-group/*', + f'arn:aws:sagemaker:{region}:{account}:model-quality-job-definition/*', + f'arn:aws:sagemaker:{region}:{account}:monitoring-schedule/*', + f'arn:aws:sagemaker:{region}:{account}:pipeline/*', + f'arn:aws:sagemaker:{region}:{account}:project/*', + f'arn:aws:sagemaker:{region}:{account}:app/*' + ], + conditions={ + 'StringEquals': { + f'aws:ResourceTag/{tag_key}': [tag_value] + } + }, + ), + iam.PolicyStatement( + actions=['sagemaker:CreateApp'], + resources=['*'] + ), + iam.PolicyStatement( + actions=['sagemaker:Create*'], + resources=['*'], + ), + iam.PolicyStatement( + actions=['sagemaker:Start*', 'sagemaker:Stop*'], + resources=[ + f'arn:aws:sagemaker:{region}:{account}:notebook-instance/*', + f'arn:aws:sagemaker:{region}:{account}:monitoring-schedule/*', + f'arn:aws:sagemaker:{region}:{account}:pipeline/*', + f'arn:aws:sagemaker:{region}:{account}:training-job/*', + f'arn:aws:sagemaker:{region}:{account}:processing-job/*', + f'arn:aws:sagemaker:{region}:{account}:hyper-parameter-tuning-job/*', + f'arn:aws:sagemaker:{region}:{account}:transform-job/*', + f'arn:aws:sagemaker:{region}:{account}:automl-job/*' + ], + conditions={ + 'StringEquals': { + f'aws:ResourceTag/{tag_key}': [tag_value] + } + }, + ), + iam.PolicyStatement( + actions=['sagemaker:Update*'], + resources=[ + f'arn:aws:sagemaker:{region}:{account}:notebook-instance/*', + f'arn:aws:sagemaker:{region}:{account}:notebook-instance-lifecycle-config/*', + f'arn:aws:sagemaker:{region}:{account}:studio-lifecycle-config/*', + f'arn:aws:sagemaker:{region}:{account}:endpoint/*', + f'arn:aws:sagemaker:{region}:{account}:pipeline/*', + f'arn:aws:sagemaker:{region}:{account}:pipeline-execution/*', + f'arn:aws:sagemaker:{region}:{account}:monitoring-schedule/*', + f'arn:aws:sagemaker:{region}:{account}:experiment/*', + f'arn:aws:sagemaker:{region}:{account}:experiment-trial/*', + f'arn:aws:sagemaker:{region}:{account}:experiment-trial-component/*', + f'arn:aws:sagemaker:{region}:{account}:model-package/*', + f'arn:aws:sagemaker:{region}:{account}:training-job/*', + f'arn:aws:sagemaker:{region}:{account}:project/*' + ], + conditions={ + 'StringEquals': { + f'aws:ResourceTag/{tag_key}': [tag_value] + } + }, + ), + iam.PolicyStatement( + actions=['sagemaker:InvokeEndpoint', 'sagemaker:InvokeEndpointAsync'], + resources=[ + f'arn:aws:sagemaker:{region}:{account}:endpoint/*' + ], + conditions={ + 'StringEquals': { + f'aws:ResourceTag/{tag_key}': [tag_value] + } + }, + ), + iam.PolicyStatement( + actions=[ + 'logs:CreateLogGroup', + 'logs:CreateLogStream', + 'logs:PutLogEvents'], + resources=[ + f'arn:aws:logs:{region}:{account}:log-group:/aws/sagemaker/*', + f'arn:aws:logs:{region}:{account}:log-group:/aws/sagemaker/*:log-stream:*', + ] + ), + iam.PolicyStatement( + actions=[ + 'ecr:GetAuthorizationToken', + 'ecr:BatchCheckLayerAvailability', + 'ecr:GetDownloadUrlForLayer', + 'ecr:BatchGetImage'], + resources=[ + '*' + ] + ) + ] diff --git a/backend/dataall/modules/loader.py b/backend/dataall/modules/loader.py new file mode 100644 index 000000000..95aa2083a --- /dev/null +++ b/backend/dataall/modules/loader.py @@ -0,0 +1,77 @@ +"""Load modules that are specified in the configuration file""" +import importlib +import logging +from abc import ABC, abstractmethod +from enum import Enum +from typing import List + +from dataall.core.config import config + +log = logging.getLogger(__name__) + +_MODULE_PREFIX = "dataall.modules" + + +class ImportMode(Enum): + """Defines importing mode + + Since there are different infrastructure components that requires only part + of functionality to be loaded, there should be different loading modes + """ + + API = "api" + CDK = "cdk" + TASKS = "tasks" + + +class ModuleInterface(ABC): + """ + An interface of the module. The implementation should be part of __init__.py of the module + Contains an API that will be called from core part + """ + @classmethod + @abstractmethod + def is_supported(cls, modes: List[ImportMode]): + pass + + +def load_modules(modes: List[ImportMode]) -> None: + """ + Loads all modules from the config + Loads only requested functionality (submodules) using the mode parameter + """ + try: + modules = config.get_property("modules") + except KeyError: + log.info('"modules" has not been found in the config file. Nothing to load') + return + + log.info("Found %d modules that have been found in the config", len(modules)) + for name, props in modules.items(): + active = props["active"] + + if "active" not in props: + raise ValueError(f"Status is not defined for {name} module") + + if not active: + log.info(f"Module {name} is not active. Skipping...") + continue + + if active.lower() == "true" and not _import_module(name): + raise ValueError(f"Couldn't find module {name} under modules directory") + + log.info(f"Module {name} is loaded") + + for module in ModuleInterface.__subclasses__(): + if module.is_supported(modes): + module() + + log.info("All modules have been imported") + + +def _import_module(name): + try: + importlib.import_module(f"{_MODULE_PREFIX}.{name}") + return True + except ModuleNotFoundError: + return False diff --git a/backend/dataall/modules/notebooks/__init__.py b/backend/dataall/modules/notebooks/__init__.py new file mode 100644 index 000000000..b63e0fa51 --- /dev/null +++ b/backend/dataall/modules/notebooks/__init__.py @@ -0,0 +1,36 @@ +"""Contains the code related to SageMaker notebooks""" +import logging + +from dataall.db.api import TargetType +from dataall.modules.loader import ImportMode, ModuleInterface +from dataall.modules.notebooks.db.repositories import NotebookRepository + +log = logging.getLogger(__name__) + + +class NotebookApiModuleInterface(ModuleInterface): + """Implements ModuleInterface for notebook GraphQl lambda""" + + @classmethod + def is_supported(cls, modes): + return ImportMode.API in modes + + def __init__(self): + import dataall.modules.notebooks.api + + from dataall.modules.notebooks.services.permissions import GET_NOTEBOOK, UPDATE_NOTEBOOK + TargetType("notebook", GET_NOTEBOOK, UPDATE_NOTEBOOK) + + log.info("API of sagemaker notebooks has been imported") + + +class NotebookCdkModuleInterface(ModuleInterface): + """Implements ModuleInterface for notebook ecs tasks""" + + @classmethod + def is_supported(cls, modes): + return ImportMode.CDK in modes + + def __init__(self): + import dataall.modules.notebooks.cdk + log.info("API of sagemaker notebooks has been imported") diff --git a/backend/dataall/modules/notebooks/api/__init__.py b/backend/dataall/modules/notebooks/api/__init__.py new file mode 100644 index 000000000..244e796d6 --- /dev/null +++ b/backend/dataall/modules/notebooks/api/__init__.py @@ -0,0 +1,4 @@ +"""The package defines the schema for SageMaker notebooks""" +from dataall.modules.notebooks.api import input_types, mutations, queries, types, resolvers + +__all__ = ["types", "input_types", "queries", "mutations", "resolvers"] diff --git a/backend/dataall/modules/notebooks/api/enums.py b/backend/dataall/modules/notebooks/api/enums.py new file mode 100644 index 000000000..413620cdd --- /dev/null +++ b/backend/dataall/modules/notebooks/api/enums.py @@ -0,0 +1,11 @@ +"""Contains the enums GraphQL mapping for SageMaker notebooks """ +from dataall.api.constants import GraphQLEnumMapper + + +class SagemakerNotebookRole(GraphQLEnumMapper): + """Describes the SageMaker Notebook roles""" + + CREATOR = "950" + ADMIN = "900" + SHARED = "300" + NO_PERMISSION = "000" diff --git a/backend/dataall/modules/notebooks/api/input_types.py b/backend/dataall/modules/notebooks/api/input_types.py new file mode 100644 index 000000000..d82d96eeb --- /dev/null +++ b/backend/dataall/modules/notebooks/api/input_types.py @@ -0,0 +1,39 @@ +"""The module defines GraphQL input types for the SageMaker notebooks""" +from dataall.api import gql + +NewSagemakerNotebookInput = gql.InputType( + name="NewSagemakerNotebookInput ", + arguments=[ + gql.Argument("label", gql.NonNullableType(gql.String)), + gql.Argument("description", gql.String), + gql.Argument("environmentUri", gql.NonNullableType(gql.String)), + gql.Argument("SamlAdminGroupName", gql.NonNullableType(gql.String)), + gql.Argument("tags", gql.ArrayType(gql.String)), + gql.Argument("topics", gql.String), + gql.Argument("VpcId", gql.String), + gql.Argument("SubnetId", gql.String), + gql.Argument("VolumeSizeInGB", gql.Integer), + gql.Argument("InstanceType", gql.String), + ], +) + +ModifySagemakerNotebookInput = gql.InputType( + name="ModifySagemakerNotebookInput", + arguments=[ + gql.Argument("label", gql.String), + gql.Argument("tags", gql.ArrayType(gql.String)), + gql.Argument("description", gql.String), + ], +) + +SagemakerNotebookFilter = gql.InputType( + name="SagemakerNotebookFilter", + arguments=[ + gql.Argument("term", gql.String), + gql.Argument("page", gql.Integer), + gql.Argument("pageSize", gql.Integer), + gql.Argument("sort", gql.String), + gql.Argument("limit", gql.Integer), + gql.Argument("offset", gql.Integer), + ], +) diff --git a/backend/dataall/modules/notebooks/api/mutations.py b/backend/dataall/modules/notebooks/api/mutations.py new file mode 100644 index 000000000..74f24bb0f --- /dev/null +++ b/backend/dataall/modules/notebooks/api/mutations.py @@ -0,0 +1,39 @@ +"""The module defines GraphQL mutations for the SageMaker notebooks""" +from dataall.api import gql +from dataall.modules.notebooks.api.resolvers import ( + create_notebook, + delete_notebook, + start_notebook, + stop_notebook, +) + +createSagemakerNotebook = gql.MutationField( + name="createSagemakerNotebook", + args=[gql.Argument(name="input", type=gql.Ref("NewSagemakerNotebookInput"))], + type=gql.Ref("SagemakerNotebook"), + resolver=create_notebook, +) + +startSagemakerNotebook = gql.MutationField( + name="startSagemakerNotebook", + args=[gql.Argument(name="notebookUri", type=gql.NonNullableType(gql.String))], + type=gql.String, + resolver=start_notebook, +) + +stopSagemakerNotebook = gql.MutationField( + name="stopSagemakerNotebook", + args=[gql.Argument(name="notebookUri", type=gql.NonNullableType(gql.String))], + type=gql.String, + resolver=stop_notebook, +) + +deleteSagemakerNotebook = gql.MutationField( + name="deleteSagemakerNotebook", + args=[ + gql.Argument(name="notebookUri", type=gql.NonNullableType(gql.String)), + gql.Argument(name="deleteFromAWS", type=gql.Boolean), + ], + type=gql.String, + resolver=delete_notebook, +) diff --git a/backend/dataall/modules/notebooks/api/queries.py b/backend/dataall/modules/notebooks/api/queries.py new file mode 100644 index 000000000..36b13364f --- /dev/null +++ b/backend/dataall/modules/notebooks/api/queries.py @@ -0,0 +1,28 @@ +"""The module defines GraphQL queries for the SageMaker notebooks""" +from dataall.api import gql +from dataall.modules.notebooks.api.resolvers import ( + get_notebook, + list_notebooks, + get_notebook_presigned_url +) + +getSagemakerNotebook = gql.QueryField( + name="getSagemakerNotebook", + args=[gql.Argument(name="notebookUri", type=gql.NonNullableType(gql.String))], + type=gql.Ref("SagemakerNotebook"), + resolver=get_notebook, +) + +listSagemakerNotebooks = gql.QueryField( + name="listSagemakerNotebooks", + args=[gql.Argument("filter", gql.Ref("SagemakerNotebookFilter"))], + type=gql.Ref("SagemakerNotebookSearchResult"), + resolver=list_notebooks, +) + +getSagemakerNotebookPresignedUrl = gql.QueryField( + name="getSagemakerNotebookPresignedUrl", + args=[gql.Argument(name="notebookUri", type=gql.NonNullableType(gql.String))], + type=gql.String, + resolver=get_notebook_presigned_url, +) diff --git a/backend/dataall/modules/notebooks/api/resolvers.py b/backend/dataall/modules/notebooks/api/resolvers.py new file mode 100644 index 000000000..3c0d70c67 --- /dev/null +++ b/backend/dataall/modules/notebooks/api/resolvers.py @@ -0,0 +1,120 @@ +from dataall.modules.notebooks.api.enums import SagemakerNotebookRole + +from dataall.api.context import Context +from dataall.db import exceptions +from dataall.api.Objects.Stack import stack_helper +from dataall.modules.notebooks.services.services import NotebookService, NotebookCreationRequest +from dataall.modules.notebooks.db.models import SagemakerNotebook + + +def create_notebook(context: Context, source: SagemakerNotebook, input: dict = None): + """Creates a SageMaker notebook. Deploys the notebooks stack into AWS""" + RequestValidator.validate_creation_request(input) + request = NotebookCreationRequest.from_dict(input) + return NotebookService.create_notebook( + uri=input["environmentUri"], + admin_group=input["SamlAdminGroupName"], + request=request + ) + + +def list_notebooks(context, source, filter: dict = None): + """ + Lists all SageMaker notebooks using the given filter. + If the filter is not provided, all notebooks are returned. + """ + + if not filter: + filter = {} + return NotebookService.list_user_notebooks(filter) + + +def get_notebook(context, source, notebookUri: str = None): + """Retrieve a SageMaker notebook by URI.""" + RequestValidator.required_uri(notebookUri) + return NotebookService.get_notebook(uri=notebookUri) + + +def resolve_notebook_status(context, source: SagemakerNotebook, **kwargs): + """Resolves the status of a notebook.""" + if not source: + return None + return NotebookService.get_notebook_status(uri=source.notebookUri) + + +def start_notebook(context, source: SagemakerNotebook, notebookUri: str = None): + """Starts a sagemaker notebook instance""" + RequestValidator.required_uri(notebookUri) + NotebookService.start_notebook(uri=notebookUri) + return 'Starting' + + +def stop_notebook(context, source: SagemakerNotebook, notebookUri: str = None): + """Stops a notebook instance.""" + RequestValidator.required_uri(notebookUri) + NotebookService.stop_notebook(uri=notebookUri) + return 'Stopping' + + +def get_notebook_presigned_url(context, source: SagemakerNotebook, notebookUri: str = None): + """Creates and returns a presigned url for a notebook""" + RequestValidator.required_uri(notebookUri) + return NotebookService.get_notebook_presigned_url(uri=notebookUri) + + +def delete_notebook( + context, + source: SagemakerNotebook, + notebookUri: str = None, + deleteFromAWS: bool = None, +): + """ + Deletes the SageMaker notebook. + Deletes the notebooks stack from AWS if deleteFromAWS is True + """ + RequestValidator.required_uri(notebookUri) + NotebookService.delete_notebook(uri=notebookUri, delete_from_aws=deleteFromAWS) + return True + + +def resolve_user_role(context: Context, source: SagemakerNotebook): + if not source: + return None + if source.owner == context.username: + return SagemakerNotebookRole.CREATOR.value + elif context.groups and source.SamlAdminGroupName in context.groups: + return SagemakerNotebookRole.ADMIN.value + return SagemakerNotebookRole.NO_PERMISSION.value + + +def resolve_notebook_stack(context: Context, source: SagemakerNotebook, **kwargs): + if not source: + return None + return stack_helper.get_stack_with_cfn_resources( + targetUri=source.notebookUri, + environmentUri=source.environmentUri, + ) + + +class RequestValidator: + """Aggregates all validation logic for operating with notebooks""" + @staticmethod + def required_uri(uri): + if not uri: + raise exceptions.RequiredParameter('URI') + + @staticmethod + def validate_creation_request(data): + required = RequestValidator._required + if not data: + raise exceptions.RequiredParameter('data') + if not data.get('label'): + raise exceptions.RequiredParameter('name') + + required(data, "environmentUri") + required(data, "SamlAdminGroupName") + + @staticmethod + def _required(data: dict, name: str): + if not data.get(name): + raise exceptions.RequiredParameter(name) diff --git a/backend/dataall/modules/notebooks/api/types.py b/backend/dataall/modules/notebooks/api/types.py new file mode 100644 index 000000000..4ca8dcf0d --- /dev/null +++ b/backend/dataall/modules/notebooks/api/types.py @@ -0,0 +1,62 @@ +"""Defines the object types of the SageMaker notebooks""" +from dataall.api import gql +from dataall.modules.notebooks.api.resolvers import ( + resolve_notebook_stack, + resolve_notebook_status, + resolve_user_role, +) + +from dataall.api.Objects.Environment.resolvers import resolve_environment +from dataall.api.Objects.Organization.resolvers import resolve_organization_by_env + +from dataall.modules.notebooks.api.enums import SagemakerNotebookRole + +SagemakerNotebook = gql.ObjectType( + name="SagemakerNotebook", + fields=[ + gql.Field(name="notebookUri", type=gql.ID), + gql.Field(name="environmentUri", type=gql.NonNullableType(gql.String)), + gql.Field(name="label", type=gql.String), + gql.Field(name="description", type=gql.String), + gql.Field(name="tags", type=gql.ArrayType(gql.String)), + gql.Field(name="name", type=gql.String), + gql.Field(name="owner", type=gql.String), + gql.Field(name="created", type=gql.String), + gql.Field(name="updated", type=gql.String), + gql.Field(name="SamlAdminGroupName", type=gql.String), + gql.Field(name="VpcId", type=gql.String), + gql.Field(name="SubnetId", type=gql.String), + gql.Field(name="InstanceType", type=gql.String), + gql.Field(name="RoleArn", type=gql.String), + gql.Field(name="VolumeSizeInGB", type=gql.Integer), + gql.Field( + name="userRoleForNotebook", + type=SagemakerNotebookRole.toGraphQLEnum(), + resolver=resolve_user_role, + ), + gql.Field(name="NotebookInstanceStatus", type=gql.String, resolver=resolve_notebook_status), + gql.Field( + name="environment", + type=gql.Ref("Environment"), + resolver=resolve_environment, + ), + gql.Field( + name="organization", + type=gql.Ref("Organization"), + resolver=resolve_organization_by_env, + ), + gql.Field(name="stack", type=gql.Ref("Stack"), resolver=resolve_notebook_stack), + ], +) + +SagemakerNotebookSearchResult = gql.ObjectType( + name="SagemakerNotebookSearchResult", + fields=[ + gql.Field(name="count", type=gql.Integer), + gql.Field(name="page", type=gql.Integer), + gql.Field(name="pages", type=gql.Integer), + gql.Field(name="hasNext", type=gql.Boolean), + gql.Field(name="hasPrevious", type=gql.Boolean), + gql.Field(name="nodes", type=gql.ArrayType(SagemakerNotebook)), + ], +) diff --git a/backend/dataall/modules/notebooks/aws/__init__.py b/backend/dataall/modules/notebooks/aws/__init__.py new file mode 100644 index 000000000..873d3c5d3 --- /dev/null +++ b/backend/dataall/modules/notebooks/aws/__init__.py @@ -0,0 +1 @@ +"""Contains code that send requests to AWS using SDK (boto3)""" diff --git a/backend/dataall/modules/notebooks/aws/client.py b/backend/dataall/modules/notebooks/aws/client.py new file mode 100644 index 000000000..d584ca30e --- /dev/null +++ b/backend/dataall/modules/notebooks/aws/client.py @@ -0,0 +1,62 @@ +import logging + +from dataall.aws.handlers.sts import SessionHelper +from dataall.modules.notebooks.db.models import SagemakerNotebook +from botocore.exceptions import ClientError + +logger = logging.getLogger(__name__) + + +class SagemakerClient: + """ + A Sagemaker notebooks proxy client that is used to send requests to AWS + """ + + def __init__(self, notebook: SagemakerNotebook): + session = SessionHelper.remote_session(notebook.AWSAccountId) + self._client = session.client('sagemaker', region_name=notebook.region) + self._instance_name = notebook.NotebookInstanceName + + def get_notebook_instance_status(self) -> str: + """Remote call to AWS to check the notebook's status""" + try: + response = self._client.describe_notebook_instance( + NotebookInstanceName=self._instance_name + ) + return response.get('NotebookInstanceStatus', 'NOT FOUND') + except ClientError as e: + logger.error( + f'Could not retrieve instance {self._instance_name} status due to: {e} ' + ) + return 'NOT FOUND' + + def presigned_url(self): + """Creates a presigned url for a notebook instance by sending request to AWS""" + try: + response = self._client.create_presigned_notebook_instance_url( + NotebookInstanceName=self._instance_name + ) + return response['AuthorizedUrl'] + except ClientError as e: + raise e + + def start_instance(self): + """Starts the notebooks instance by sending a request to AWS""" + try: + status = self.get_notebook_instance_status() + self._client.start_notebook_instance(NotebookInstanceName=self._instance_name) + return status + except ClientError as e: + return e + + def stop_instance(self) -> None: + """Stops the notebooks instance by sending a request to AWS""" + try: + self._client.stop_notebook_instance(NotebookInstanceName=self._instance_name) + except ClientError as e: + raise e + + +def client(notebook: SagemakerNotebook) -> SagemakerClient: + """Factory method to retrieve the client to send request to AWS""" + return SagemakerClient(notebook) diff --git a/backend/dataall/modules/notebooks/cdk/__init__.py b/backend/dataall/modules/notebooks/cdk/__init__.py new file mode 100644 index 000000000..0c16c7577 --- /dev/null +++ b/backend/dataall/modules/notebooks/cdk/__init__.py @@ -0,0 +1,8 @@ +""" +This package contains modules that are used to create a CloudFormation stack in AWS. +The code is invoked in ECS Fargate to initialize the creation of the stack +""" +from dataall.modules.notebooks.cdk import stacks +from dataall.modules.notebooks.cdk import policies + +__all__ = ["stacks", "policies"] diff --git a/backend/dataall/modules/notebooks/cdk/policies.py b/backend/dataall/modules/notebooks/cdk/policies.py new file mode 100644 index 000000000..c6e77f9b6 --- /dev/null +++ b/backend/dataall/modules/notebooks/cdk/policies.py @@ -0,0 +1,16 @@ +from dataall.cdkproxy.stacks.policies.service_policy import ServicePolicy + +from dataall.modules.notebooks.services.permissions import CREATE_NOTEBOOK +from dataall.modules.common.sagemaker.cdk.statements import create_sagemaker_statements + + +class SagemakerPolicy(ServicePolicy): + """ + Creates a sagemaker policy for accessing and interacting with notebooks + """ + + def get_statements(self, group_permissions, **kwargs): + if CREATE_NOTEBOOK not in group_permissions: + return [] + + return create_sagemaker_statements(self.account, self.region, self.tag_key, self.tag_value) diff --git a/backend/dataall/cdkproxy/stacks/notebook.py b/backend/dataall/modules/notebooks/cdk/stacks.py similarity index 79% rename from backend/dataall/cdkproxy/stacks/notebook.py rename to backend/dataall/modules/notebooks/cdk/stacks.py index dd80de060..ef01f7117 100644 --- a/backend/dataall/cdkproxy/stacks/notebook.py +++ b/backend/dataall/modules/notebooks/cdk/stacks.py @@ -1,3 +1,6 @@ +"""" +Creates a CloudFormation stack for SageMaker notebooks using cdk +""" import logging import os @@ -10,34 +13,43 @@ CfnOutput, ) -from .manager import stack -from ... import db -from ...db import models -from ...db.api import Environment -from ...utils.cdk_nag_utils import CDKNagUtil -from ...utils.runtime_stacks_tagging import TagsUtil +from dataall.modules.notebooks.db.models import SagemakerNotebook +from dataall.modules.notebooks.db import models +from dataall.db.models import EnvironmentGroup + +from dataall.cdkproxy.stacks.manager import stack +from dataall.db import Engine, get_engine +from dataall.db.api import Environment +from dataall.utils.cdk_nag_utils import CDKNagUtil +from dataall.utils.runtime_stacks_tagging import TagsUtil logger = logging.getLogger(__name__) @stack(stack='notebook') -class SagemakerNotebook(Stack): +class NotebookStack(Stack): + """ + Creation of a notebook stack. + Having imported the notebook module, the class registers itself using @stack + Then it will be reachable by HTTP request / using SQS from GraphQL lambda + """ + module_name = __file__ - def get_engine(self) -> db.Engine: + def get_engine(self) -> Engine: envname = os.environ.get('envname', 'local') - engine = db.get_engine(envname=envname) + engine = get_engine(envname=envname) return engine - def get_target(self, target_uri) -> models.SagemakerNotebook: + def get_target(self, target_uri) -> SagemakerNotebook: engine = self.get_engine() with engine.scoped_session() as session: - notebook = session.query(models.SagemakerNotebook).get(target_uri) + notebook = session.query(SagemakerNotebook).get(target_uri) return notebook def get_env_group( - self, notebook: models.SagemakerNotebook - ) -> models.EnvironmentGroup: + self, notebook: SagemakerNotebook + ) -> EnvironmentGroup: engine = self.get_engine() with engine.scoped_session() as session: env = Environment.get_environment_group( @@ -132,6 +144,6 @@ def __init__(self, scope, id: str, target_uri: str = None, **kwargs) -> None: value=notebook.NotebookInstanceName, ) - TagsUtil.add_tags(self) + TagsUtil.add_tags(stack=self, model=models.SagemakerNotebook, target_type="notebook") CDKNagUtil.check_rules(self) diff --git a/backend/dataall/modules/notebooks/db/__init__.py b/backend/dataall/modules/notebooks/db/__init__.py new file mode 100644 index 000000000..86631d191 --- /dev/null +++ b/backend/dataall/modules/notebooks/db/__init__.py @@ -0,0 +1 @@ +"""Contains a code to that interacts with the database""" diff --git a/backend/dataall/db/models/SagemakerNotebook.py b/backend/dataall/modules/notebooks/db/models.py similarity index 69% rename from backend/dataall/db/models/SagemakerNotebook.py rename to backend/dataall/modules/notebooks/db/models.py index 675ebf334..8afe3da3d 100644 --- a/backend/dataall/db/models/SagemakerNotebook.py +++ b/backend/dataall/modules/notebooks/db/models.py @@ -1,13 +1,16 @@ -from sqlalchemy import Column, String, Integer -from sqlalchemy.orm import query_expression +"""ORM models for sagemaker notebooks""" -from .. import Base -from .. import Resource, utils +from sqlalchemy import Column, String, Integer, ForeignKey + +from dataall.db import Base +from dataall.db import Resource, utils class SagemakerNotebook(Resource, Base): + """Describes ORM model for sagemaker notebooks""" + __tablename__ = 'sagemaker_notebook' - environmentUri = Column(String, nullable=False) + environmentUri = Column(String, ForeignKey("environment.environmentUri"), nullable=False) notebookUri = Column(String, primary_key=True, default=utils.uuid('notebook')) NotebookInstanceName = Column( String, nullable=False, default=utils.slugifier('label') @@ -21,4 +24,3 @@ class SagemakerNotebook(Resource, Base): SubnetId = Column(String, nullable=True) VolumeSizeInGB = Column(Integer, nullable=True) InstanceType = Column(String, nullable=True) - userRoleForNotebook = query_expression() diff --git a/backend/dataall/modules/notebooks/db/repositories.py b/backend/dataall/modules/notebooks/db/repositories.py new file mode 100644 index 000000000..cb0e3a20f --- /dev/null +++ b/backend/dataall/modules/notebooks/db/repositories.py @@ -0,0 +1,60 @@ +""" +DAO layer that encapsulates the logic and interaction with the database for notebooks +Provides the API to retrieve / update / delete notebooks +""" +from sqlalchemy import or_ +from sqlalchemy.orm import Query + +from dataall.db import paginate +from dataall.modules.notebooks.db.models import SagemakerNotebook + + +class NotebookRepository: + """DAO layer for notebooks""" + _DEFAULT_PAGE = 1 + _DEFAULT_PAGE_SIZE = 10 + + def __init__(self, session): + self._session = session + + def save_notebook(self, notebook): + """Save notebook to the database""" + self._session.add(notebook) + self._session.commit() + + def find_notebook(self, uri) -> SagemakerNotebook: + """Finds a notebook. Returns None if the notebook doesn't exist""" + return self._session.query(SagemakerNotebook).get(uri) + + def paginated_user_notebooks(self, username, groups, filter=None) -> dict: + """Returns a page of user notebooks""" + return paginate( + query=self._query_user_notebooks(username, groups, filter), + page=filter.get('page', NotebookRepository._DEFAULT_PAGE), + page_size=filter.get('pageSize', NotebookRepository._DEFAULT_PAGE_SIZE), + ).to_dict() + + def _query_user_notebooks(self, username, groups, filter) -> Query: + query = self._session.query(SagemakerNotebook).filter( + or_( + SagemakerNotebook.owner == username, + SagemakerNotebook.SamlAdminGroupName.in_(groups), + ) + ) + if filter and filter.get('term'): + query = query.filter( + or_( + SagemakerNotebook.description.ilike( + filter.get('term') + '%%' + ), + SagemakerNotebook.label.ilike(filter.get('term') + '%%'), + ) + ) + return query + + def count_notebooks(self, environment_uri): + return ( + self._session.query(SagemakerNotebook) + .filter(SagemakerNotebook.environmentUri == environment_uri) + .count() + ) diff --git a/backend/dataall/modules/notebooks/services/__init__.py b/backend/dataall/modules/notebooks/services/__init__.py new file mode 100644 index 000000000..a5df50095 --- /dev/null +++ b/backend/dataall/modules/notebooks/services/__init__.py @@ -0,0 +1,7 @@ +""" +Contains the code needed for service layer. +The service layer is a layer where all business logic is aggregated +""" +from dataall.modules.notebooks.services import services, permissions + +__all__ = ["services", "permissions"] diff --git a/backend/dataall/modules/notebooks/services/permissions.py b/backend/dataall/modules/notebooks/services/permissions.py new file mode 100644 index 000000000..4baeb2947 --- /dev/null +++ b/backend/dataall/modules/notebooks/services/permissions.py @@ -0,0 +1,49 @@ +""" +Add module's permissions to the global permissions. +Contains permissions for sagemaker notebooks +""" + +from dataall.db.permissions import ( + ENVIRONMENT_ALL, + ENVIRONMENT_INVITED, + RESOURCES_ALL_WITH_DESC, + RESOURCES_ALL, + ENVIRONMENT_INVITATION_REQUEST, + TENANT_ALL, + TENANT_ALL_WITH_DESC +) + +GET_NOTEBOOK = "GET_NOTEBOOK" +UPDATE_NOTEBOOK = "UPDATE_NOTEBOOK" +DELETE_NOTEBOOK = "DELETE_NOTEBOOK" +CREATE_NOTEBOOK = "CREATE_NOTEBOOK" +MANAGE_NOTEBOOKS = "MANAGE_NOTEBOOKS" + +NOTEBOOK_ALL = [ + GET_NOTEBOOK, + DELETE_NOTEBOOK, + UPDATE_NOTEBOOK, +] + +LIST_ENVIRONMENT_NOTEBOOKS = 'LIST_ENVIRONMENT_NOTEBOOKS' + +ENVIRONMENT_ALL.append(LIST_ENVIRONMENT_NOTEBOOKS) +ENVIRONMENT_ALL.append(CREATE_NOTEBOOK) +ENVIRONMENT_INVITED.append(LIST_ENVIRONMENT_NOTEBOOKS) +ENVIRONMENT_INVITED.append(CREATE_NOTEBOOK) +ENVIRONMENT_INVITATION_REQUEST.append(LIST_ENVIRONMENT_NOTEBOOKS) +ENVIRONMENT_INVITATION_REQUEST.append(CREATE_NOTEBOOK) + +TENANT_ALL.append(MANAGE_NOTEBOOKS) +TENANT_ALL_WITH_DESC[MANAGE_NOTEBOOKS] = "Manage notebooks" + + +RESOURCES_ALL.append(CREATE_NOTEBOOK) +RESOURCES_ALL.extend(NOTEBOOK_ALL) +RESOURCES_ALL.append(LIST_ENVIRONMENT_NOTEBOOKS) + +RESOURCES_ALL_WITH_DESC[CREATE_NOTEBOOK] = "Create notebooks on this environment" +RESOURCES_ALL_WITH_DESC[LIST_ENVIRONMENT_NOTEBOOKS] = "List notebooks on this environment" +RESOURCES_ALL_WITH_DESC[GET_NOTEBOOK] = "General permission to list notebooks" +RESOURCES_ALL_WITH_DESC[DELETE_NOTEBOOK] = "Permission to delete notebooks" +RESOURCES_ALL_WITH_DESC[UPDATE_NOTEBOOK] = "Permission to edit notebooks" diff --git a/backend/dataall/modules/notebooks/services/services.py b/backend/dataall/modules/notebooks/services/services.py new file mode 100644 index 000000000..c2b9b9b27 --- /dev/null +++ b/backend/dataall/modules/notebooks/services/services.py @@ -0,0 +1,230 @@ +""" +A service layer for sagemaker notebooks +Central part for working with notebooks +""" +import contextlib +import dataclasses +import logging +from dataclasses import dataclass, field +from typing import List, Dict + +from dataall.api.Objects.Stack import stack_helper +from dataall.core.context import get_context as context +from dataall.core.environment.db.repositories import EnvironmentParameterRepository +from dataall.db.api import ( + ResourcePolicy, + Environment, KeyValueTag, Stack, +) +from dataall.db import models, exceptions +from dataall.modules.notebooks.aws.client import client +from dataall.modules.notebooks.db.repositories import NotebookRepository +from dataall.utils.naming_convention import ( + NamingConventionService, + NamingConventionPattern, +) +from dataall.utils.slugify import slugify +from dataall.modules.notebooks.db.models import SagemakerNotebook +from dataall.modules.notebooks.services import permissions +from dataall.modules.notebooks.services.permissions import MANAGE_NOTEBOOKS, CREATE_NOTEBOOK +from dataall.core.permission_checker import has_resource_permission, has_tenant_permission, has_group_permission + +logger = logging.getLogger(__name__) + + +@dataclass +class NotebookCreationRequest: + """A request dataclass for notebook creation. Adds default values for missed parameters""" + label: str + VpcId: str + SubnetId: str + SamlAdminGroupName: str + environment: Dict = field(default_factory=dict) + description: str = "No description provided" + VolumeSizeInGB: int = 32 + InstanceType: str = "ml.t3.medium" + tags: List[str] = field(default_factory=list) + + @classmethod + def from_dict(cls, env): + """Copies only required fields from the dictionary and creates an instance of class""" + fields = set([f.name for f in dataclasses.fields(cls)]) + return cls(**{ + k: v for k, v in env.items() + if k in fields + }) + + +class NotebookService: + """ + Encapsulate the logic of interactions with sagemaker notebooks. + """ + + _NOTEBOOK_RESOURCE_TYPE = "notebook" + + @staticmethod + @has_tenant_permission(MANAGE_NOTEBOOKS) + @has_resource_permission(CREATE_NOTEBOOK) + @has_group_permission(CREATE_NOTEBOOK) + def create_notebook(*, uri: str, admin_group: str, request: NotebookCreationRequest) -> SagemakerNotebook: + """ + Creates a notebook and attach policies to it + Throws an exception if notebook are not enabled for the environment + """ + + with _session() as session: + env = Environment.get_environment_by_uri(session, uri) + enabled = EnvironmentParameterRepository(session).get_param(uri, "notebooksEnabled") + + if not enabled and enabled.lower() != "true": + raise exceptions.UnauthorizedOperation( + action=CREATE_NOTEBOOK, + message=f'Notebooks feature is disabled for the environment {env.label}', + ) + + env_group = request.environment + if not env_group: + env_group = Environment.get_environment_group( + session, + group_uri=admin_group, + environment_uri=env.environmentUri, + ) + + notebook = SagemakerNotebook( + label=request.label, + environmentUri=env.environmentUri, + description=request.description, + NotebookInstanceName=slugify(request.label, separator=''), + NotebookInstanceStatus='NotStarted', + AWSAccountId=env.AwsAccountId, + region=env.region, + RoleArn=env_group.environmentIAMRoleArn, + owner=context().username, + SamlAdminGroupName=admin_group, + tags=request.tags, + VpcId=request.VpcId, + SubnetId=request.SubnetId, + VolumeSizeInGB=request.VolumeSizeInGB, + InstanceType=request.InstanceType, + ) + + NotebookRepository(session).save_notebook(notebook) + + notebook.NotebookInstanceName = NamingConventionService( + target_uri=notebook.notebookUri, + target_label=notebook.label, + pattern=NamingConventionPattern.NOTEBOOK, + resource_prefix=env.resourcePrefix, + ).build_compliant_name() + + ResourcePolicy.attach_resource_policy( + session=session, + group=request.SamlAdminGroupName, + permissions=permissions.NOTEBOOK_ALL, + resource_uri=notebook.notebookUri, + resource_type=SagemakerNotebook.__name__, + ) + + if env.SamlGroupName != admin_group: + ResourcePolicy.attach_resource_policy( + session=session, + group=env.SamlGroupName, + permissions=permissions.NOTEBOOK_ALL, + resource_uri=notebook.notebookUri, + resource_type=SagemakerNotebook.__name__, + ) + + Stack.create_stack( + session=session, + environment_uri=notebook.environmentUri, + target_type='notebook', + target_uri=notebook.notebookUri, + target_label=notebook.label, + ) + + stack_helper.deploy_stack(targetUri=notebook.notebookUri) + + return notebook + + @staticmethod + def list_user_notebooks(filter) -> dict: + """List existed user notebooks. Filters only required notebooks by the filter param""" + with _session() as session: + return NotebookRepository(session).paginated_user_notebooks( + username=context().username, + groups=context().groups, + filter=filter + ) + + @staticmethod + @has_resource_permission(permissions.GET_NOTEBOOK) + def get_notebook(*, uri) -> SagemakerNotebook: + """Gets a notebook by uri""" + with _session() as session: + return NotebookService._get_notebook(session, uri) + + @staticmethod + @has_resource_permission(permissions.UPDATE_NOTEBOOK) + def start_notebook(*, uri): + """Starts notebooks instance""" + notebook = NotebookService.get_notebook(uri=uri) + client(notebook).start_instance() + + @staticmethod + @has_resource_permission(permissions.UPDATE_NOTEBOOK) + def stop_notebook(*, uri: str) -> None: + """Stop notebook instance""" + notebook = NotebookService.get_notebook(uri=uri) + client(notebook).stop_instance() + + @staticmethod + @has_resource_permission(permissions.GET_NOTEBOOK) + def get_notebook_presigned_url(*, uri: str) -> str: + """Creates and returns a presigned url for a notebook""" + notebook = NotebookService.get_notebook(uri=uri) + return client(notebook).presigned_url() + + @staticmethod + @has_resource_permission(permissions.GET_NOTEBOOK) + def get_notebook_status(*, uri) -> str: + """Retrieves notebook status""" + notebook = NotebookService.get_notebook(uri=uri) + return client(notebook).get_notebook_instance_status() + + @staticmethod + @has_resource_permission(permissions.DELETE_NOTEBOOK) + def delete_notebook(*, uri: str, delete_from_aws: bool): + """Deletes notebook from the database and if delete_from_aws is True from AWS as well""" + with _session() as session: + notebook = NotebookService._get_notebook(session, uri) + KeyValueTag.delete_key_value_tags(session, notebook.notebookUri, 'notebook') + session.delete(notebook) + + ResourcePolicy.delete_resource_policy( + session=session, + resource_uri=notebook.notebookUri, + group=notebook.SamlAdminGroupName, + ) + + env: models.Environment = Environment.get_environment_by_uri( + session, notebook.environmentUri + ) + + if delete_from_aws: + stack_helper.delete_stack( + target_uri=uri, + accountid=env.AwsAccountId, + cdk_role_arn=env.CDKRoleArn, + region=env.region + ) + + @staticmethod + def _get_notebook(session, uri) -> SagemakerNotebook: + notebook = NotebookRepository(session).find_notebook(uri) + + if not notebook: + raise exceptions.ObjectNotFound('SagemakerNotebook', uri) + return notebook + + +def _session(): + return context().db_engine.scoped_session() diff --git a/backend/dataall/modules/notebooks/tasks/__init__.py b/backend/dataall/modules/notebooks/tasks/__init__.py new file mode 100644 index 000000000..7da194e3b --- /dev/null +++ b/backend/dataall/modules/notebooks/tasks/__init__.py @@ -0,0 +1 @@ +"""Currently notebooks don't have tasks, but this module needed for correct loading""" diff --git a/backend/dataall/utils/runtime_stacks_tagging.py b/backend/dataall/utils/runtime_stacks_tagging.py index d95dc7d8a..02f0a9c3f 100644 --- a/backend/dataall/utils/runtime_stacks_tagging.py +++ b/backend/dataall/utils/runtime_stacks_tagging.py @@ -27,13 +27,13 @@ def __init__(self, stack): self.stack = stack @classmethod - def add_tags(cls, stack: Stack) -> [tuple]: + def add_tags(cls, stack: Stack, model, target_type) -> [tuple]: """ A class method that adds tags to a Stack """ # Get the list of tags to be added from the tag factory - stack_tags_to_add = cls.tag_factory(stack) + stack_tags_to_add = cls.tag_factory(stack, model, target_type) # Add the tags to the Stack for tag in stack_tags_to_add: @@ -42,35 +42,22 @@ def add_tags(cls, stack: Stack) -> [tuple]: return stack_tags_to_add @classmethod - def tag_factory(cls, stack: Stack) -> typing.List[typing.Tuple]: + def tag_factory(cls, stack: Stack, model_name, target_type) -> typing.List[typing.Tuple]: """ A class method that returns tags to be added to a Stack (based on Stack type) """ _stack_tags = [] - # Dictionary that resolves the Stack class name to the GraphQL model - stack_model = dict( - Dataset=models.Dataset, - EnvironmentSetup=models.Environment, - SagemakerStudioDomain=models.SagemakerStudioUserProfile, - SagemakerStudioUserProfile=models.SagemakerStudioUserProfile, - SagemakerNotebook=models.SagemakerNotebook, - PipelineStack=models.DataPipeline, - CDKPipelineStack=models.DataPipeline, - RedshiftStack=models.RedshiftCluster, - ) - engine = cls.get_engine() # Initialize references to stack's environment and organisation with engine.scoped_session() as session: - model_name = stack_model[stack.__class__.__name__] target_stack = cls.get_target(session, stack, model_name) environment = cls.get_environment(session, target_stack) organisation = cls.get_organization(session, environment) key_value_tags: [models.KeyValueTag] = cls.get_model_key_value_tags( - session, stack, model_name + session, stack, target_type ) cascaded_tags: [models.KeyValueTag] = cls.get_environment_cascade_key_value_tags( session, environment.environmentUri @@ -145,13 +132,13 @@ def get_environment(cls, session, target_stack): return environment @classmethod - def get_model_key_value_tags(cls, session, stack, model_name): + def get_model_key_value_tags(cls, session, stack, target_type): return [ (kv.key, kv.value) for kv in db.api.KeyValueTag.find_key_value_tags( session, stack.target_uri, - db.api.TargetType.get_target_type(model_name), + target_type, ) ] diff --git a/backend/docker/prod/ecs/Dockerfile b/backend/docker/prod/ecs/Dockerfile index b272902af..2dd661665 100644 --- a/backend/docker/prod/ecs/Dockerfile +++ b/backend/docker/prod/ecs/Dockerfile @@ -52,6 +52,9 @@ ADD backend/dataall /dataall ADD backend/blueprints /blueprints ADD backend/cdkproxymain.py /cdkproxymain.py +ENV config_location="/config.json" +COPY config.json /config.json + RUN mkdir -p dataall/cdkproxy/assets/glueprofilingjob/jars RUN mkdir -p blueprints/ml_data_pipeline/engine/glue/jars RUN curl https://repo1.maven.org/maven2/com/amazon/deequ/deequ/$DEEQU_VERSION/deequ-$DEEQU_VERSION.jar --output dataall/cdkproxy/assets/glueprofilingjob/jars/deequ-$DEEQU_VERSION.jar diff --git a/backend/docker/prod/lambda/Dockerfile b/backend/docker/prod/lambda/Dockerfile index 42b98d65a..74609e98c 100644 --- a/backend/docker/prod/lambda/Dockerfile +++ b/backend/docker/prod/lambda/Dockerfile @@ -23,6 +23,9 @@ RUN $PYTHON_VERSION -m pip install -r requirements.txt -t . COPY backend/. ./ +ENV config_location="config.json" +COPY config.json ./config.json + ## You must add the Lambda Runtime Interface Client (RIC) for your runtime. RUN $PYTHON_VERSION -m pip install awslambdaric --target ${FUNCTION_DIR} diff --git a/backend/local.cdkapi.server.py b/backend/local_cdkapi_server.py similarity index 100% rename from backend/local.cdkapi.server.py rename to backend/local_cdkapi_server.py diff --git a/backend/local.graphql.server.py b/backend/local_graphql_server.py similarity index 88% rename from backend/local.graphql.server.py rename to backend/local_graphql_server.py index 094e53052..3783ba0a3 100644 --- a/backend/local.graphql.server.py +++ b/backend/local_graphql_server.py @@ -14,6 +14,9 @@ from dataall.aws.handlers.service_handlers import Worker from dataall.db import get_engine, Base, create_schema_and_tables, init_permissions, api from dataall.searchproxy import connect, run_query +from dataall.modules.loader import load_modules, ImportMode +from dataall.core.config import config +from dataall.core.context import set_context, dispose_context, RequestContext import logging @@ -27,10 +30,12 @@ es = connect(envname=ENVNAME) logger.info('Connected') # create_schema_and_tables(engine, envname=ENVNAME) +load_modules(modes=[ImportMode.API, ImportMode.TASKS]) Base.metadata.create_all(engine.engine) CDKPROXY_URL = ( 'http://cdkproxy:2805' if ENVNAME == 'dkrcompose' else 'http://localhost:2805' ) +config.set_property("cdk_proxy_url", CDKPROXY_URL) init_permissions(engine) @@ -80,13 +85,16 @@ def request_context(headers, mock=False): permissions=db.permissions.TENANT_ALL, tenant_name='dataall', ) + + set_context(RequestContext(engine, username, groups, es)) + + # TODO: remove when the migration to a new RequestContext API is complete. Used only for backward compatibility context = Context( engine=engine, es=es, schema=schema, username=username, groups=groups, - cdkproxyurl=CDKPROXY_URL, ) return context.__dict__ @@ -129,19 +137,22 @@ def esproxy(): def graphql_server(): print('.............................') # GraphQL queries are always sent as POST - print(request.data) + logger.debug(request.data) data = request.get_json() - print(request_context(request.headers, mock=True)) + + context = request_context(request.headers, mock=True) + logger.debug(context) # Note: Passing the request to the context is optional. # In Flask, the current request is always accessible as flask.request success, result = graphql_sync( schema, data, - context_value=request_context(request.headers, mock=True), + context_value=context, debug=app.debug, ) + dispose_context() status_code = 200 if success else 400 return jsonify(result), status_code diff --git a/backend/migrations/versions/5fc49baecea4_add_enviromental_parameters.py b/backend/migrations/versions/5fc49baecea4_add_enviromental_parameters.py new file mode 100644 index 000000000..061d9b81a --- /dev/null +++ b/backend/migrations/versions/5fc49baecea4_add_enviromental_parameters.py @@ -0,0 +1,153 @@ +"""add_enviromental_parameters + +Revision ID: 5fc49baecea4 +Revises: d05f9a5b215e +Create Date: 2023-02-20 14:28:13.331670 + +""" +from typing import List + +from alembic import op +from sqlalchemy import Boolean, Column, String, orm +from sqlalchemy.ext.declarative import declarative_base +from dataall.db import Resource +from dataall.db.api.permission import Permission +from dataall.db.models import TenantPolicy, TenantPolicyPermission, PermissionType +from dataall.db.permissions import MANAGE_SGMSTUDIO_NOTEBOOKS +from dataall.modules.notebooks.services.permissions import MANAGE_NOTEBOOKS + +# revision identifiers, used by Alembic. +revision = "5fc49baecea4" +down_revision = "509997f0a51e" +branch_labels = None +depends_on = None + +Base = declarative_base() + + +class Environment(Resource, Base): + __tablename__ = "environment" + environmentUri = Column(String, primary_key=True) + notebooksEnabled = Column(Boolean) + + +class EnvironmentParameter(Resource, Base): + __tablename__ = "environment_parameters" + environmentUri = Column(String, primary_key=True) + paramKey = Column(String, primary_key=True), + paramValue = Column(String, nullable=True) + + +class SagemakerNotebook(Resource, Base): + __tablename__ = "sagemaker_notebook" + environmentUri = Column(String, nullable=False) + notebookUri = Column(String, primary_key=True) + + +def upgrade(): + """ + The script does the following migration: + 1) creation of the environment_parameters and environment_resources tables + 2) Migration xxxEnabled to the environment_parameters table + 3) Dropping the xxxEnabled columns from the environment_parameters + 4) Migrate permissions + """ + try: + bind = op.get_bind() + session = orm.Session(bind=bind) + + print("Creating of environment_parameters table...") + op.create_table( + "environment_parameters", + Column("environmentUri", String, primary_key=True), + Column("paramKey", String, primary_key=True), + Column("paramValue", String, nullable=False), + ) + print("Creation of environment_parameters is done") + + print("Migrating the environmental parameters...") + envs: List[Environment] = session.query(Environment).all() + params: List[EnvironmentParameter] = [] + for env in envs: + _add_param_if_exists( + params, env, "notebooksEnabled", str(env.notebooksEnabled).lower() # for frontend + ) + + session.add_all(params) + print("Migration of the environmental parameters has been complete") + + op.drop_column("environment", "notebooksEnabled") + print("Dropped the columns from the environment table ") + + create_foreign_key_to_env(op, 'sagemaker_notebook') + create_foreign_key_to_env(op, 'dataset') + create_foreign_key_to_env(op, 'sagemaker_studio_user_profile') + create_foreign_key_to_env(op, 'redshiftcluster') + create_foreign_key_to_env(op, 'datapipeline') + create_foreign_key_to_env(op, 'dashboard') + + print("Saving new MANAGE_SGMSTUDIO_NOTEBOOKS permission") + Permission.init_permissions(session) + + manage_notebooks = Permission.get_permission_by_name( + session, MANAGE_NOTEBOOKS, PermissionType.TENANT.name + ) + manage_mlstudio = Permission.get_permission_by_name( + session, MANAGE_SGMSTUDIO_NOTEBOOKS, PermissionType.TENANT.name + ) + + permissions = ( + session.query(TenantPolicyPermission) + .filter(TenantPolicyPermission.permission == manage_notebooks.permissionUri) + .all() + ) + + for permission in permissions: + session.add(TenantPolicyPermission( + sid=permission.sid, + permissionUri=manage_mlstudio.permissionUri, + )) + session.commit() + + except Exception as ex: + print(f"Failed to execute the migration script due to: {ex}") + + +def downgrade(): + try: + bind = op.get_bind() + session = orm.Session(bind=bind) + + op.drop_constraint("fk_notebook_env_uri", "sagemaker_notebook") + op.add_column("environment", Column("notebooksEnabled", Boolean, default=True)) + + params = session.query(EnvironmentParameter).all() + envs = [] + for param in params: + envs.append(Environment( + environmentUri=param.environmentUri, + notebooksEnabled=params["notebooksEnabled"] == "true" + )) + + session.add_all(envs) + op.drop_table("environment_parameters") + + except Exception as ex: + print(f"Failed to execute the rollback script due to: {ex}") + + +def _add_param_if_exists(params: List[EnvironmentParameter], env: Environment, key, val) -> None: + if val is not None: + params.append(EnvironmentParameter( + environmentUri=env.environmentUri, + paramKey=key, + paramValue=str(val).lower() + )) + + +def create_foreign_key_to_env(op, table: str): + op.create_foreign_key( + f"fk_{table}_env_uri", + table, "environment", + ["environmentUri"], ["environmentUri"], + ) diff --git a/backend/requirements.txt b/backend/requirements.txt index 7429e61c9..81eeba8a9 100644 --- a/backend/requirements.txt +++ b/backend/requirements.txt @@ -14,4 +14,5 @@ PyYAML==6.0 requests==2.27.1 requests_aws4auth==1.1.1 sqlalchemy==1.3.24 -starlette==0.25.0 \ No newline at end of file +starlette==0.25.0 +deprecated==1.2.13 \ No newline at end of file diff --git a/config.json b/config.json new file mode 100644 index 000000000..e0a9d85d0 --- /dev/null +++ b/config.json @@ -0,0 +1,7 @@ +{ + "modules": { + "notebooks": { + "active": true + } + } +} \ No newline at end of file diff --git a/deploy/stacks/container.py b/deploy/stacks/container.py index dfc667c25..c313df82e 100644 --- a/deploy/stacks/container.py +++ b/deploy/stacks/container.py @@ -1,3 +1,4 @@ +from typing import Dict from aws_cdk import ( aws_ec2 as ec2, aws_ecs as ecs, @@ -28,6 +29,7 @@ def __init__( **kwargs, ): super().__init__(scope, id, **kwargs) + self._envname = envname if self.node.try_get_context('image_tag'): image_tag = self.node.try_get_context('image_tag') @@ -62,11 +64,7 @@ def __init__( image=ecs.ContainerImage.from_ecr_repository( repository=ecr_repository, tag=cdkproxy_image_tag ), - environment={ - 'AWS_REGION': self.region, - 'envname': envname, - 'LOGLEVEL': 'DEBUG', - }, + environment=self._create_env('DEBUG'), command=['python3.8', '-m', 'dataall.tasks.cdkproxy'], logging=ecs.LogDriver.aws_logs( stream_prefix='task', @@ -99,11 +97,7 @@ def __init__( command=['python3.8', '-m', 'dataall.tasks.tables_syncer'], container_id=f'container', ecr_repository=ecr_repository, - environment={ - 'AWS_REGION': self.region, - 'envname': envname, - 'LOGLEVEL': 'INFO', - }, + environment=self._create_env('INFO'), image_tag=cdkproxy_image_tag, log_group=self.create_log_group( envname, resource_prefix, log_group_name='tables-syncer' @@ -123,11 +117,7 @@ def __init__( command=['python3.8', '-m', 'dataall.tasks.catalog_indexer'], container_id=f'container', ecr_repository=ecr_repository, - environment={ - 'AWS_REGION': self.region, - 'envname': envname, - 'LOGLEVEL': 'INFO', - }, + environment=self._create_env('INFO'), image_tag=cdkproxy_image_tag, log_group=self.create_log_group( envname, resource_prefix, log_group_name='catalog-indexer' @@ -147,11 +137,7 @@ def __init__( command=['python3.8', '-m', 'dataall.tasks.stacks_updater'], container_id=f'container', ecr_repository=ecr_repository, - environment={ - 'AWS_REGION': self.region, - 'envname': envname, - 'LOGLEVEL': 'INFO', - }, + environment=self._create_env('INFO'), image_tag=cdkproxy_image_tag, log_group=self.create_log_group( envname, resource_prefix, log_group_name='stacks-updater' @@ -171,11 +157,7 @@ def __init__( command=['python3.8', '-m', 'dataall.tasks.bucket_policy_updater'], container_id=f'container', ecr_repository=ecr_repository, - environment={ - 'AWS_REGION': self.region, - 'envname': envname, - 'LOGLEVEL': 'INFO', - }, + environment=self._create_env('DEBUG'), image_tag=cdkproxy_image_tag, log_group=self.create_log_group( envname, resource_prefix, log_group_name='policies-updater' @@ -201,11 +183,7 @@ def __init__( ], container_id=f'container', ecr_repository=ecr_repository, - environment={ - 'AWS_REGION': self.region, - 'envname': envname, - 'LOGLEVEL': 'INFO', - }, + environment=self._create_env('INFO'), image_tag=cdkproxy_image_tag, log_group=self.create_log_group( envname, resource_prefix, log_group_name='subscriptions' @@ -236,11 +214,7 @@ def __init__( image=ecs.ContainerImage.from_ecr_repository( repository=ecr_repository, tag=cdkproxy_image_tag ), - environment={ - 'AWS_REGION': self.region, - 'envname': envname, - 'LOGLEVEL': 'DEBUG', - }, + environment=self._create_env('DEBUG'), command=['python3.8', '-m', 'dataall.tasks.share_manager'], logging=ecs.LogDriver.aws_logs( stream_prefix='task', @@ -512,3 +486,11 @@ def set_scheduled_task( # security_groups=[security_group], ) return scheduled_task + + def _create_env(self, log_lvl) -> Dict: + return { + 'AWS_REGION': self.region, + 'envname': self._envname, + 'LOGLEVEL': log_lvl, + 'config_location': '/config.json' + } diff --git a/deploy/stacks/pipeline.py b/deploy/stacks/pipeline.py index d00a7db13..49107bb9e 100644 --- a/deploy/stacks/pipeline.py +++ b/deploy/stacks/pipeline.py @@ -452,6 +452,7 @@ def set_quality_gate_stage(self): commands=[ 'mkdir -p source_build', 'mv backend ./source_build/', + 'mv config.json ./source_build/', 'cd source_build/ && zip -r ../source_build/source_build.zip *', f'aws s3api put-object --bucket {self.pipeline_bucket.bucket_name} --key source_build.zip --body source_build.zip', ], @@ -483,6 +484,7 @@ def set_quality_gate_stage(self): commands=[ 'mkdir -p source_build', 'mv backend ./source_build/', + 'mv config.json ./source_build/', 'cd source_build/ && zip -r ../source_build/source_build.zip *', f'aws s3api put-object --bucket {self.pipeline_bucket.bucket_name} --key source_build.zip --body source_build.zip', ], diff --git a/docker-compose.yaml b/docker-compose.yaml index 0668ffcd8..0aecb2555 100644 --- a/docker-compose.yaml +++ b/docker-compose.yaml @@ -16,10 +16,12 @@ services: - db environment: envname: 'dkrcompose' + config_location: "/config.json" AWS_DEFAULT_REGION: "eu-west-1" volumes: - ./backend:/code - $HOME/.aws/credentials:/root/.aws/credentials:ro + - ./config.json:/config.json restart: on-failure:60 platform: @@ -29,17 +31,19 @@ services: build: context: ./backend dockerfile: docker/dev/Dockerfile - entrypoint: /bin/bash -c "../build/wait-for-it.sh elasticsearch:9200 -t 30 && python3.8 local.graphql.server.py" + entrypoint: /bin/bash -c "../build/wait-for-it.sh elasticsearch:9200 -t 30 && python3.8 local_graphql_server.py" expose: - 5000 ports: - 5000:5000 environment: envname: 'dkrcompose' + config_location: "/config.json" AWS_DEFAULT_REGION: "eu-west-1" volumes: - ./backend:/code - $HOME/.aws/credentials:/root/.aws/credentials:ro + - ./config.json:/config.json depends_on: - db - elasticsearch diff --git a/frontend/src/api/Environment/createEnvironment.js b/frontend/src/api/Environment/createEnvironment.js index 1314f0bf3..af4972d84 100644 --- a/frontend/src/api/Environment/createEnvironment.js +++ b/frontend/src/api/Environment/createEnvironment.js @@ -14,10 +14,13 @@ const createEnvironment = (input) => ({ AwsAccountId created dashboardsEnabled - notebooksEnabled mlStudiosEnabled pipelinesEnabled warehousesEnabled + parameters { + key + value + } } } ` diff --git a/frontend/src/api/Environment/getEnvironment.js b/frontend/src/api/Environment/getEnvironment.js index 6096392f7..70ccc54a5 100644 --- a/frontend/src/api/Environment/getEnvironment.js +++ b/frontend/src/api/Environment/getEnvironment.js @@ -15,7 +15,6 @@ const getEnvironment = ({ environmentUri }) => ({ label AwsAccountId dashboardsEnabled - notebooksEnabled mlStudiosEnabled pipelinesEnabled warehousesEnabled @@ -51,7 +50,6 @@ const getEnvironment = ({ environmentUri }) => ({ resources } dashboardsEnabled - notebooksEnabled mlStudiosEnabled pipelinesEnabled warehousesEnabled @@ -60,6 +58,10 @@ const getEnvironment = ({ environmentUri }) => ({ privateSubnetIds publicSubnetIds } + parameters { + key + value + } } } ` diff --git a/frontend/src/api/Environment/listOrganizationEnvironments.js b/frontend/src/api/Environment/listOrganizationEnvironments.js index 23834051a..c5244c34b 100644 --- a/frontend/src/api/Environment/listOrganizationEnvironments.js +++ b/frontend/src/api/Environment/listOrganizationEnvironments.js @@ -33,7 +33,6 @@ const listOrganizationEnvironments = ({ organizationUri, filter }) => ({ environmentType AwsAccountId dashboardsEnabled - notebooksEnabled mlStudiosEnabled pipelinesEnabled warehousesEnabled @@ -50,6 +49,10 @@ const listOrganizationEnvironments = ({ organizationUri, filter }) => ({ outputs resources } + parameters { + key + value + } } } } diff --git a/frontend/src/api/Environment/updateEnvironment.js b/frontend/src/api/Environment/updateEnvironment.js index 9a715b91f..ddd851637 100644 --- a/frontend/src/api/Environment/updateEnvironment.js +++ b/frontend/src/api/Environment/updateEnvironment.js @@ -17,11 +17,14 @@ const updateEnvironment = ({ environmentUri, input }) => ({ SamlGroupName AwsAccountId dashboardsEnabled - notebooksEnabled mlStudiosEnabled pipelinesEnabled warehousesEnabled created + parameters { + key + value + } } } ` diff --git a/frontend/src/views/Environments/EnvironmentCreateForm.js b/frontend/src/views/Environments/EnvironmentCreateForm.js index 77bc5ee04..90732f0da 100644 --- a/frontend/src/views/Environments/EnvironmentCreateForm.js +++ b/frontend/src/views/Environments/EnvironmentCreateForm.js @@ -154,12 +154,17 @@ const EnvironmentCreateForm = (props) => { description: values.description, region: values.region, dashboardsEnabled: values.dashboardsEnabled, - notebooksEnabled: values.notebooksEnabled, mlStudiosEnabled: values.mlStudiosEnabled, pipelinesEnabled: values.pipelinesEnabled, warehousesEnabled: values.warehousesEnabled, EnvironmentDefaultIAMRoleName: values.EnvironmentDefaultIAMRoleName, - resourcePrefix: values.resourcePrefix + resourcePrefix: values.resourcePrefix, + parameters: [ + { + key: "notebooksEnabled", + value: String(values.notebooksEnabled) + } + ] }) ); if (!response.errors) { diff --git a/frontend/src/views/Environments/EnvironmentEditForm.js b/frontend/src/views/Environments/EnvironmentEditForm.js index 2f7c53342..45676a18c 100644 --- a/frontend/src/views/Environments/EnvironmentEditForm.js +++ b/frontend/src/views/Environments/EnvironmentEditForm.js @@ -48,7 +48,9 @@ const EnvironmentEditForm = (props) => { getEnvironment({ environmentUri: params.uri }) ); if (!response.errors && response.data.getEnvironment) { - setEnv(response.data.getEnvironment); + const environment = response.data.getEnvironment + environment.parameters = Object.fromEntries(environment.parameters.map(x => [x.key, x.value])) + setEnv(environment); } else { const error = response.errors ? response.errors[0].message @@ -72,11 +74,16 @@ const EnvironmentEditForm = (props) => { tags: values.tags, description: values.description, dashboardsEnabled: values.dashboardsEnabled, - notebooksEnabled: values.notebooksEnabled, mlStudiosEnabled: values.mlStudiosEnabled, pipelinesEnabled: values.pipelinesEnabled, warehousesEnabled: values.warehousesEnabled, - resourcePrefix: values.resourcePrefix + resourcePrefix: values.resourcePrefix, + parameters: [ + { + key: "notebooksEnabled", + value: String(values.notebooksEnabled) + } + ] } }) ); @@ -192,7 +199,7 @@ const EnvironmentEditForm = (props) => { description: env.description, tags: env.tags || [], dashboardsEnabled: env.dashboardsEnabled, - notebooksEnabled: env.notebooksEnabled, + notebooksEnabled: env.parameters["notebooksEnabled"] === 'true', mlStudiosEnabled: env.mlStudiosEnabled, pipelinesEnabled: env.pipelinesEnabled, warehousesEnabled: env.warehousesEnabled, diff --git a/frontend/src/views/Environments/EnvironmentFeatures.js b/frontend/src/views/Environments/EnvironmentFeatures.js index af23b6ffa..07d3e52cb 100644 --- a/frontend/src/views/Environments/EnvironmentFeatures.js +++ b/frontend/src/views/Environments/EnvironmentFeatures.js @@ -51,8 +51,8 @@ const EnvironmentFeatures = (props) => { Notebooks - diff --git a/frontend/src/views/Environments/EnvironmentView.js b/frontend/src/views/Environments/EnvironmentView.js index 71c271af3..099519f1e 100644 --- a/frontend/src/views/Environments/EnvironmentView.js +++ b/frontend/src/views/Environments/EnvironmentView.js @@ -122,13 +122,11 @@ const EnvironmentView = () => { getEnvironment({ environmentUri: params.uri }) ); if (!response.errors && response.data.getEnvironment) { - setEnv(response.data.getEnvironment); - setStack(response.data.getEnvironment.stack); - setIsAdmin( - ['Admin', 'Owner'].indexOf( - response.data.getEnvironment.userRoleInEnvironment - ) !== -1 - ); + const environment = response.data.getEnvironment + environment.parameters = Object.fromEntries(environment.parameters.map(x => [x.key, x.value])) + setEnv(environment); + setStack(environment.stack); + setIsAdmin(['Admin', 'Owner'].indexOf(environment.userRoleInEnvironment) !== -1); } else { const error = response.errors ? response.errors[0].message diff --git a/tests/api/client.py b/tests/api/client.py index 20dbb8c22..f54c45d66 100644 --- a/tests/api/client.py +++ b/tests/api/client.py @@ -7,6 +7,10 @@ from flask import Flask, request, jsonify, Response from munch import DefaultMunch import dataall +from dataall.core.context import set_context, dispose_context, RequestContext +from dataall.core.config import config + +config.set_property("cdk_proxy_url", "mock_url") class ClientWrapper: @@ -58,6 +62,9 @@ def graphql_server(): username = request.headers.get('Username', 'anonym') groups = json.loads(request.headers.get('Groups', '[]')) + + set_context(RequestContext(db, username, groups, es)) + success, result = graphql_sync( schema, data, @@ -67,11 +74,11 @@ def graphql_server(): 'username': username, 'groups': groups, 'es': es, - 'cdkproxyurl': 'cdkproxyurl', }, debug=app.debug, ) + dispose_context() status_code = 200 if success else 400 return jsonify(result), status_code diff --git a/tests/api/conftest.py b/tests/api/conftest.py index e2541ac72..65dc6934b 100644 --- a/tests/api/conftest.py +++ b/tests/api/conftest.py @@ -297,7 +297,10 @@ def factory( def env(client): cache = {} - def factory(org, envname, owner, group, account, region, desc='test'): + def factory(org, envname, owner, group, account, region, desc='test', parameters=None): + if parameters == None: + parameters = {} + key = f"{org.organizationUri}{envname}{owner}{''.join(group or '-')}{account}{region}" if cache.get(key): return cache[key] @@ -314,6 +317,10 @@ def factory(org, envname, owner, group, account, region, desc='test'): region name owner + parameters { + key + value + } } }""", username=f'{owner}', @@ -328,6 +335,7 @@ def factory(org, envname, owner, group, account, region, desc='test'): 'SamlGroupName': f'{group}', 'dashboardsEnabled': True, 'vpcId': 'vpc-123456', + 'parameters': [{'key': k, 'value': v} for k, v in parameters.items()] }, ) cache[key] = response.data.createEnvironment @@ -671,42 +679,6 @@ def cluster(env_fixture, org_fixture, client, group): yield res.data.createRedshiftCluster -@pytest.fixture(scope='module') -def sgm_notebook(client, tenant, group, env_fixture) -> dataall.db.models.SagemakerNotebook: - response = client.query( - """ - mutation createSagemakerNotebook($input:NewSagemakerNotebookInput){ - createSagemakerNotebook(input:$input){ - notebookUri - label - description - tags - owner - userRoleForNotebook - SamlAdminGroupName - VpcId - SubnetId - VolumeSizeInGB - InstanceType - } - } - """, - input={ - 'label': 'my pipeline', - 'SamlAdminGroupName': group.name, - 'tags': [group.name], - 'environmentUri': env_fixture.environmentUri, - 'VpcId': 'vpc-123567', - 'SubnetId': 'subnet-123567', - 'VolumeSizeInGB': 32, - 'InstanceType': 'ml.m5.xlarge', - }, - username='alice', - groups=[group.name], - ) - yield response.data.createSagemakerNotebook - - @pytest.fixture(scope='module') def pipeline(client, tenant, group, env_fixture) -> models.DataPipeline: response = client.query( diff --git a/tests/api/test_environment.py b/tests/api/test_environment.py index e961a445c..a84535d41 100644 --- a/tests/api/test_environment.py +++ b/tests/api/test_environment.py @@ -20,8 +20,8 @@ def env1(env, org1, user, group, tenant, module_mocker): yield env1 -def test_get_environment(client, org1, env1, group): - response = client.query( +def get_env(client, env1, group): + return client.query( """ query GetEnv($environmentUri:String!){ getEnvironment(environmentUri:$environmentUri){ @@ -35,7 +35,6 @@ def test_get_environment(client, org1, env1, group): SamlGroupName owner dashboardsEnabled - notebooksEnabled mlStudiosEnabled pipelinesEnabled warehousesEnabled @@ -43,6 +42,10 @@ def test_get_environment(client, org1, env1, group): EcsTaskArn EcsTaskId } + parameters { + key + value + } } } """, @@ -50,6 +53,9 @@ def test_get_environment(client, org1, env1, group): environmentUri=env1.environmentUri, groups=[group.name], ) + +def test_get_environment(client, org1, env1, group): + response = get_env(client, env1, group) assert ( response.data.getEnvironment.organization.organizationUri == org1.organizationUri @@ -57,7 +63,6 @@ def test_get_environment(client, org1, env1, group): assert response.data.getEnvironment.owner == 'alice' assert response.data.getEnvironment.AwsAccountId == env1.AwsAccountId assert response.data.getEnvironment.dashboardsEnabled - assert response.data.getEnvironment.notebooksEnabled assert response.data.getEnvironment.mlStudiosEnabled assert response.data.getEnvironment.pipelinesEnabled assert response.data.getEnvironment.warehousesEnabled @@ -88,8 +93,7 @@ def test_get_environment_object_not_found(client, org1, env1, group): def test_update_env(client, org1, env1, group): - response = client.query( - """ + query = """ mutation UpdateEnv($environmentUri:String!,$input:ModifyEnvironmentInput){ updateEnvironment(environmentUri:$environmentUri,input:$input){ organization{ @@ -103,63 +107,55 @@ def test_update_env(client, org1, env1, group): tags resourcePrefix dashboardsEnabled - notebooksEnabled mlStudiosEnabled pipelinesEnabled warehousesEnabled - + parameters { + key + value + } } } - """, + """ + + response = client.query(query, username='alice', environmentUri=env1.environmentUri, input={ 'label': 'DEV', 'tags': ['test', 'env'], 'dashboardsEnabled': False, - 'notebooksEnabled': False, 'mlStudiosEnabled': False, 'pipelinesEnabled': False, 'warehousesEnabled': False, + 'parameters': [ + { + 'key': 'notebooksEnabled', + 'value': 'True' + } + ], 'resourcePrefix': 'customer-prefix_AZ390 ', }, groups=[group.name], ) assert 'InvalidInput' in response.errors[0].message - response = client.query( - """ - mutation UpdateEnv($environmentUri:String!,$input:ModifyEnvironmentInput){ - updateEnvironment(environmentUri:$environmentUri,input:$input){ - organization{ - organizationUri - } - label - AwsAccountId - region - SamlGroupName - owner - tags - resourcePrefix - dashboardsEnabled - notebooksEnabled - mlStudiosEnabled - pipelinesEnabled - warehousesEnabled - - } - } - """, + response = client.query(query, username='alice', environmentUri=env1.environmentUri, input={ 'label': 'DEV', 'tags': ['test', 'env'], 'dashboardsEnabled': False, - 'notebooksEnabled': False, 'mlStudiosEnabled': False, 'pipelinesEnabled': False, 'warehousesEnabled': False, + 'parameters': [ + { + 'key': 'notebooksEnabled', + 'value': 'True' + } + ], 'resourcePrefix': 'customer-prefix', }, groups=[group.name], @@ -178,9 +174,54 @@ def test_update_env(client, org1, env1, group): assert not response.data.updateEnvironment.mlStudiosEnabled assert not response.data.updateEnvironment.pipelinesEnabled assert not response.data.updateEnvironment.warehousesEnabled + assert response.data.updateEnvironment.parameters + assert response.data.updateEnvironment.parameters[0]["key"] == "notebooksEnabled" + assert response.data.updateEnvironment.parameters[0]["value"] == "True" assert response.data.updateEnvironment.resourcePrefix == 'customer-prefix' +def test_update_params(client, org1, env1, group): + def update_params(parameters): + return client.query( + query, + username='alice', + environmentUri=env1.environmentUri, + input=parameters, + groups=[group.name], + ) + + query = """ + mutation UpdateEnv($environmentUri:String!,$input:ModifyEnvironmentInput){ + updateEnvironment(environmentUri:$environmentUri,input:$input){ + parameters { + key + value + } + } + } + """ + + notebooks_enabled = {'parameters': [ {'key': 'notebooksEnabled','value': 'True'}]} + environment = update_params(notebooks_enabled).data.updateEnvironment + assert len(environment.parameters) + assert environment.parameters[0]["key"] == "notebooksEnabled" + assert environment.parameters[0]["value"] == "True" + + # parameters should be rewritten. Notebooks should go away + dashboards_enabled = {'parameters': [{'key': 'dashboardsEnabled', 'value': 'True'}]} + environment = update_params(dashboards_enabled).data.updateEnvironment + assert len(environment.parameters) + assert environment.parameters[0]["key"] == "dashboardsEnabled" + assert environment.parameters[0]["value"] == "True" + + # retrieve the environment one more time via GraphQL API, to check if it's correct + response = get_env(client, env1, group) + environment = response.data.getEnvironment + assert len(environment.parameters) == 1 + assert environment.parameters[0]["key"] == "dashboardsEnabled" + assert environment.parameters[0]["value"] == "True" + + def test_unauthorized_update(client, org1, env1): response = client.query( """ diff --git a/tests/api/test_stack.py b/tests/api/test_stack.py index fd834f7e8..bcb9f28f0 100644 --- a/tests/api/test_stack.py +++ b/tests/api/test_stack.py @@ -5,7 +5,6 @@ def test_update_stack( pipeline, env_fixture, dataset_fixture, - sgm_notebook, sgm_studio, cluster, ): @@ -26,11 +25,6 @@ def test_update_stack( response.data.updateStack.targetUri == sgm_studio.sagemakerStudioUserProfileUri ) - response = update_stack_query( - client, sgm_notebook.notebookUri, 'notebook', group.name - ) - assert response.data.updateStack.targetUri == sgm_notebook.notebookUri - response = update_stack_query(client, cluster.clusterUri, 'redshift', group.name) assert response.data.updateStack.targetUri == cluster.clusterUri diff --git a/tests/cdkproxy/conftest.py b/tests/cdkproxy/conftest.py index d2160dde5..c83d0028b 100644 --- a/tests/cdkproxy/conftest.py +++ b/tests/cdkproxy/conftest.py @@ -159,25 +159,6 @@ def sgm_studio(db, env: models.Environment) -> models.SagemakerStudioUserProfile yield notebook -@pytest.fixture(scope='module', autouse=True) -def notebook(db, env: models.Environment) -> models.SagemakerNotebook: - with db.scoped_session() as session: - notebook = models.SagemakerNotebook( - label='thistable', - NotebookInstanceStatus='RUNNING', - owner='me', - AWSAccountId=env.AwsAccountId, - region=env.region, - environmentUri=env.environmentUri, - RoleArn=env.EnvironmentDefaultIAMRoleArn, - SamlAdminGroupName='admins', - VolumeSizeInGB=32, - InstanceType='ml.t3.medium', - ) - session.add(notebook) - yield notebook - - @pytest.fixture(scope='module', autouse=True) def pipeline1(db, env: models.Environment) -> models.DataPipeline: with db.scoped_session() as session: @@ -234,6 +215,7 @@ def pip_envs(db, env: models.Environment, pipeline1: models.DataPipeline) -> mod yield api.Pipeline.query_pipeline_environments(session=session, uri=pipeline1.DataPipelineUri) + @pytest.fixture(scope='module', autouse=True) def redshift_cluster(db, env: models.Environment) -> models.RedshiftCluster: with db.scoped_session() as session: diff --git a/tests/conftest.py b/tests/conftest.py index b247f0659..a67d6bd41 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -1,7 +1,9 @@ import os import pytest import dataall +from dataall.modules.loader import load_modules, ImportMode +load_modules(modes=[ImportMode.TASKS, ImportMode.API, ImportMode.CDK]) ENVNAME = os.environ.get('envname', 'pytest') diff --git a/tests/core/__init__.py b/tests/core/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/tests/core/test_config.py b/tests/core/test_config.py new file mode 100644 index 000000000..3222e4144 --- /dev/null +++ b/tests/core/test_config.py @@ -0,0 +1,28 @@ +from dataall.core.config import config + + +def test_config(): + config.set_property("k1", "v1") + assert config.get_property("k1") == "v1" + + assert config.get_property("not_exist", "default1") == "default1" + + config.set_property("a.b.c", "d") + assert config.get_property("a.b.c") == "d" + assert "c" in config.get_property("a.b") + assert "k" not in config.get_property("a.b") + assert config.get_property("a.b.k", "default2") == "default2" + assert "b" in config.get_property("a") + + config.set_property("a.b.e", "f") + assert config.get_property("a.b.c") == "d" + assert config.get_property("a.b.e") == "f" + + +def test_default_config(): + """Checks that properties are read correctly""" + modules = config.get_property("modules") + assert "notebooks" in modules + assert "active" in modules["notebooks"] + + assert config.get_property("modules.notebooks.active") == "true" diff --git a/tests/modules/__init__.py b/tests/modules/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/tests/modules/notebooks/__init__.py b/tests/modules/notebooks/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/tests/modules/notebooks/cdk/__init__.py b/tests/modules/notebooks/cdk/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/tests/modules/notebooks/cdk/conftest.py b/tests/modules/notebooks/cdk/conftest.py new file mode 100644 index 000000000..9bc2ff169 --- /dev/null +++ b/tests/modules/notebooks/cdk/conftest.py @@ -0,0 +1,34 @@ +import pytest + +from dataall.modules.notebooks.db.models import SagemakerNotebook +from dataall.db import models +from tests.cdkproxy.conftest import org, env + + +@pytest.fixture(scope='module', autouse=True) +def stack_org(db) -> models.Organization: + yield org + + +@pytest.fixture(scope='module', autouse=True) +def stack_env(db, stack_org: models.Organization) -> models.Environment: + yield env + + +@pytest.fixture(scope='module', autouse=True) +def notebook(db, env: models.Environment) -> SagemakerNotebook: + with db.scoped_session() as session: + notebook = SagemakerNotebook( + label='thistable', + NotebookInstanceStatus='RUNNING', + owner='me', + AWSAccountId=env.AwsAccountId, + region=env.region, + environmentUri=env.environmentUri, + RoleArn=env.EnvironmentDefaultIAMRoleArn, + SamlAdminGroupName='admins', + VolumeSizeInGB=32, + InstanceType='ml.t3.medium', + ) + session.add(notebook) + yield notebook diff --git a/tests/cdkproxy/test_sagemaker_notebook_stack.py b/tests/modules/notebooks/cdk/test_sagemaker_notebook_stack.py similarity index 79% rename from tests/cdkproxy/test_sagemaker_notebook_stack.py rename to tests/modules/notebooks/cdk/test_sagemaker_notebook_stack.py index a17b673f2..6cc8dc442 100644 --- a/tests/cdkproxy/test_sagemaker_notebook_stack.py +++ b/tests/modules/notebooks/cdk/test_sagemaker_notebook_stack.py @@ -3,21 +3,21 @@ import pytest from aws_cdk import App -from dataall.cdkproxy.stacks import SagemakerNotebook +from dataall.modules.notebooks.cdk.stacks import NotebookStack @pytest.fixture(scope='function', autouse=True) def patch_methods(mocker, db, notebook, env, org): mocker.patch( - 'dataall.cdkproxy.stacks.notebook.SagemakerNotebook.get_engine', - return_value=db, + 'dataall.modules.notebooks.cdk.stacks.NotebookStack.get_engine', + return_value=db ) mocker.patch( 'dataall.aws.handlers.sts.SessionHelper.get_delegation_role_name', return_value="dataall-pivot-role-name-pytest", ) mocker.patch( - 'dataall.cdkproxy.stacks.notebook.SagemakerNotebook.get_target', + 'dataall.modules.notebooks.cdk.stacks.NotebookStack.get_target', return_value=notebook, ) mocker.patch( @@ -40,7 +40,7 @@ def patch_methods(mocker, db, notebook, env, org): @pytest.fixture(scope='function', autouse=True) def template(notebook): app = App() - SagemakerNotebook(app, 'SagemakerNotebook', target_uri=notebook.notebookUri) + NotebookStack(app, 'SagemakerNotebook', target_uri=notebook.notebookUri) return json.dumps(app.synth().get_stack_by_name('SagemakerNotebook').template) diff --git a/tests/modules/notebooks/conftest.py b/tests/modules/notebooks/conftest.py new file mode 100644 index 000000000..7ba5fbe5a --- /dev/null +++ b/tests/modules/notebooks/conftest.py @@ -0,0 +1,50 @@ +import pytest + +from dataall.modules.notebooks.db.models import SagemakerNotebook +from tests.api.client import client, app +from tests.api.conftest import * + + +@pytest.fixture(scope='module') +def env_fixture(env, org_fixture, user, group, tenant, module_mocker): + module_mocker.patch('requests.post', return_value=True) + module_mocker.patch('dataall.api.Objects.Environment.resolvers.check_environment', return_value=True) + env1 = env(org_fixture, 'dev', 'alice', 'testadmins', '111111111111', 'eu-west-1', + parameters={'notebooksEnabled': 'True'}) + yield env1 + + +@pytest.fixture(scope='module') +def sgm_notebook(client, tenant, group, env_fixture) -> SagemakerNotebook: + response = client.query( + """ + mutation createSagemakerNotebook($input:NewSagemakerNotebookInput){ + createSagemakerNotebook(input:$input){ + notebookUri + label + description + tags + owner + userRoleForNotebook + SamlAdminGroupName + VpcId + SubnetId + VolumeSizeInGB + InstanceType + } + } + """, + input={ + 'label': 'my best notebook ever', + 'SamlAdminGroupName': group.name, + 'tags': [group.name], + 'environmentUri': env_fixture.environmentUri, + 'VpcId': 'vpc-123567', + 'SubnetId': 'subnet-123567', + 'VolumeSizeInGB': 32, + 'InstanceType': 'ml.m5.xlarge', + }, + username='alice', + groups=[group.name], + ) + yield response.data.createSagemakerNotebook diff --git a/tests/modules/notebooks/test_notebook_stack.py b/tests/modules/notebooks/test_notebook_stack.py new file mode 100644 index 000000000..fc65e9af4 --- /dev/null +++ b/tests/modules/notebooks/test_notebook_stack.py @@ -0,0 +1,18 @@ + +def test_notebook_stack(client, sgm_notebook, group): + response = client.query( + """ + mutation updateStack($targetUri:String!, $targetType:String!){ + updateStack(targetUri:$targetUri, targetType:$targetType){ + stackUri + targetUri + name + } + } + """, + targetUri=sgm_notebook.notebookUri, + targetType="notebook", + username="alice", + groups=[group.name], + ) + assert response.data.updateStack.targetUri == sgm_notebook.notebookUri \ No newline at end of file diff --git a/tests/api/test_sagemaker_notebook.py b/tests/modules/notebooks/test_sagemaker_notebook.py similarity index 64% rename from tests/api/test_sagemaker_notebook.py rename to tests/modules/notebooks/test_sagemaker_notebook.py index 1861936ad..5fd4e4d16 100644 --- a/tests/api/test_sagemaker_notebook.py +++ b/tests/modules/notebooks/test_sagemaker_notebook.py @@ -1,6 +1,16 @@ import pytest -import dataall + +class MockSagemakerClient: + def start_instance(self): + return "Starting" + + def stop_instance(self): + return True + + def get_notebook_instance_status(self): + return "INSERVICE" + @pytest.fixture(scope='module') @@ -10,75 +20,35 @@ def org1(org, user, group, tenant): @pytest.fixture(scope='module') -def env1(env, org1, user, group, tenant, module_mocker): +def env1(env, org1, user, group, tenant, db, module_mocker): module_mocker.patch('requests.post', return_value=True) module_mocker.patch( 'dataall.api.Objects.Environment.resolvers.check_environment', return_value=True ) - env1 = env(org1, 'dev', user.userName, group.name, '111111111111', 'eu-west-1') + env1 = env(org1, 'dev', user.userName, group.name, '111111111111', 'eu-west-1', + parameters={"notebooksEnabled": "True"}) yield env1 -@pytest.fixture(scope='module', autouse=True) -def sgm_notebook(client, tenant, group, env1) -> dataall.db.models.SagemakerNotebook: - response = client.query( - """ - mutation createSagemakerNotebook($input:NewSagemakerNotebookInput){ - createSagemakerNotebook(input:$input){ - notebookUri - label - description - tags - owner - userRoleForNotebook - SamlAdminGroupName - VpcId - SubnetId - VolumeSizeInGB - InstanceType - } - } - """, - input={ - 'label': 'my pipeline', - 'SamlAdminGroupName': group.name, - 'tags': [group.name], - 'environmentUri': env1.environmentUri, - 'VpcId': 'vpc-123567', - 'SubnetId': 'subnet-123567', - 'VolumeSizeInGB': 32, - 'InstanceType': 'ml.m5.xlarge', - }, - username='alice', - groups=[group.name], - ) - assert response.data.createSagemakerNotebook.notebookUri - assert response.data.createSagemakerNotebook.SamlAdminGroupName == group.name - assert response.data.createSagemakerNotebook.VpcId == 'vpc-123567' - assert response.data.createSagemakerNotebook.SubnetId == 'subnet-123567' - assert response.data.createSagemakerNotebook.InstanceType == 'ml.m5.xlarge' - assert response.data.createSagemakerNotebook.VolumeSizeInGB == 32 - return response.data.createSagemakerNotebook +def test_sgm_notebook(sgm_notebook, group): + assert sgm_notebook.notebookUri + assert sgm_notebook.SamlAdminGroupName == group.name + assert sgm_notebook.VpcId == 'vpc-123567' + assert sgm_notebook.SubnetId == 'subnet-123567' + assert sgm_notebook.InstanceType == 'ml.m5.xlarge' + assert sgm_notebook.VolumeSizeInGB == 32 @pytest.fixture(scope='module', autouse=True) def patch_aws(module_mocker): module_mocker.patch( - 'dataall.aws.handlers.sagemaker.Sagemaker.start_instance', - return_value='Starting', - ) - module_mocker.patch( - 'dataall.aws.handlers.sagemaker.Sagemaker.stop_instance', return_value=True - ) - module_mocker.patch( - 'dataall.aws.handlers.sagemaker.Sagemaker.get_notebook_instance_status', - return_value='INSERVICE', + "dataall.modules.notebooks.services.services.client", + return_value=MockSagemakerClient(), ) -def test_list_notebooks(client, env1, db, org1, user, group, sgm_notebook, patch_aws): - response = client.query( - """ +def test_list_notebooks(client, user, group, sgm_notebook): + query = """ query ListSagemakerNotebooks($filter:SagemakerNotebookFilter){ listSagemakerNotebooks(filter:$filter){ count @@ -94,17 +64,28 @@ def test_list_notebooks(client, env1, db, org1, user, group, sgm_notebook, patch } } } - """, + """ + + response = client.query( + query, filter=None, username=user.userName, groups=[group.name], ) + + assert len(response.data.listSagemakerNotebooks['nodes']) == 1 + + response = client.query( + query, + filter={"term": "my best"}, + username=user.userName, + groups=[group.name], + ) + assert len(response.data.listSagemakerNotebooks['nodes']) == 1 -def test_nopermissions_list_notebooks( - client, env1, db, org1, user2, group2, sgm_notebook, patch_aws -): +def test_nopermissions_list_notebooks(client, user2, group2, sgm_notebook): response = client.query( """ query ListSagemakerNotebooks($filter:SagemakerNotebookFilter){ @@ -130,7 +111,7 @@ def test_nopermissions_list_notebooks( assert len(response.data.listSagemakerNotebooks['nodes']) == 0 -def test_get_notebook(client, env1, db, org1, user, group, sgm_notebook, patch_aws): +def test_get_notebook(client, user, group, sgm_notebook): response = client.query( """ @@ -148,7 +129,7 @@ def test_get_notebook(client, env1, db, org1, user, group, sgm_notebook, patch_a assert response.data.getSagemakerNotebook.notebookUri == sgm_notebook.notebookUri -def test_action_notebook(client, env1, db, org1, user, group, sgm_notebook, patch_aws): +def test_action_notebook(client, user, group, sgm_notebook): response = client.query( """ mutation stopSagemakerNotebook($notebookUri:String!){ @@ -174,7 +155,7 @@ def test_action_notebook(client, env1, db, org1, user, group, sgm_notebook, patc assert response.data.startSagemakerNotebook == 'Starting' -def test_delete_notebook(client, env1, db, org1, user, group, patch_aws, sgm_notebook): +def test_delete_notebook(client, user, group, sgm_notebook): response = client.query( """ diff --git a/tests/utils/clients/graphql.py b/tests/utils/clients/graphql.py index 8117e437e..b44cd68b5 100644 --- a/tests/utils/clients/graphql.py +++ b/tests/utils/clients/graphql.py @@ -7,6 +7,7 @@ from flask import Flask, request, jsonify, Response from dotted.collection import DottedCollection import dataall +from dataall.core.context import set_context, RequestContext, dispose_context class ClientWrapper: @@ -59,6 +60,8 @@ def graphql_server(): username = request.headers.get('Username', 'anonym') groups = json.loads(request.headers.get('Groups', '[]')) + + set_context(RequestContext(db, username, groups)) success, result = graphql_sync( schema, data, @@ -71,6 +74,7 @@ def graphql_server(): debug=app.debug, ) + dispose_context() status_code = 200 if success else 400 return jsonify(result), status_code