diff --git a/metadata/metadata_service/api/swagger_doc/user/detail_put.yml b/metadata/metadata_service/api/swagger_doc/user/detail_put.yml new file mode 100644 index 0000000000..b0677e497e --- /dev/null +++ b/metadata/metadata_service/api/swagger_doc/user/detail_put.yml @@ -0,0 +1,36 @@ +Create or update a user +--- +tags: + - 'user' +requestBody: + content: + application/json: + schema: + $ref: '#/components/schemas/UserDetailFields' + description: User attribute fields + required: true +responses: + 200: + description: 'Existing user found and updated' + content: + application/json: + schema: + $ref: '#/components/schemas/UserDetailFields' + 201: + description: 'New user created' + content: + application/json: + schema: + $ref: '#/components/schemas/UserDetailFields' + 400: + description: 'Bad Request' + content: + application/json: + schema: + $ref: '#/components/schemas/ErrorResponse' + 500: + description: 'Internal server error' + content: + application/json: + schema: + $ref: '#/components/schemas/ErrorResponse' diff --git a/metadata/metadata_service/api/user.py b/metadata/metadata_service/api/user.py index 4fc0713986..1f6c3ff498 100644 --- a/metadata/metadata_service/api/user.py +++ b/metadata/metadata_service/api/user.py @@ -1,6 +1,7 @@ # Copyright Contributors to the Amundsen project. # SPDX-License-Identifier: Apache-2.0 +import json import logging from http import HTTPStatus from typing import (Any, Dict, Iterable, List, Mapping, Optional, # noqa: F401 @@ -11,7 +12,9 @@ from amundsen_common.models.user import UserSchema from flasgger import swag_from from flask import current_app as app +from flask import request from flask_restful import Resource +from marshmallow.exceptions import ValidationError as SchemaValidationError from metadata_service.api import BaseAPI from metadata_service.entity.resource_type import (ResourceType, @@ -44,6 +47,34 @@ def get(self, *, id: Optional[str] = None) -> Iterable[Union[Mapping, int, None] else: return super().get(id=id) + @swag_from('swagger_doc/user/detail_put.yml') + def put(self) -> Iterable[Union[Mapping, int, None]]: + """ + Create or update a user. Serializes the data in the request body + using the UserSchema, validating the inputs in the process to ensure + all downstream processes leverage clean data, and passes the User + object to the client to create or update the user record. + """ + if not request.data: + return {'message': 'No user information provided in the request.'}, HTTPStatus.BAD_REQUEST + + try: + user_attributes = json.loads(request.data) + schema = UserSchema() + user = schema.load(user_attributes) + + new_user, user_created = self.client.create_update_user(user=user) + resp_code = HTTPStatus.CREATED if user_created else HTTPStatus.OK + return schema.dumps(new_user), resp_code + + except SchemaValidationError as schema_err: + err_msg = 'User inputs provided are not valid: %s' % schema_err + return {'message': err_msg}, HTTPStatus.BAD_REQUEST + + except Exception: + LOGGER.exception('UserDetailAPI PUT Failed') + return {'message': 'Internal server error!'}, HTTPStatus.INTERNAL_SERVER_ERROR + class UserFollowsAPI(Resource): """ diff --git a/metadata/metadata_service/proxy/atlas_proxy.py b/metadata/metadata_service/proxy/atlas_proxy.py index 68344b80e3..b5579d9aba 100644 --- a/metadata/metadata_service/proxy/atlas_proxy.py +++ b/metadata/metadata_service/proxy/atlas_proxy.py @@ -454,6 +454,9 @@ def _get_owners(self, data_owners: list, fallback_owner: str = None) -> List[Use def get_user(self, *, id: str) -> Union[UserEntity, None]: pass + def create_update_user(self, *, user: User) -> Tuple[User, bool]: + pass + def get_users(self) -> List[UserEntity]: pass diff --git a/metadata/metadata_service/proxy/base_proxy.py b/metadata/metadata_service/proxy/base_proxy.py index db476badf2..1928c9f4ee 100644 --- a/metadata/metadata_service/proxy/base_proxy.py +++ b/metadata/metadata_service/proxy/base_proxy.py @@ -2,7 +2,7 @@ # SPDX-License-Identifier: Apache-2.0 from abc import ABCMeta, abstractmethod -from typing import Any, Dict, List, Optional, Union +from typing import Any, Dict, List, Optional, Tuple, Union from amundsen_common.models.dashboard import DashboardSummary from amundsen_common.models.lineage import Lineage @@ -27,6 +27,18 @@ class BaseProxy(metaclass=ABCMeta): def get_user(self, *, id: str) -> Union[User, None]: pass + @abstractmethod + def create_update_user(self, *, user: User) -> Tuple[User, bool]: + """ + Allows creating and updating users. Returns a tuple of the User + object that has been created or updated as well as a flag that + depicts whether or no the user was created or updated. + + :param user: a User object + :return: Tuple of [User object, bool (True = created, False = updated)] + """ + pass + @abstractmethod def get_users(self) -> List[User]: pass diff --git a/metadata/metadata_service/proxy/gremlin_proxy.py b/metadata/metadata_service/proxy/gremlin_proxy.py index 13f975e63b..1d8cfa011a 100644 --- a/metadata/metadata_service/proxy/gremlin_proxy.py +++ b/metadata/metadata_service/proxy/gremlin_proxy.py @@ -8,7 +8,7 @@ from datetime import date, datetime, timedelta from operator import attrgetter from typing import (Any, Callable, Dict, Iterable, List, Mapping, Optional, - Sequence, Set, Type, TypeVar, Union, no_type_check, + Sequence, Set, Tuple, Type, TypeVar, Union, no_type_check, overload) from urllib.parse import unquote @@ -1010,6 +1010,9 @@ def _get_user(self, *, id: str, executor: ExecuteQuery) -> Union[User, None]: user.manager_fullname = _safe_get(managers[0], 'full_name', default=None) if managers else None return user + def create_update_user(self, *, user: User) -> Tuple[User, bool]: + pass + @timer_with_counter @overrides def get_users(self) -> List[User]: diff --git a/metadata/metadata_service/proxy/neo4j_proxy.py b/metadata/metadata_service/proxy/neo4j_proxy.py index b334c8fc0e..fbf2482969 100644 --- a/metadata/metadata_service/proxy/neo4j_proxy.py +++ b/metadata/metadata_service/proxy/neo4j_proxy.py @@ -18,6 +18,7 @@ Reader, Source, Stat, Table, Tag, User, Watermark) from amundsen_common.models.user import User as UserEntity +from amundsen_common.models.user import UserSchema from beaker.cache import CacheManager from beaker.util import parse_cache_config_options from flask import current_app, has_app_context @@ -42,6 +43,12 @@ # Expire cache every 11 hours + jitter _GET_POPULAR_TABLE_CACHE_EXPIRY_SEC = 11 * 60 * 60 + randint(0, 3600) + +CREATED_EPOCH_MS = 'publisher_created_epoch_ms' +LAST_UPDATED_EPOCH_MS = 'publisher_last_updated_epoch_ms' +PUBLISHED_TAG_PROPERTY_NAME = 'published_tag' + + LOGGER = logging.getLogger(__name__) @@ -958,6 +965,61 @@ def get_user(self, *, id: str) -> Union[UserEntity, None]: return self._build_user_from_record(record=record, manager_name=manager_name) + def create_update_user(self, *, user: User) -> Tuple[User, bool]: + """ + Create a user if it does not exist, otherwise update the user. Required + fields for creating / updating a user are validated upstream to this when + the User object is created. + + :param user: + :return: + """ + user_data = UserSchema().dump(user) + user_props = self._create_props_body(user_data, 'usr') + + create_update_user_query = textwrap.dedent(""" + MERGE (usr:User {key: $user_id}) + on CREATE SET %s, usr.%s=timestamp() + on MATCH SET %s + RETURN usr, usr.%s = timestamp() as created + """ % (user_props, CREATED_EPOCH_MS, user_props, CREATED_EPOCH_MS)) + + try: + tx = self._driver.session().begin_transaction() + result = tx.run(create_update_user_query, user_data) + + user_result = result.single() + if not user_result: + raise RuntimeError('Failed to create user with data %s' % user_data) + tx.commit() + + new_user = self._build_user_from_record(user_result['usr']) + new_user_created = True if user_result['created'] is True else False + + except Exception as e: + if not tx.closed(): + tx.rollback() + # propagate the exception back to api + raise e + + return new_user, new_user_created + + def _create_props_body(self, + record_dict: dict, + identifier: str) -> str: + """ + Creates a Neo4j property body by converting a dictionary into a comma + separated string of KEY = VALUE. + """ + props = [] + for k, v in record_dict.items(): + if v: + props.append(f'{identifier}.{k} = ${k}') + + props.append(f"{identifier}.{PUBLISHED_TAG_PROPERTY_NAME} = 'api_create_update_user'") + props.append(f"{identifier}.{LAST_UPDATED_EPOCH_MS} = timestamp()") + return ', '.join(props) + def get_users(self) -> List[UserEntity]: statement = "MATCH (usr:User) WHERE usr.is_active = true RETURN collect(usr) as users" diff --git a/metadata/tests/unit/api/test_user.py b/metadata/tests/unit/api/test_user.py index 2dd17a0d57..5e0703a787 100644 --- a/metadata/tests/unit/api/test_user.py +++ b/metadata/tests/unit/api/test_user.py @@ -1,6 +1,7 @@ # Copyright Contributors to the Amundsen project. # SPDX-License-Identifier: Apache-2.0 +import json import unittest from http import HTTPStatus from unittest import mock @@ -37,6 +38,32 @@ def test_gets(self) -> None: self.assertEqual(list(response)[1], HTTPStatus.OK) self.mock_client.get_users.assert_called_once() + def test_put(self) -> None: + m = MagicMock() + m.data = json.dumps({'email': 'create_email@email.com'}) + with mock.patch("metadata_service.api.user.request", m): + # Test user creation + create_email = {'email': 'test_email'} + self.mock_client.create_update_user.return_value = create_email, True + test_user, test_user_created = self.api.put() + self.assertEqual(test_user, json.dumps(create_email)) + self.assertEqual(test_user_created, HTTPStatus.CREATED) + + # Test user update + update_email = {'email': 'update_email@email.com'} + self.mock_client.create_update_user.return_value = update_email, False + test_user2, test_user_updated = self.api.put() + self.assertEqual(test_user2, json.dumps(update_email)) + self.assertEqual(test_user_updated, HTTPStatus.OK) + + def test_put_no_inputs(self) -> None: + # Test no data provided + m2 = MagicMock() + m2.data = {} + with mock.patch("metadata_service.api.user.request", m2): + _, status_code = self.api.put() + self.assertEquals(status_code, HTTPStatus.BAD_REQUEST) + class UserFollowsAPITest(unittest.TestCase): diff --git a/metadata/tests/unit/proxy/test_neo4j_proxy.py b/metadata/tests/unit/proxy/test_neo4j_proxy.py index 973dd09caf..57ee7658a7 100644 --- a/metadata/tests/unit/proxy/test_neo4j_proxy.py +++ b/metadata/tests/unit/proxy/test_neo4j_proxy.py @@ -622,6 +622,24 @@ def test_get_user_other_key_values(self) -> None: neo4j_user = neo4j_proxy.get_user(id='test_email') self.assertEqual(neo4j_user.other_key_values, {'mode_user_id': 'mode_foo_bar'}) + def test_put_user_new_user(self) -> None: + """ + Test creating a new user + :return: + """ + with patch.object(GraphDatabase, 'driver') as mock_driver: + mock_transaction = mock_driver.return_value.session.return_value.begin_transaction.return_value + mock_run = mock_transaction.run + mock_commit = mock_transaction.commit + + test_user = MagicMock() + + neo4j_proxy = Neo4jProxy(host='DOES_NOT_MATTER', port=0000) + neo4j_proxy.create_update_user(user=test_user) + + self.assertEqual(mock_run.call_count, 1) + self.assertEqual(mock_commit.call_count, 1) + def test_get_users(self) -> None: with patch.object(GraphDatabase, 'driver'), patch.object(Neo4jProxy, '_execute_cypher_query') as mock_execute: test_user = {