diff --git a/flask_appbuilder/babel/manager.py b/flask_appbuilder/babel/manager.py index 523d8681e..b011e0102 100644 --- a/flask_appbuilder/babel/manager.py +++ b/flask_appbuilder/babel/manager.py @@ -1,11 +1,10 @@ import os from flask import has_request_context, request, session +from flask_appbuilder.babel.views import LocaleView +from flask_appbuilder.basemanager import BaseManager from flask_babel import Babel -from .views import LocaleView -from ..basemanager import BaseManager - class BabelManager(BaseManager): diff --git a/flask_appbuilder/models/base.py b/flask_appbuilder/models/base.py index 6a1529e6e..3d4273038 100644 --- a/flask_appbuilder/models/base.py +++ b/flask_appbuilder/models/base.py @@ -41,6 +41,8 @@ class BaseInterface: ) general_error_message = lazy_gettext("General Error") + database_error_message = lazy_gettext("Database Error") + """ Tuple with message and text with severity type ex: ("Added Row", "info") """ message = () @@ -103,13 +105,13 @@ def get_values_item(self, item, show_columns): def _get_values(self, lst, list_columns): """ - Get Values: formats values for list template. - returns [{'col_name':'col_value',....},{'col_name':'col_value',....}] + Get Values: formats values for list template. + returns [{'col_name':'col_value',....},{'col_name':'col_value',....}] - :param lst: - The list of item objects from query - :param list_columns: - The list of columns to include + :param lst: + The list of item objects from query + :param list_columns: + The list of columns to include """ retlst = [] for item in lst: @@ -121,13 +123,13 @@ def _get_values(self, lst, list_columns): def get_values(self, lst, list_columns): """ - Get Values: formats values for list template. - returns [{'col_name':'col_value',....},{'col_name':'col_value',....}] + Get Values: formats values for list template. + returns [{'col_name':'col_value',....},{'col_name':'col_value',....}] - :param lst: - The list of item objects from query - :param list_columns: - The list of columns to include + :param lst: + The list of item objects from query + :param list_columns: + The list of columns to include """ for item in lst: retdict = {} @@ -137,7 +139,7 @@ def get_values(self, lst, list_columns): def get_values_json(self, lst, list_columns): """ - Converts list of objects from query to JSON + Converts list of objects from query to JSON """ result = [] for item in self.get_values(lst, list_columns): @@ -264,19 +266,19 @@ def get_min_length(self, col_name): def add(self, item): """ - Adds object + Adds object """ raise NotImplementedError def edit(self, item): """ - Edit (change) object + Edit (change) object """ raise NotImplementedError def delete(self, item): """ - Deletes object + Deletes object """ raise NotImplementedError @@ -285,7 +287,7 @@ def get_col_default(self, col_name): def get_keys(self, lst): """ - return a list of pk values from object list + return a list of pk values from object list """ pk_name = self.get_pk_name() if self.is_pk_composite(): @@ -295,7 +297,7 @@ def get_keys(self, lst): def get_pk_name(self): """ - Returns the primary key name + Returns the primary key name """ raise NotImplementedError @@ -308,8 +310,8 @@ def get_pk_value(self, item): def get(self, pk, filter=None): """ - return the record from key, you can optionally pass filters - if pk exits on the db but filters exclude it it will return none. + return the record from key, you can optionally pass filters + if pk exits on the db but filters exclude it it will return none. """ pass @@ -318,11 +320,11 @@ def get_related_model(self, prop): def get_related_interface(self, col_name): """ - Returns a BaseInterface for the related model - of column name. + Returns a BaseInterface for the related model + of column name. - :param col_name: Column name with relation - :return: BaseInterface + :param col_name: Column name with relation + :return: BaseInterface """ raise NotImplementedError @@ -334,25 +336,25 @@ def get_related_fk(self, model): def get_columns_list(self): """ - Returns a list of all the columns names + Returns a list of all the columns names """ return [] def get_user_columns_list(self): """ - Returns a list of user viewable columns names + Returns a list of user viewable columns names """ return self.get_columns_list() def get_search_columns_list(self): """ - Returns a list of searchable columns names + Returns a list of searchable columns names """ return [] def get_order_columns_list(self, list_columns=None): """ - Returns a list of order columns names + Returns a list of order columns names """ return [] diff --git a/flask_appbuilder/models/sqla/interface.py b/flask_appbuilder/models/sqla/interface.py index c98372d02..b2e7517f4 100644 --- a/flask_appbuilder/models/sqla/interface.py +++ b/flask_appbuilder/models/sqla/interface.py @@ -1,14 +1,11 @@ # -*- coding: utf-8 -*- from contextlib import suppress import logging -import sys from typing import Any, Dict, List, Optional, Tuple, Type, Union from flask_appbuilder._compat import as_unicode from flask_appbuilder.const import ( - LOGMSG_ERR_DBI_ADD_GENERIC, LOGMSG_ERR_DBI_DEL_GENERIC, - LOGMSG_ERR_DBI_EDIT_GENERIC, LOGMSG_WAR_DBI_ADD_INTEGRITY, LOGMSG_WAR_DBI_DEL_INTEGRITY, LOGMSG_WAR_DBI_EDIT_INTEGRITY, @@ -736,11 +733,8 @@ def add(self, item: Model, raise_exception: bool = False) -> bool: raise e return False except Exception as e: - self.message = ( - as_unicode(self.general_error_message + " " + str(sys.exc_info()[0])), - "danger", - ) - log.exception(LOGMSG_ERR_DBI_ADD_GENERIC.format(str(e))) + self.message = (as_unicode(self.database_error_message), "danger") + log.exception("Database error") self.session.rollback() if raise_exception: raise e @@ -760,11 +754,8 @@ def edit(self, item: Model, raise_exception: bool = False) -> bool: raise e return False except Exception as e: - self.message = ( - as_unicode(self.general_error_message + " " + str(sys.exc_info()[0])), - "danger", - ) - log.exception(LOGMSG_ERR_DBI_EDIT_GENERIC.format(str(e))) + self.message = (as_unicode(self.database_error_message), "danger") + log.exception("Database error") self.session.rollback() if raise_exception: raise e @@ -785,11 +776,8 @@ def delete(self, item: Model, raise_exception: bool = False) -> bool: raise e return False except Exception as e: - self.message = ( - as_unicode(self.general_error_message + " " + str(sys.exc_info()[0])), - "danger", - ) - log.exception(LOGMSG_ERR_DBI_DEL_GENERIC.format(str(e))) + self.message = (as_unicode(self.database_error_message), "danger") + log.exception("Database error") self.session.rollback() if raise_exception: raise e @@ -809,10 +797,7 @@ def delete_all(self, items: List[Model]) -> bool: self.session.rollback() return False except Exception as e: - self.message = ( - as_unicode(self.general_error_message + " " + str(sys.exc_info()[0])), - "danger", - ) + self.message = (as_unicode(self.database_error_message), "danger") log.exception(LOGMSG_ERR_DBI_DEL_GENERIC.format(str(e))) self.session.rollback() return False diff --git a/flask_appbuilder/tests/security/test_mvc_security.py b/flask_appbuilder/tests/security/test_mvc_security.py index e3dc5c6cb..b268211e8 100644 --- a/flask_appbuilder/tests/security/test_mvc_security.py +++ b/flask_appbuilder/tests/security/test_mvc_security.py @@ -1,3 +1,5 @@ +from unittest.mock import patch + from flask_appbuilder import ModelView from flask_appbuilder.exceptions import PasswordComplexityValidationError from flask_appbuilder.models.sqla.filters import FilterEqual @@ -422,3 +424,117 @@ def test_register_user(self): ) self.db.session.delete(user) self.db.session.commit() + + def test_edit_user(self): + """ + Test edit user + """ + client = self.app.test_client() + _ = self.browser_login(client, USERNAME_ADMIN, PASSWORD_ADMIN) + + _tmp_user = self.create_user( + self.appbuilder, + "tmp_user", + "password1", + "", + first_name="tmp", + last_name="user", + email="tmp@fab.org", + role_names=["Admin"], + ) + + # use all required params + rv = client.get(f"/users/edit/{_tmp_user.id}", follow_redirects=True) + data = rv.data.decode("utf-8") + self.assertIn("Edit User", data) + rv = client.post( + f"/users/edit/{_tmp_user.id}", + data=dict( + first_name=_tmp_user.first_name, + last_name=_tmp_user.last_name, + username=_tmp_user.username, + email="changed@changed.org", + roles=_tmp_user.roles[0].id, + ), + follow_redirects=True, + ) + data = rv.data.decode("utf-8") + self.assertIn("Changed Row", data) + + user = ( + self.db.session.query(User) + .filter(User.username == _tmp_user.username) + .one_or_none() + ) + + assert user.email == "changed@changed.org" + self.db.session.delete(user) + self.db.session.commit() + + def test_edit_user_email_validation(self): + """ + Test edit user with email not null validation + """ + client = self.app.test_client() + _ = self.browser_login(client, USERNAME_ADMIN, PASSWORD_ADMIN) + + read_ony_user: User = ( + self.db.session.query(User) + .filter(User.username == USERNAME_READONLY) + .one_or_none() + ) + + # use all required params + rv = client.get(f"/users/edit/{read_ony_user.id}", follow_redirects=True) + data = rv.data.decode("utf-8") + self.assertIn("Edit User", data) + rv = client.post( + f"/users/edit/{read_ony_user.id}", + data=dict( + first_name=read_ony_user.first_name, + last_name=read_ony_user.last_name, + username=read_ony_user.username, + email=None, + roles=read_ony_user.roles[0].id, + ), + follow_redirects=True, + ) + data = rv.data.decode("utf-8") + self.assertIn("This field is required", data) + + def test_edit_user_db_fail(self): + """ + Test edit user with DB fail + """ + client = self.app.test_client() + _ = self.browser_login(client, USERNAME_ADMIN, PASSWORD_ADMIN) + + read_ony_user: User = ( + self.db.session.query(User) + .filter(User.username == USERNAME_READONLY) + .one_or_none() + ) + + # use all required params + rv = client.get(f"/users/edit/{read_ony_user.id}", follow_redirects=True) + data = rv.data.decode("utf-8") + self.assertIn("Edit User", data) + + with patch.object(self.appbuilder.session, "merge") as mock_merge: + with patch.object(self.appbuilder.sm, "has_access", return_value=True) as _: + mock_merge.side_effect = Exception("BANG!") + + rv = client.post( + f"/users/edit/{read_ony_user.id}", + data=dict( + first_name=read_ony_user.first_name, + last_name=read_ony_user.last_name, + username=read_ony_user.username, + email="changed@changed.org", + roles=read_ony_user.roles[0].id, + ), + follow_redirects=True, + ) + + data = rv.data.decode("utf-8") + self.assertIn("Database Error", data)