diff --git a/.github/workflows/pytest.yml b/.github/workflows/pytest.yml index f54f78bb..3a2ac459 100644 --- a/.github/workflows/pytest.yml +++ b/.github/workflows/pytest.yml @@ -63,7 +63,7 @@ jobs: psql -h localhost -U postgres coaster_test -c "grant all privileges on schema public to $(whoami); grant all privileges on all tables in schema public to $(whoami); grant all privileges on all sequences in schema public to $(whoami);" - name: Test with pytest run: | - pytest --ignore-flaky --showlocals --cov=coaster + pytest --showlocals --cov=coaster - name: Prepare coverage report run: | mkdir -p coverage diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index d3c2fd82..328e904c 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -1,7 +1,7 @@ # See https://pre-commit.com for more information # See https://pre-commit.com/hooks.html for more hooks ci: - skip: ['yesqa', 'no-commit-to-branch'] + skip: ['no-commit-to-branch'] repos: - repo: https://github.com/asottile/pyupgrade rev: v3.15.2 @@ -13,35 +13,7 @@ repos: hooks: - id: ruff args: ['--fix', '--exit-non-zero-on-fix'] - # Extra args, only after removing flake8 and yesqa: '--extend-select', 'RUF100' - - repo: https://github.com/asottile/yesqa - rev: v1.5.0 - hooks: - - id: yesqa - additional_dependencies: &flake8deps - - flake8-assertive - # - flake8-annotations - - flake8-blind-except - - flake8-builtins - - flake8-comprehensions - # - flake8-docstrings - - flake8-isort - - flake8-logging-format - - flake8-mutable - - flake8-print - - pep8-naming - - toml - - tomli - - repo: https://github.com/PyCQA/isort - rev: 5.13.2 - hooks: - - id: isort - additional_dependencies: - - tomli - - repo: https://github.com/psf/black - rev: 24.4.2 - hooks: - - id: black + - id: ruff-format - repo: https://github.com/pre-commit/mirrors-mypy rev: v1.10.0 hooks: @@ -56,6 +28,7 @@ repos: ] additional_dependencies: - flask + - quart - lxml-stubs - sqlalchemy - toml @@ -66,11 +39,6 @@ repos: - types-requests - types-toml - typing-extensions - - repo: https://github.com/PyCQA/flake8 - rev: 7.0.0 - hooks: - - id: flake8 - additional_dependencies: *flake8deps - repo: https://github.com/pre-commit/pygrep-hooks rev: v1.10.0 hooks: @@ -88,6 +56,7 @@ repos: ] additional_dependencies: - flask + - quart - sqlalchemy - tomli - repo: https://github.com/PyCQA/bandit diff --git a/dev_requirements.txt b/dev_requirements.txt index eb09a985..a4bf1d9a 100644 --- a/dev_requirements.txt +++ b/dev_requirements.txt @@ -1,15 +1,12 @@ -flake8 -flake8-assertive -flake8-blind-except -flake8-builtins -flake8-coding -flake8-comprehensions -flake8-isort -flake8-logging-format -flake8-mutable -flake8-print>=5.0.0 -isort -pep8-naming +mypy pre-commit pylint +ruff toml +types-bleach +types-Markdown +types-pytz +types-PyYAML +types-requests +types-toml +typing-extensions diff --git a/docs/compat.rst b/docs/compat.rst new file mode 100644 index 00000000..a92b0d76 --- /dev/null +++ b/docs/compat.rst @@ -0,0 +1,2 @@ +.. automodule:: coaster.compat + :members: diff --git a/docs/index.rst b/docs/index.rst index b4cc878f..184e8378 100644 --- a/docs/index.rst +++ b/docs/index.rst @@ -25,7 +25,7 @@ Coaster is available under the BSD license, the same license as Flask. views/index sqlalchemy/index db - nlp + compat Indices and tables ================== diff --git a/docs/nlp.rst b/docs/nlp.rst deleted file mode 100644 index fd280b21..00000000 --- a/docs/nlp.rst +++ /dev/null @@ -1,2 +0,0 @@ -.. automodule:: coaster.nlp - :members: diff --git a/pyproject.toml b/pyproject.toml index 58f4207b..eb717d06 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -33,6 +33,7 @@ classifiers = [ ] dependencies = [ 'aniso8601', + 'asgiref', 'base58>=2.0.0', 'bleach', 'blinker', @@ -45,7 +46,6 @@ dependencies = [ 'isoweek', 'Markdown>=3.2.0', 'markupsafe', - 'nltk>=3.4.5', 'pymdown-extensions>=8.0', 'pytz', 'semantic-version>=2.8.0', @@ -103,11 +103,18 @@ sections = ['FUTURE', 'STDLIB', 'THIRDPARTY', 'FIRSTPARTY', 'LOCALFOLDER'] [tool.pytest.ini_options] pythonpath = 'src' -required_plugins = ['pytest-env', 'pytest-rerunfailures', 'pytest-socket'] +required_plugins = [ + 'pytest-asyncio', + 'pytest-env', + 'pytest-rerunfailures', + 'pytest-socket', +] +asyncio_mode = 'auto' minversion = '6.0' -addopts = '--doctest-modules --ignore setup.py --cov-report=term-missing' +addopts = '--doctest-modules --ignore setup.py --cov-report=term-missing --strict-markers' doctest_optionflags = ['ALLOW_UNICODE', 'ALLOW_BYTES'] env = ['FLASK_ENV=testing'] +markers = ["has_server_name: App fixture has a server name in config"] [tool.pylint.master] max-parents = 10 @@ -138,20 +145,21 @@ disable = [ 'too-many-lines', 'too-many-locals', 'too-many-public-methods', - 'unused-argument', 'unsupported-membership-test', + 'unused-argument', # These need some serious refactoring, so disabled for now 'too-many-branches', 'too-many-nested-blocks', 'too-many-statements', - # Let Black, isort and ruff handle these + # Let Ruff handle these + 'consider-using-f-string', 'line-too-long', - 'wrong-import-position', - 'wrong-import-order', - # Let flake8 handle these 'missing-class-docstring', 'missing-function-docstring', 'missing-module-docstring', + 'superfluous-parens', + 'wrong-import-order', + 'wrong-import-position', ] [tool.mypy] @@ -176,63 +184,6 @@ exclude_dirs = ['node_modules', 'build/lib'] skips = ['*/*_test.py', '*/test_*.py'] [tool.ruff] -# This is a slight customisation of the default rules -# 1. Rule E402 (module-level import not top-level) is disabled as isort handles it -# 2. Rule E501 (line too long) is left to Black; some strings are worse for wrapping - -# Enable pycodestyle (`E`) and Pyflakes (`F`) codes by default. -lint.select = ["E", "F"] -lint.ignore = ["E402", "E501"] - -# Allow autofix for all enabled rules (when `--fix`) is provided. -lint.fixable = [ - "A", - "B", - "C", - "D", - "E", - "F", - "G", - "I", - "N", - "Q", - "S", - "T", - "W", - "ANN", - "ARG", - "BLE", - "COM", - "DJ", - "DTZ", - "EM", - "ERA", - "EXE", - "FBT", - "ICN", - "INP", - "ISC", - "NPY", - "PD", - "PGH", - "PIE", - "PL", - "PT", - "PTH", - "PYI", - "RET", - "RSE", - "RUF", - "SIM", - "SLF", - "TCH", - "TID", - "TRY", - "UP", - "YTT", -] -lint.unfixable = [] - # Exclude a variety of commonly ignored directories. exclude = [ ".bzr", @@ -260,15 +211,89 @@ exclude = [ # Same as Black. line-length = 88 -# Allow unused variables when underscore-prefixed. -lint.dummy-variable-rgx = "^(_+|(_+[a-zA-Z0-9_]*[a-zA-Z0-9]+?))$" - # Target Python 3.9 target-version = "py39" -[tool.ruff.lint.mccabe] -# Unlike Flake8, default to a complexity level of 10. -max-complexity = 10 +[tool.ruff.format] +docstring-code-format = true +quote-style = "preserve" + +[tool.ruff.lint] +select = [ + "A", # flake8-builtins + "ANN", # flake8-annotations + "ARG", # flake8-unused-arguments + "ASYNC", # flake8-async + "B", # flake8-bugbear + "BLE", # flake8-blind-except + "C", # pylint convention + "D", # pydocstyle + "C4", # flake8-comprehensions + "E", # Error + "EM", # flake8-errmsg + "EXE", # flake8-executable + "F", # pyflakes + "FA", # flake8-future-annotations + "G", # flake8-logging-format + "I", # isort + "INP", # flake8-no-pep420 + "INT", # flake8-gettext + "ISC", # flake8-implicit-str-concat + "N", # pep8-naming + "PIE", # flake8-pie + "PT", # flake8-pytest-style + "PYI", # flake8-pyi + "RET", # flake8-return + "RUF", # Ruff + "S", # flake8-bandit + "SIM", # flake8-simplify + "SLOT", # flake8-slots + "T20", # flake8-print + "TRIO", # flake8-trio + "UP", # pyupgrade + "W", # Warnings + "YTT", # flake8-2020 +] +ignore = [ + "ANN002", # `*args` is implicit `Any` + "ANN003", # `**kwargs` is implicit `Any` + "ANN101", # `self` type is implicit + "ANN102", # `cls` type is implicit + "ANN401", # Allow `Any` type + "C901", # TODO: Remove after code refactoring + "D101", + "D102", + "D103", + "D105", # Magic methods don't need docstrings + "D106", # Nested classes don't need docstrings + "D107", # `__init__` doesn't need a docstring + "D203", # No blank lines before class docstring + "D212", # Allow multiline docstring to start on next line after quotes + "D213", # But also allow multiline docstring to start right after quotes + "E402", # Allow top-level imports after statements + "E501", # Allow long lines if the formatter can't fix it + "EM101", # Allow Exception("string") + "EM102", # Allow Exception(f"string") + "ISC001", # Allow implicitly concatenated string literals (required for formatter) + "RUF012", # Allow mutable ClassVar without annotation (conflicts with SQLAlchemy) + "SLOT000", # Don't require `__slots__` for subclasses of str +] + +# Allow autofix for all enabled rules (when `--fix`) is provided. +fixable = ["ALL"] +unfixable = [] + +# Allow these characters in strings +allowed-confusables = ["‘", "’", "–"] + +[tool.ruff.lint.extend-per-file-ignores] +"__init__.py" = ["E402"] # Allow non-top-level imports +"tests/**.py" = [ + "S101", # Allow assert + "ANN001", # Args don't need types (usually fixtures) + "N802", # Fixture returning a class may be named per class name convention + "N803", # Args don't require naming convention (fixture could be a class) +] [tool.ruff.lint.isort] # These config options should match isort config above under [tool.isort] @@ -276,15 +301,22 @@ combine-as-imports = true extra-standard-library = ['typing_extensions'] split-on-trailing-comma = false relative-imports-order = 'furthest-to-closest' -known-first-party = ['coaster'] +known-first-party = [] section-order = [ 'future', 'standard-library', 'third-party', 'first-party', + 'repo', 'local-folder', ] +[tool.ruff.lint.isort.sections] +repo = ['coaster'] + [tool.ruff.lint.flake8-pytest-style] fixture-parentheses = false mark-parentheses = false + +[tool.ruff.lint.pyupgrade] +keep-runtime-typing = true diff --git a/src/coaster/__init__.py b/src/coaster/__init__.py index 7a7581c6..74ab6b54 100644 --- a/src/coaster/__init__.py +++ b/src/coaster/__init__.py @@ -1,3 +1,4 @@ # flake8: noqa """Coaster provides various support modules for Flask apps.""" + from ._version import * diff --git a/src/coaster/app.py b/src/coaster/app.py index 53f47832..c67fc593 100644 --- a/src/coaster/app.py +++ b/src/coaster/app.py @@ -1,7 +1,4 @@ -""" -App configuration -================= -""" +"""App configuration.""" # pyright: reportMissingImports=false @@ -29,13 +26,9 @@ from flask.json.provider import DefaultJSONProvider from flask.sessions import SecureCookieSessionInterface -try: # Flask >= 3.0 # pragma: no cover - from flask.sansio.app import App as FlaskApp -except ModuleNotFoundError: # Flask < 3.0 - from flask import Flask as FlaskApp - from . import logger from .auth import current_auth +from .compat import BaseApp from .views import current_view __all__ = [ @@ -58,13 +51,13 @@ try: import tomllib as mod_tomllib # type: ignore[no-redef] # Python >= 3.11 except ModuleNotFoundError: - try: + try: # noqa: SIM105 import tomli as mod_tomli # type: ignore[no-redef,unused-ignore] except ModuleNotFoundError: pass -try: +try: # noqa: SIM105 import yaml as mod_yaml except ModuleNotFoundError: pass @@ -135,10 +128,11 @@ def __class__(self) -> type: """Mimic wrapped engine's class.""" if self._engines: return type(self._engines[0]) - return super().__class__ + return KeyRotationWrapper @__class__.setter - def __class__(self, value: Any) -> NoReturn: # noqa: F811 + def __class__(self, value: Any) -> NoReturn: + # This setter is required for static type checkers raise TypeError("__class__ cannot be set.") def __init__( @@ -148,23 +142,23 @@ def __init__( **kwargs: Any, ) -> None: """Init key rotation wrapper.""" - if isinstance(secret_keys, str): # type: ignore[unreachable] + if isinstance(secret_keys, (str, bytes)): # type: ignore[unreachable] raise ValueError("Secret keys must be a list") if not secret_keys: raise ValueError("No secret keys in the list") self._engines = [cls(key, **kwargs) for key in secret_keys] - def __getattr__(self, attr: str) -> Any: + def __getattr__(self, name: str) -> Any: """Read a wrapped attribute.""" - item = getattr(self._engines[0], attr) - return self._make_wrapper(attr) if callable(item) else item + item = getattr(self._engines[0], name) + return self._make_wrapper(name) if callable(item) else item - def _make_wrapper(self, attr: str) -> Callable: + def _make_wrapper(self, name: str) -> Callable: def wrapper(*args: Any, **kwargs: Any) -> Any: saved_exc: Exception = _sentinel_keyrotation_exception for engine in self._engines: try: - return getattr(engine, attr)(*args, **kwargs) + return getattr(engine, name)(*args, **kwargs) except itsdangerous.BadSignature as exc: saved_exc = exc # We've run out of engines and all of them reported BadSignature. @@ -178,7 +172,7 @@ class RotatingKeySecureCookieSessionInterface(SecureCookieSessionInterface): """Replaces the serializer with key rotation support.""" def get_signing_serializer( # type: ignore[override] - self, app: FlaskApp + self, app: BaseApp ) -> Optional[KeyRotationWrapper]: """Return serializers wrapped for key rotation.""" if not app.config.get('SECRET_KEYS'): @@ -217,7 +211,7 @@ def default(o: Any) -> Any: def init_app( - app: FlaskApp, + app: BaseApp, config: Optional[list[Literal['env', 'py', 'json', 'toml', 'yaml']]] = None, *, env_prefix: Optional[Union[str, Sequence[str]]] = None, @@ -285,7 +279,7 @@ def init_app( for config_option in config: if config_option == 'env': if env_prefix is None: - # Use Flask's default env prefix + # Use Flask or Quart's default env prefix app.config.from_prefixed_env() elif isinstance(env_prefix, str): # Use the app's requested env prefix @@ -335,7 +329,7 @@ def init_app( def load_config_from_file( - app: FlaskApp, + app: BaseApp, filepath: str, load: Optional[Callable] = None, text: Optional[bool] = None, diff --git a/src/coaster/assets.py b/src/coaster/assets.py index 826a2fb0..e97f6049 100644 --- a/src/coaster/assets.py +++ b/src/coaster/assets.py @@ -1,6 +1,5 @@ """ -Assets -====== +Assets. Coaster provides a simple asset management system for semantically versioned assets using the semantic_version_ and webassets_ libraries. Many popular libraries such as @@ -15,6 +14,7 @@ .. _webassets: http://elsdoerfer.name/docs/webassets/ .. _Webpack: https://webpack.js.org/ """ +# spell-checker:ignore webassets sourcecode endassets from __future__ import annotations @@ -22,13 +22,18 @@ import warnings from collections import defaultdict from collections.abc import Iterator, Mapping, Sequence -from typing import Any, Final, Optional, Union +from typing import TYPE_CHECKING, Any, Final, Optional, Union from urllib.parse import urljoin -from flask import Flask, current_app +from flask import Flask from flask_assets import Bundle from semantic_version import SimpleSpec, Version +from .compat import current_app, sync_await + +if TYPE_CHECKING: + from quart import Quart + _VERSION_SPECIFIER_RE = re.compile('[<=>!*]') # Version is not used here but is made available for others to import from @@ -67,6 +72,7 @@ class VersionedAssets(defaultdict): To use, initialize a container for your assets:: from coaster.assets import VersionedAssets, Version + assets = VersionedAssets() And then populate it with your assets. The simplest way is by specifying @@ -79,14 +85,16 @@ class VersionedAssets(defaultdict): a list or tuple of requirements followed by the actual asset:: assets['jquery.form.js'][Version('2.96.0')] = ( - 'jquery.js', 'js/jquery.form-2.96.js') + 'jquery.js', + 'js/jquery.form-2.96.js', + ) You may have an asset that provides replacement functionality for another asset:: assets['zepto.js'][Version('1.0.0-rc1')] = { 'provides': 'jquery.js', 'bundle': 'js/zepto-1.0rc1.js', - } + } Assets specified as a dictionary can have three keys: @@ -113,6 +121,7 @@ class VersionedAssets(defaultdict): To use these assets in a Flask app, register the assets with an environment:: from flask_assets import Environment + appassets = Environment(app) appassets.register('js_all', assets.require('jquery.js', ...)) @@ -277,7 +286,7 @@ class WebpackManifest(Mapping): def __init__( self, - app: Optional[Flask] = None, + app: Optional[Union[Flask, Quart]] = None, *, filepath: str = 'static/manifest.json', urlpath: Optional[str] = None, @@ -294,13 +303,24 @@ def __init__( if app is not None: self.init_app(app, _warning_stack_level=3) - def init_app(self, app: Flask, _warning_stack_level: int = 2) -> None: + def _read_resource(self, app: Flask) -> Union[str, bytes]: + with app.open_resource(self.filepath) as resource: + return resource.read() + + async def _async_read_resource(self, app: Quart) -> Union[str, bytes]: + async with await app.open_resource(self.filepath) as resource: + return await resource.read() + + def init_app(self, app: Union[Flask, Quart], _warning_stack_level: int = 2) -> None: """Configure WebpackManifest on a Flask app.""" # Step 1: Open manifest.json and validate basic structure (incl. legacy check) - with app.open_resource(self.filepath) as resource: - # Use ``json.loads`` because a substitute JSON implementation may not - # support the ``load`` method (eg: orjson has ``loads`` but not ``load``) - assets = app.json.loads(resource.read()) + if isinstance(app, Flask): + resource_content = self._read_resource(app) + else: + resource_content = sync_await(self._async_read_resource(app)) + # Use ``json.loads`` because a substitute JSON implementation may not + # support the ``load`` method (eg: orjson has ``loads`` but not ``load``) + assets = app.json.loads(resource_content) if not isinstance(assets, dict): raise ValueError( f"File `{self.filepath}` must contain a JSON object at the root level" diff --git a/src/coaster/auth.py b/src/coaster/auth.py index 73361b6b..b84a5e4d 100644 --- a/src/coaster/auth.py +++ b/src/coaster/auth.py @@ -1,6 +1,5 @@ """ -Authentication management -========================= +Authentication management. Coaster provides a :obj:`current_auth` for handling authentication. Login managers must comply with its API for Coaster's view handlers to work. @@ -24,11 +23,9 @@ from threading import Lock from typing import Any, NoReturn, TypeVar, cast -from flask import Flask, current_app, g -from flask.globals import request_ctx from werkzeug.local import LocalProxy -from werkzeug.wrappers import Response as BaseResponse +from .compat import BaseApp, BaseResponse, current_app, flask_g, quart_g, request_ctx from .utils import InspectableSet __all__ = [ @@ -91,7 +88,7 @@ def add_auth_attribute(attr: str, value: Any, actor: bool = False) -> None: if attr == 'user': # Special-case 'user' for compatibility with Flask-Login - if g: + if g := (quart_g or flask_g): g._login_user = value # A user is always an actor actor = True @@ -168,14 +165,14 @@ def __init__(self, is_placeholder: bool = False) -> None: object.__setattr__(self, 'actor', None) object.__setattr__(self, 'user', None) - def __setattr__(self, attr: str, value: Any) -> NoReturn: - if hasattr(self, attr) and getattr(self, attr) is value: + def __setattr__(self, name: str, value: Any) -> NoReturn: + if hasattr(self, name) and getattr(self, name) is value: # This test is used to allow in-place mutations such as: # current_auth.permissions |= {extra} return # type: ignore[misc] raise TypeError('current_auth is read-only') - def __delattr__(self, attr: str) -> NoReturn: + def __delattr__(self, name: str) -> NoReturn: raise TypeError('current_auth is read-only') def __contains__(self, attr: str) -> bool: @@ -191,19 +188,19 @@ def get(self, attr: str, default: Any = None) -> Any: def __repr__(self) -> str: # pragma: no cover return f'CurrentAuth(is_placeholder={self.is_placeholder})' - def __getattr__(self, attr: str) -> Any: + def __getattr__(self, name: str) -> Any: """Init :class:`CurrentAuth` on first attribute access.""" with _prop_lock: if 'actor' in self.__dict__: # CurrentAuth already initialized - raise AttributeError(attr) + raise AttributeError(name) self.__dict__['actor'] = None self.__dict__.setdefault('user', None) self._call_login_manager() try: - return self.__dict__[attr] + return self.__dict__[name] except KeyError: - raise AttributeError(attr) from None + raise AttributeError(name) from None def _call_login_manager(self) -> None: """Call the app's login manager on first access of user or actor (internal).""" @@ -221,7 +218,7 @@ def _call_login_manager(self) -> None: # In case the login manager did not call :func:`add_auth_attribute`, we'll # need to do it if self.__dict__.get('user') is None: - user = g.get('_login_user') + user = (quart_g or flask_g).get('_login_user') if user is not None: self.__dict__['user'] = user # Set actor=user only if the login manager did not add another actor @@ -248,7 +245,7 @@ def _set_auth_cookie_after_request(response: _Response) -> _Response: return response -def init_app(app: Flask) -> None: +def init_app(app: BaseApp) -> None: """Optionally initialize current_auth for auth cookie management in an app.""" app.config.setdefault('AUTH_COOKIE_NAME', 'auth') for our_config, flask_config in [ @@ -285,7 +282,7 @@ def __call__(self) -> CurrentAuth: if ca is None: # 3. If not, create it ca = self.cls() - request_ctx.current_auth = ca # type: ignore[attr-defined] + request_ctx.current_auth = ca # type: ignore[union-attr] elif not isinstance(ca, self.cls): # If ca is not an instance of self.cls but self.cls is a subclass of # ca.__class__, then re-create with self.cls. This is needed because @@ -297,7 +294,7 @@ def __call__(self) -> CurrentAuth: if issubclass(self.cls, ca.__class__): new_ca = self.cls() new_ca.__dict__.update(ca.__dict__) - request_ctx.current_auth = new_ca # type: ignore[attr-defined] + request_ctx.current_auth = new_ca # type: ignore[union-attr] ca = new_ca # 4. Return current_auth return ca diff --git a/src/coaster/compat.py b/src/coaster/compat.py new file mode 100644 index 00000000..4ab7040f --- /dev/null +++ b/src/coaster/compat.py @@ -0,0 +1,317 @@ +"""Async compatibility between Flask and Quart.""" + +# pyright: reportMissingImports=false +# pylint: disable=ungrouped-imports + +from __future__ import annotations + +from collections.abc import Awaitable, Callable +from inspect import isawaitable, iscoroutinefunction +from typing import TYPE_CHECKING, Any, AnyStr, Optional, TypeVar, Union, overload +from typing_extensions import Literal, ParamSpec + +from asgiref.sync import async_to_sync +from flask import ( + current_app as flask_current_app, + g as flask_g, + has_request_context as flask_has_request_context, + make_response as flask_make_response, + render_template as flask_render_template, + render_template_string as flask_render_template_string, + request as flask_request, +) +from flask.globals import request_ctx as flask_request_ctx +from werkzeug.datastructures import CombinedMultiDict, MultiDict +from werkzeug.wrappers import Response as WerkzeugResponse + +# MARK: Gated imports ------------------------------------------------------------------ + +try: # Flask >= 3.0 + from flask.sansio.app import App as BaseApp + from flask.sansio.blueprints import Blueprint as BaseBlueprint +except ModuleNotFoundError: # Flask < 3.0 + from flask import Blueprint as BaseBlueprint, Flask as BaseApp + +try: # Werkzeug >= 3.0 + from werkzeug.sansio.request import Request as BaseRequest + from werkzeug.sansio.response import Response as BaseResponse +except ModuleNotFoundError: # Werkzeug < 3.0 + # pylint: disable=reimported + from werkzeug.wrappers import ( # type: ignore[assignment] + Request as BaseRequest, + Response as BaseResponse, + ) + +try: + from quart import ( + current_app as quart_current_app, + g as quart_g, + has_request_context as quart_has_request_context, + make_response as quart_make_response, + render_template as quart_render_template, + render_template_string as quart_render_template_string, + request as quart_request, + ) + from quart.globals import request_ctx as quart_request_ctx +except ModuleNotFoundError: + quart_current_app = None # type: ignore[assignment] + quart_g = None # type: ignore[assignment] + quart_request = None # type: ignore[assignment] + quart_has_request_context = None # type: ignore[assignment] + quart_render_template = None # type: ignore[assignment] + quart_render_template_string = None # type: ignore[assignment] + quart_request_ctx = None # type: ignore[assignment] + + +if TYPE_CHECKING: + from flask import Flask, Request as FlaskRequest + from flask.ctx import RequestContext as FlaskRequestContext + from quart import Quart, Request as QuartRequest, Response as QuartResponse + from quart.ctx import RequestContext as QuartRequestContext + +__all__ = [ + 'BaseApp', + 'BaseBlueprint', + 'BaseRequest', + 'BaseResponse', + 'async_render_template_string', + 'async_render_template', + 'async_request', + 'current_app_object', + 'current_app', + 'flask_g', + 'has_request_context', + 'quart_g', + 'request_ctx', +] + + +# MARK: Cross-compatible helpers ------------------------------------------------------- + + +class QuartFlaskWrapper: + """ + Proxy to Quart or Flask source objects. + + This object does not implement any magic methods other than meth:`__bool__` and does + not resolve API differences. + """ + + _quart_source: Any + _flask_source: Any + + def __init__(self, quart_source: Any, flask_source: Any) -> None: + object.__setattr__(self, '_quart_source', quart_source) + object.__setattr__(self, '_flask_source', flask_source) + + def __bool__(self) -> bool: + return bool(self._quart_source or self._flask_source) + + def __getattr__(self, name: str) -> Any: + if self._quart_source: + return getattr(self._quart_source, name) + return getattr(self._flask_source, name) + + def __setattr__(self, name: str, value: Any) -> None: + if self._quart_source: + setattr(self._quart_source, name, value) + setattr(self._flask_source, name, value) + + def __delattr__(self, name: str) -> None: + if self._quart_source: + delattr(self._quart_source, name) + delattr(self._flask_source, name) + + +current_app: Union[Flask, Quart] +current_app = QuartFlaskWrapper( # type: ignore[assignment] + quart_current_app, flask_current_app +) +request_ctx: Union[FlaskRequestContext, QuartRequestContext] +request_ctx = QuartFlaskWrapper( # type: ignore[assignment] + quart_request_ctx, flask_request_ctx +) + +request: Union[FlaskRequest, QuartRequest] +request = QuartFlaskWrapper( # type: ignore[assignment] + quart_request, flask_request +) + + +def current_app_object() -> Optional[Union[Flask, Quart]]: + """Get current app from Quart or Flask (unwrapping the proxy).""" + # pylint: disable=protected-access + if quart_current_app: + return quart_current_app._get_current_object() # type: ignore[attr-defined] + if flask_current_app: + return flask_current_app._get_current_object() # type: ignore[attr-defined] + return None + + +def has_request_context() -> bool: + """Check for request context in Quart or Flask.""" + return ( + quart_has_request_context is not None and quart_has_request_context() + ) or flask_has_request_context() + + +# MARK: Async helpers ------------------------------------------------------------------ + + +class AsyncRequestWrapper: + """Mimic Quart's async request when operating under Flask.""" + + def __bool__(self) -> bool: + return bool(quart_request or flask_request) + + @property + async def data(self) -> bytes: + if quart_request: + return await quart_request.data + return flask_request.data + + @overload + async def get_data( + self, cache: bool, as_text: Literal[False], parse_form_data: bool + ) -> bytes: ... + + @overload + async def get_data( + self, cache: bool, as_text: Literal[True], parse_form_data: bool + ) -> str: ... + + @overload + async def get_data( + self, cache: bool = True, as_text: bool = False, parse_form_data: bool = False + ) -> AnyStr: ... + + async def get_data( + self, cache: bool = True, as_text: bool = False, parse_form_data: bool = False + ) -> AnyStr: + if quart_request: + return await quart_request.get_data(cache, as_text, parse_form_data) + return flask_request.get_data( # type: ignore[call-overload, return-value] + cache, as_text, parse_form_data + ) + + @property + async def json(self) -> Optional[Any]: + if quart_request: + return await quart_request.json + return flask_request.json + + async def get_json(self) -> Optional[Any]: + if quart_request: + return await quart_request.get_json() + return flask_request.get_json() + + @property + async def form(self) -> MultiDict: + if quart_request: + return await quart_request.form + return flask_request.form + + @property + async def files(self) -> MultiDict: + if quart_request: + return await quart_request.files + return flask_request.files + + @property + async def values(self) -> CombinedMultiDict: + if quart_request: + return await quart_request.values + return flask_request.values + + async def send_push_promise(self, path: str) -> None: + if quart_request: + await quart_request.send_push_promise(path) + # Do nothing if Flask + + async def close(self) -> None: + if quart_request: + return await quart_request.close() + return flask_request.close() + + # Proxy all other attributes to Quart or Flask + + def __getattr__(self, name: str) -> Any: + if quart_request: + return getattr(quart_request, name) + return getattr(flask_request, name) + + def __setattr__(self, name: str, value: Any) -> None: + if quart_request: + setattr(quart_request, name, value) + setattr(flask_request, name, value) + + def __delattr__(self, name: str) -> None: + if quart_request: + delattr(quart_request, name) + delattr(flask_request, name) + + +async_request = AsyncRequestWrapper() + + +async def async_make_response(*args: Any) -> Union[WerkzeugResponse, QuartResponse]: + """Make a response, auto-selecting between Quart and Flask.""" + if quart_current_app: + return await quart_make_response(*args) + return flask_make_response(*args) + + +async def async_render_template( + template_name_or_list: Union[str, list[str]], **context: Any +) -> str: + """Async render_template, auto-selecting between Quart and Flask.""" + if quart_current_app: + return await quart_render_template(template_name_or_list, **context) + return flask_render_template( + template_name_or_list, # type: ignore[arg-type] + **context, + ) + + +async def async_render_template_string(source: str, **context: Any) -> str: + """Async render_template_string, auto-selecting between Quart and Flask.""" + if quart_current_app: + return await quart_render_template_string(source, **context) + return flask_render_template_string(source, **context) + + +# MARK: Async to Sync helpers ---------------------------------------------------------- + +_P = ParamSpec('_P') +_R_co = TypeVar('_R_co', covariant=True) + + +@async_to_sync +async def sync_await(awaitable: Awaitable[_R_co]) -> _R_co: + """Implement await statement in a sync context.""" + return await awaitable + + +def ensure_sync( + func: Union[ + Callable[_P, Awaitable[_R_co]], + Callable[_P, _R_co], + ], +) -> Callable[_P, _R_co]: + """Run a possibly-async function in a sync context.""" + if not callable(func): + raise TypeError("Function is not callable.") + if iscoroutinefunction(func) or iscoroutinefunction( + getattr(func, '__call__', func) # noqa: B004 + ): + return async_to_sync(func) # type: ignore[arg-type] + + def check_return(*args: _P.args, **kwargs: _P.kwargs) -> _R_co: + result = func(*args, **kwargs) + if isawaitable(result): + return sync_await(result) + # The typeguard for isawaitable doesn't narrow in the negative context, so we + # need a type-ignore here: + return result # type: ignore[return-value] + + return check_return diff --git a/src/coaster/db.py b/src/coaster/db.py index f673bd55..1baca014 100644 --- a/src/coaster/db.py +++ b/src/coaster/db.py @@ -1,6 +1,5 @@ """ -Flask-SQLAlchemy instance -------------------------- +Flask-SQLAlchemy instance. .. deprecated:: 0.7.0 Coaster provides a global instance of Flask-SQLAlchemy for convenience, but this is diff --git a/src/coaster/gfm.py b/src/coaster/gfm.py deleted file mode 100644 index 3b97a48f..00000000 --- a/src/coaster/gfm.py +++ /dev/null @@ -1,3 +0,0 @@ -from .utils import markdown - -__all__ = ['markdown'] diff --git a/src/coaster/logger.py b/src/coaster/logger.py index 4b01d6f2..4eed7fef 100644 --- a/src/coaster/logger.py +++ b/src/coaster/logger.py @@ -1,12 +1,11 @@ """ -Logger -======= +Exception logger. Coaster can help your application log errors at run-time. Initialize with :func:`coaster.logger.init_app`. If you use :func:`coaster.app.init_app`, this is done automatically for you. """ - +# spell-checker:ignore typeshed apikey stripetoken cardnumber levelname # pyright: reportMissingImports=false from __future__ import annotations @@ -33,15 +32,11 @@ from logging import _SysExcInfoType import requests -from flask import g, request, session +from flask import request, session from flask.config import Config -try: # Flask >= 3.0 - from flask.sansio.app import App as FlaskApp -except ModuleNotFoundError: - from flask import Flask as FlaskApp - from .auth import current_auth +from .compat import BaseApp, flask_g, quart_g # Regex for credit card numbers _card_re = re.compile(r'\b(?:\d[ -]*?){13,16}\b') @@ -125,6 +120,9 @@ def pprint_with_indent(dictlike: dict, outfile: IO, indent: int = 4) -> None: out.close() +STACK_FRAMES_NOTICE = "Stack frames (most recent call first):" + + class LocalVarFormatter(logging.Formatter): """Log the contents of local variables in the stack frame.""" @@ -133,7 +131,7 @@ def __init__(self, *args: Any, **kwargs: Any) -> None: super().__init__(*args, **kwargs) self.lock = Lock() - def format(self, record: logging.LogRecord) -> str: # noqa: A003 + def format(self, record: logging.LogRecord) -> str: """ Format the specified record as text. @@ -143,7 +141,7 @@ def format(self, record: logging.LogRecord) -> str: # noqa: A003 if ( record.exc_info and record.exc_text - and "Stack frames (most recent call first)" not in record.exc_text + and STACK_FRAMES_NOTICE not in record.exc_text ): record.exc_text = None return super().format(record) @@ -174,16 +172,15 @@ def formatException(self, ei: _SysExcInfoType) -> str: # noqa: N802 # original __repr__ while this is still dumping. original_config_repr = Config.__repr__ Config.__repr__ = ( # type: ignore[method-assign] - lambda self: '' + lambda self: '' # noqa: ARG005 ) value_cache: dict[Any, str] = {} - print('\n----------\n', file=sio) # noqa: T201 - # XXX: The following text is used as a signature in :meth:`format` above - print("Stack frames (most recent call first):", file=sio) # noqa: T201 + print('\n----------\n', file=sio) + print(STACK_FRAMES_NOTICE, file=sio) for frame in stack: - print('\n----\n', file=sio) # noqa: T201 - print( # noqa: T201 + print('\n----\n', file=sio) + print( f"Frame {frame.f_code.co_name} in {frame.f_code.co_filename} at" f" line {frame.f_lineno}", file=sio, @@ -194,20 +191,20 @@ def formatException(self, ei: _SysExcInfoType) -> str: # noqa: N802 value = RepeatValueIndicator(value_cache[idvalue]) else: value_cache[idvalue] = f"{frame.f_code.co_name}.{attr}" - print(f"\t{attr:>20} = ", end=' ', file=sio) # noqa: T201 + print(f"\t{attr:>20} = ", end=' ', file=sio) try: - print(repr(filtered_value(attr, value)), file=sio) # noqa: T201 - except Exception: # noqa: B902 # pylint: disable=broad-except + print(repr(filtered_value(attr, value)), file=sio) + except Exception: # noqa: BLE001 # pylint: disable=broad-except # We need a bare except clause because this is the exception # handler. It can't have exceptions of its own. - print("", file=sio) # noqa: T201 + print("", file=sio) del value_cache Config.__repr__ = original_config_repr # type: ignore[method-assign] if request: - print('\n----------\n', file=sio) # noqa: T201 - print("Request context:", file=sio) # noqa: T201 + print('\n----------\n', file=sio) + print("Request context:", file=sio) request_data = { 'form': { k: filtered_value(k, v) @@ -226,32 +223,32 @@ def formatException(self, ei: _SysExcInfoType) -> str: # noqa: N802 } try: pprint_with_indent(request_data, sio) - except Exception: # noqa: B902 # pylint: disable=broad-except - print("", file=sio) # noqa: T201 + except Exception: # noqa: BLE001 # pylint: disable=broad-except + print("", file=sio) if session: - print('\n----------\n', file=sio) # noqa: T201 - print("Session cookie contents:", file=sio) # noqa: T201 + print('\n----------\n', file=sio) + print("Session cookie contents:", file=sio) try: pprint_with_indent(dict(session), sio) - except Exception: # noqa: B902 # pylint: disable=broad-except - print("", file=sio) # noqa: T201 + except Exception: # noqa: BLE001 # pylint: disable=broad-except + print("", file=sio) - if g: - print('\n----------\n', file=sio) # noqa: T201 - print("App context:", file=sio) # noqa: T201 + if quart_g or flask_g: + print('\n----------\n', file=sio) + print("App context:", file=sio) try: - pprint_with_indent(vars(g), sio) - except Exception: # noqa: B902 # pylint: disable=broad-except - print("", file=sio) # noqa: T201 + pprint_with_indent(vars(quart_g or flask_g), sio) + except Exception: # noqa: BLE001 # pylint: disable=broad-except + print("", file=sio) if current_auth: - print('\n----------\n', file=sio) # noqa: T201 - print("Current auth:", file=sio) # noqa: T201 + print('\n----------\n', file=sio) + print("Current auth:", file=sio) try: pprint_with_indent(vars(current_auth), sio) - except Exception: # noqa: B902 # pylint: disable=broad-except - print("", file=sio) # noqa: T201 + except Exception: # noqa: BLE001 # pylint: disable=broad-except + print("", file=sio) s = sio.getvalue() sio.close() @@ -350,7 +347,7 @@ def emit(self, record: logging.LogRecord) -> None: ).start() with self.throttle_lock: self.throttle_cache[throttle_key] = datetime.now() - except Exception: # nosec # noqa: B902 # pylint: disable=broad-except + except Exception: # noqa: BLE001 # pylint: disable=broad-except self.handleError(record) @@ -379,11 +376,10 @@ def emit(self, record: logging.LogRecord) -> None: < timedelta(minutes=5) ): return - # pylint: disable=consider-using-f-string - text = '{levelname} in {name}: {message}'.format( - levelname=escape(record.levelname, False), - name=escape(self.app_name, False), - message=escape(record.message, False), + text = ( + f'{escape(record.levelname, False)}' + f' in {escape(self.app_name, False)}:' + f' {escape(record.message, False)}' ) if record.exc_info: # Reverse the traceback, after dropping the first line with @@ -425,7 +421,7 @@ def emit(self, record: logging.LogRecord) -> None: ).start() with self.throttle_lock: self.throttle_cache[throttle_key] = datetime.now() - except Exception: # noqa: B902 # pylint: disable=broad-except + except Exception: # noqa: BLE001 # pylint: disable=broad-except self.handleError(record) @@ -441,7 +437,7 @@ def emit(self, record: logging.LogRecord) -> None: } -def init_app(app: FlaskApp, _warning_stacklevel: int = 2) -> None: +def init_app(app: BaseApp, _warning_stacklevel: int = 2) -> None: """ Enable logging for an app using :class:`LocalVarFormatter`. @@ -470,10 +466,13 @@ def init_app(app: FlaskApp, _warning_stacklevel: int = 2) -> None: Format for ``LOG_SLACK_WEBHOOKS``:: - LOG_SLACK_WEBHOOKS = [{ - 'levelnames': ['WARNING', 'ERROR', 'CRITICAL'], - 'url': 'https://hooks.slack.com/...' - }, ...] + LOG_SLACK_WEBHOOKS = [ + { + 'levelnames': ['WARNING', 'ERROR', 'CRITICAL'], + 'url': 'https://hooks.slack.com/...', + }, + ..., + ] """ # --- Prevent dupe init diff --git a/src/coaster/nlp.py b/src/coaster/nlp.py deleted file mode 100644 index 95ea12bd..00000000 --- a/src/coaster/nlp.py +++ /dev/null @@ -1,48 +0,0 @@ -""" -Natural language processing -=========================== - -Provides a wrapper around NLTK to extract named entities from HTML text:: - - from coaster.utils import text_blocks - from coaster.nlp import extract_named_entities - - html = "

This is some HTML-formatted text.

In two paragraphs.

" - textlist = text_blocks(html) # Returns a list of paragraphs. - entities = extract_named_entities(textlist) -""" - -from __future__ import annotations - -from collections.abc import Iterable - -import nltk - - -def extract_named_entities(text_blocks: Iterable[str]) -> set[str]: - """Return a set of named entities extracted from the provided text blocks.""" - sentences = [] - for text in text_blocks: - sentences.extend(nltk.sent_tokenize(text)) - - tokenized_sentences = [nltk.word_tokenize(sentence) for sentence in sentences] - tagged_sentences = [nltk.pos_tag(sentence) for sentence in tokenized_sentences] - chunked_sentences = nltk.ne_chunk_sents(tagged_sentences, binary=True) - - def extract_entity_names(tree: nltk.Tree) -> list[str]: - entity_names = [] - - if hasattr(tree, 'label'): - if tree.label() == 'NE': - entity_names.append(' '.join(child[0] for child in tree)) - else: - for child in tree: - entity_names.extend(extract_entity_names(child)) - - return entity_names - - entity_names = [] - for tree in chunked_sentences: - entity_names.extend(extract_entity_names(tree)) - - return set(entity_names) diff --git a/src/coaster/signals.py b/src/coaster/signals.py index d36d265b..e5a96010 100644 --- a/src/coaster/signals.py +++ b/src/coaster/signals.py @@ -1,3 +1,5 @@ +"""Coaster signals.""" + from blinker import Namespace coaster_signals = Namespace() diff --git a/src/coaster/sqlalchemy/annotations.py b/src/coaster/sqlalchemy/annotations.py index 8a06094e..72fdee74 100644 --- a/src/coaster/sqlalchemy/annotations.py +++ b/src/coaster/sqlalchemy/annotations.py @@ -1,6 +1,5 @@ """ -SQLAlchemy attribute annotations --------------------------------- +SQLAlchemy attribute annotations. Annotations are strings attached to attributes that serve as a programmer reference on how those attributes are meant to be used. They can be used to @@ -19,12 +18,11 @@ natural_key = annotation_wrapper('natural_key', "Natural key for this model") + class MyModel(Model): __tablename__ = 'my_model' id: Mapped[int] = immutable(sa.orm.mapped_column(sa.Integer, primary_key=True)) - name: Mapped[str] = natural_key(sa.orm.mapped_column( - sa.Unicode(250), unique=True - )) + name: Mapped[str] = natural_key(sa.orm.mapped_column(sa.Unicode(250), unique=True)) @classmethod def get(cls, **kwargs): diff --git a/src/coaster/sqlalchemy/columns.py b/src/coaster/sqlalchemy/columns.py index fb8b28e5..4caddc91 100644 --- a/src/coaster/sqlalchemy/columns.py +++ b/src/coaster/sqlalchemy/columns.py @@ -1,7 +1,4 @@ -""" -SQLAlchemy column types ------------------------ -""" +"""SQLAlchemy column types.""" from __future__ import annotations @@ -51,13 +48,13 @@ def coerce_compared_value(self, op: Any, value: Any) -> sa.types.TypeEngine: """Coerce an incoming value using the JSON type's default handler.""" return self.impl.coerce_compared_value(op, value) - def process_bind_param(self, value: Any, dialect: sa.Dialect) -> Any: + def process_bind_param(self, value: Any, _dialect: sa.Dialect) -> Any: """Convert a Python value into a JSON string for the database.""" if value is not None: value = json.dumps(value, default=str) # Callable default return value - def process_result_value(self, value: Any, dialect: sa.Dialect) -> Any: + def process_result_value(self, value: Any, _dialect: sa.Dialect) -> Any: """Convert a JSON string from the database into a dict.""" if value is not None and isinstance(value, str): # Psycopg2 >= 2.5 will auto-decode JSON columns, so @@ -70,7 +67,7 @@ def process_result_value(self, value: Any, dialect: sa.Dialect) -> Any: class MutableDict(Mutable, dict): @classmethod - def coerce(cls, key: Any, value: Any) -> Optional[MutableDict]: + def coerce(cls, _key: Any, value: Any) -> Optional[MutableDict]: """Convert plain dictionaries to MutableDict.""" if value is None: return None @@ -146,7 +143,7 @@ def process_bind_param(self, value: Any, dialect: sa.Dialect) -> Optional[str]: raise ValueError("Missing URL host") return value - def process_result_value(self, value: Any, dialect: sa.Dialect) -> Optional[furl]: + def process_result_value(self, value: Any, _dialect: sa.Dialect) -> Optional[furl]: """Cast URL loaded from database into a furl object.""" if value is not None: return self.url_parser(value) diff --git a/src/coaster/sqlalchemy/comparators.py b/src/coaster/sqlalchemy/comparators.py index e3e2a070..4566065b 100644 --- a/src/coaster/sqlalchemy/comparators.py +++ b/src/coaster/sqlalchemy/comparators.py @@ -1,12 +1,10 @@ -""" -Enhanced query and custom comparators -------------------------------------- -""" +"""Enhanced query and custom comparators.""" from __future__ import annotations +import contextlib from collections.abc import Iterator -from typing import Any, Optional, TypeVar, Union +from typing import Any, Optional, Union from uuid import UUID import sqlalchemy as sa @@ -23,9 +21,6 @@ ] -_T = TypeVar('_T', bound=Any) - - class SplitIndexComparator(Comparator): """Base class for comparators that split a string and compare with one part.""" @@ -42,23 +37,23 @@ def __init__( def _decode(self, other: str) -> Any: raise NotImplementedError - def __eq__(self, other: Any) -> sa.ColumnElement[bool]: # type: ignore[override] + def __eq__(self, other: object) -> sa.ColumnElement[bool]: # type: ignore[override] try: - other = self._decode(other) + other = self._decode(other) # type: ignore[arg-type] except (ValueError, TypeError): # If other could not be decoded, we do not match. return sa.sql.expression.false() - return self.__clause_element__() == other + return self.__clause_element__() == other # type: ignore[return-value] is_ = __eq__ # type: ignore[assignment] - def __ne__(self, other: Any) -> sa.ColumnElement[bool]: # type: ignore[override] + def __ne__(self, other: object) -> sa.ColumnElement[bool]: # type: ignore[override] try: - other = self._decode(other) + other = self._decode(other) # type: ignore[arg-type] except (ValueError, TypeError): # If other could not be decoded, we are not equal. return sa.sql.expression.true() - return self.__clause_element__() != other + return self.__clause_element__() != other # type: ignore[return-value] isnot = __ne__ # type: ignore[assignment] is_not = __ne__ # type: ignore[assignment] @@ -68,10 +63,8 @@ def in_(self, other: Any) -> sa.ColumnElement[bool]: # type: ignore[override] def errordecode(otherlist: Any) -> Iterator[str]: for val in otherlist: - try: + with contextlib.suppress(ValueError, TypeError): yield self._decode(val) - except (ValueError, TypeError): - pass valid_values = list(errordecode(other)) if not valid_values: diff --git a/src/coaster/sqlalchemy/functions.py b/src/coaster/sqlalchemy/functions.py index 78f62dc6..9a4f5a4b 100644 --- a/src/coaster/sqlalchemy/functions.py +++ b/src/coaster/sqlalchemy/functions.py @@ -1,7 +1,4 @@ -""" -Helper functions ----------------- -""" +"""SQLAlchemy helper functions.""" from __future__ import annotations @@ -34,23 +31,31 @@ class UtcNow(sa.sql.functions.GenericFunction): """Provide func.utcnow() that guarantees UTC timestamp.""" - type = sa.TIMESTAMP() # noqa: A003 + type = sa.TIMESTAMP() identifier = 'utcnow' inherit_cache = True @compiles(UtcNow) -def _utcnow_default(element: UtcNow, _compiler: Any, **kwargs) -> str: +def _utcnow_default(_element: UtcNow, _compiler: Any, **_kwargs) -> str: return 'CURRENT_TIMESTAMP' @compiles(UtcNow, 'mysql') -def _utcnow_mysql(element: UtcNow, _compiler: Any, **kwargs) -> str: # pragma: no cover +def _utcnow_mysql( # pragma: no cover + _element: UtcNow, + _compiler: Any, + **_kwargs, +) -> str: return 'UTC_TIMESTAMP()' @compiles(UtcNow, 'mssql') -def _utcnow_mssql(element: UtcNow, _compiler: Any, **kwargs) -> str: # pragma: no cover +def _utcnow_mssql( # pragma: no cover + _element: UtcNow, + _compiler: Any, + **_kwargs, +) -> str: return 'SYSUTCDATETIME()' @@ -230,6 +235,7 @@ def _validate_child( sa.event.listen( primary_table, 'after_create', + # spell-checker:ignore parentcol plpgsql sa.DDL( ''' CREATE FUNCTION %(function)s() RETURNS TRIGGER AS $$ @@ -283,7 +289,7 @@ def _validate_child( def auto_init_default( - column: Union[sa.orm.ColumnProperty, sa.orm.InstrumentedAttribute] + column: Union[sa.orm.ColumnProperty, sa.orm.InstrumentedAttribute], ) -> None: """ Set the default value of a column on first access. @@ -295,6 +301,7 @@ def auto_init_default( class MyModel(Model): column: Mapped[PyType] = sa.orm.mapped_column(SqlType, default="value") + auto_init_default(MyModel.column) """ if isinstance(column, sa.orm.ColumnProperty): diff --git a/src/coaster/sqlalchemy/immutable_annotation.py b/src/coaster/sqlalchemy/immutable_annotation.py index eb24f163..bef4f21f 100644 --- a/src/coaster/sqlalchemy/immutable_annotation.py +++ b/src/coaster/sqlalchemy/immutable_annotation.py @@ -1,7 +1,4 @@ -""" -Immutable annotation --------------------- -""" +"""Immutable annotation.""" from __future__ import annotations @@ -64,11 +61,11 @@ def immutable_column_set_listener( # skipcq: PTC-W0065 # SQLAlchemy >= 1.4 it appears to also be used in place of NO_VALUE. # NO_VALUE is for columns that have no value (either never set, or not # loaded). Because of this ambiguity, we pair it with a test for persistence - if old_value == value: - pass - elif ( - old_value is NEVER_SET or old_value is NO_VALUE - ) and target.persistent is False: + if ( + old_value == value + or (old_value is NEVER_SET or old_value is NO_VALUE) + and target.persistent is False + ): pass else: raise ImmutableColumnError(cls.__name__, attr, old_value, value) diff --git a/src/coaster/sqlalchemy/markdown.py b/src/coaster/sqlalchemy/markdown.py index e330a2f7..89da7b8b 100644 --- a/src/coaster/sqlalchemy/markdown.py +++ b/src/coaster/sqlalchemy/markdown.py @@ -86,7 +86,7 @@ def __json__(self) -> Any: return {'text': self._text, 'html': self._html} # Compare text value - def __eq__(self, other: Any) -> bool: + def __eq__(self, other: object) -> bool: """Compare for equality.""" return isinstance(other, MarkdownComposite) and ( self.__composite_values__() == other.__composite_values__() @@ -116,7 +116,7 @@ def __bool__(self) -> bool: return bool(self._text) @classmethod - def coerce(cls, key: str, value: Any) -> MarkdownComposite: + def coerce(cls, _key: str, value: Any) -> MarkdownComposite: """Allow a composite column to be assigned a string value.""" return cls(value) diff --git a/src/coaster/sqlalchemy/mixins.py b/src/coaster/sqlalchemy/mixins.py index f254ab57..c9c64553 100644 --- a/src/coaster/sqlalchemy/mixins.py +++ b/src/coaster/sqlalchemy/mixins.py @@ -1,20 +1,22 @@ """ -SQLAlchemy mixin classes ------------------------- +SQLAlchemy mixin classes. Coaster provides a number of mixin classes for SQLAlchemy models. To use in your Flask app:: from sqlalchemy.orm import DeclarativeBase from flask_sqlalchemy import SQLAlchemy - from coaster.sqlalchemy import BaseMixin, ModelBase + from coaster.sqlalchemy import BaseMixin, ModelBase, Query + class Model(ModelBase, DeclarativeBase): '''Model base class.''' - db = SQLAlchemy(metadata=Model.metadata) + + db = SQLAlchemy(metadata=Model.metadata, query_class=Query) Model.init_flask_sqlalchemy(db) + class MyModel(BaseMixin[int], Model): # Integer serial primary key; alt: UUID __tablename__ = 'my_model' @@ -46,14 +48,8 @@ class MyModel(BaseMixin[int], Model): # Integer serial primary key; alt: UUID from typing_extensions import Self, TypeVar, get_original_bases from uuid import UUID, uuid4 -from flask import current_app, url_for - -try: # Flask >= 3.0 - from flask.sansio.app import App as FlaskApp -except ModuleNotFoundError: # Flask < 3.0 - from flask import Flask as FlaskApp - import sqlalchemy as sa +from flask import url_for from sqlalchemy import event from sqlalchemy.ext.hybrid import hybrid_property from sqlalchemy.orm import Mapped, declarative_mixin, declared_attr, synonym @@ -61,6 +57,7 @@ class MyModel(BaseMixin[int], Model): # Integer serial primary key; alt: UUID from werkzeug.routing import BuildError from ..auth import current_auth +from ..compat import BaseApp, current_app_object from ..typing import ReturnDecorator, WrappedFunc from ..utils import ( InspectableSet, @@ -122,11 +119,12 @@ class IdMixin(Generic[PkeyType]): from uuid import UUID + class MyModel(IdMixin[UUID], Model): # or IdMixin[int] ... - class OtherModel(BaseMixin[UUID], Model): - ... + + class OtherModel(BaseMixin[UUID], Model): ... The legacy method using a flag also works, but will break type discovery for the id column in static type analysis (mypy or pyright):: @@ -157,6 +155,7 @@ def __init_subclass__(cls, *args: Any, **kwargs: Any) -> None: " to the base class (`IdMixin[int]` or `IdMixin[UUID]`) instead of" " specifying `__uuid_primary_key__` directly", PkeyWarning, + stacklevel=2, ) for base in get_original_bases(cls): @@ -209,7 +208,7 @@ def __init_subclass__(cls, *args: Any, **kwargs: Any) -> None: @immutable @declared_attr @classmethod - def id(cls) -> Mapped[PkeyType]: # noqa: A003 + def id(cls) -> Mapped[PkeyType]: """Database identity for this model.""" if cls.__uuid_primary_key__: return sa.orm.mapped_column( @@ -248,10 +247,7 @@ def url_id_uuid_comparator(cls: type[Self]) -> SqlUuidHexComparator: url_id_uuid_func.__name__ = 'url_id' url_id_uuid_func.__doc__ = url_id_uuid_func.__doc__ - url_id_property = hybrid_property(url_id_uuid_func).comparator( - url_id_uuid_comparator - ) - return url_id_property + return hybrid_property(url_id_uuid_func).comparator(url_id_uuid_comparator) def url_id_int_func(self: Self) -> str: """URL-safe representation of the integer id as a string.""" @@ -263,10 +259,7 @@ def url_id_int_comparator(cls: type[Self]) -> SqlSplitIdComparator: url_id_int_func.__name__ = 'url_id' url_id_int_func.__doc__ = url_id_int_func.__doc__ - url_id_property = hybrid_property(url_id_int_func).comparator( - url_id_int_comparator - ) - return url_id_property + return hybrid_property(url_id_int_func).comparator(url_id_int_comparator) url_id: declared_attr[str] = declared_attr(__url_id) del __url_id @@ -406,7 +399,11 @@ class PermissionMixin: should use the role granting mechanism in :class:`RoleMixin`. """ - def permissions(self, actor: Any, inherited: Optional[set[str]] = None) -> set[str]: + def permissions( + self, + actor: Any, # noqa: ARG002 + inherited: Optional[set[str]] = None, + ) -> set[str]: """Return permissions available to the given user on this object.""" if inherited is not None: return set(inherited) @@ -476,15 +473,9 @@ def __getitem__(self, key: str) -> str: raise KeyError(key) from exc def __len__(self) -> int: - # pylint: disable=protected-access + capp = current_app_object() return len(self.obj.url_for_endpoints[None]) + ( - len( - self.obj.url_for_endpoints.get( - current_app._get_current_object(), {} # type: ignore[attr-defined] - ) - ) - if current_app - else 0 + len(self.obj.url_for_endpoints.get(capp, {})) if capp else 0 ) def __iter__(self) -> Iterator[str]: @@ -493,13 +484,9 @@ def __iter__(self) -> Iterator[str]: # 3. Confirm the action does not require additional parameters # 4. Yield whatever passes the tests current_roles = self.obj.roles_for(current_auth.actor, current_auth.anchors) + capp = current_app_object() for app, app_actions in self.obj.url_for_endpoints.items(): - # pylint: disable=protected-access - if app is None or ( - current_app - and app - is current_app._get_current_object() # type: ignore[attr-defined] - ): + if app is None or app is capp: for action, endpoint_data in app_actions.items(): if not endpoint_data.requires_kwargs and ( endpoint_data.roles is None @@ -515,12 +502,12 @@ class UrlForMixin: #: different endpoints in different apps. The app may also be None as fallback. Each #: subclass will get its own dictionary. This particular dictionary is only used as #: an inherited fallback. - url_for_endpoints: ClassVar[ - dict[Optional[FlaskApp], dict[str, UrlEndpointData]] - ] = {None: {}} + url_for_endpoints: ClassVar[dict[Optional[BaseApp], dict[str, UrlEndpointData]]] = { + None: {} + } #: Mapping of {app: {action: (classview, attr)}} view_for_endpoints: ClassVar[ - dict[Optional[FlaskApp], dict[str, tuple[Any, str]]] + dict[Optional[BaseApp], dict[str, tuple[Any, str]]] ] = {} #: Dictionary of URLs available on this object @@ -528,12 +515,7 @@ class UrlForMixin: def url_for(self, action: str = 'view', **kwargs) -> str: """Return public URL to this instance for a given action (default 'view').""" - # pylint: disable=protected-access - app = ( - current_app._get_current_object() # type: ignore[attr-defined] - if current_app - else None - ) + app = current_app_object() if app is not None and action in self.url_for_endpoints.get(app, {}): endpoint_data = self.url_for_endpoints[app][action] else: @@ -582,13 +564,13 @@ def is_url_for( cls, __action: str, __endpoint: Optional[str] = None, - __app: Optional[FlaskApp] = None, + __app: Optional[BaseApp] = None, /, _external: Optional[bool] = None, **paramattrs: Union[str, tuple[str, ...], Callable[[Any], str]], ) -> ReturnDecorator: """ - Decorator that registers a view as a :meth:`url_for` target. + Register a view as a :meth:`url_for` target. :param __action: Action to register a URL under :param __endpoint: View endpoint name to pass to Flask's ``url_for`` @@ -615,7 +597,7 @@ def register_endpoint( action: str, *, endpoint: str, - app: Optional[FlaskApp], + app: Optional[BaseApp], paramattrs: Mapping[str, Union[str, tuple[str, ...], Callable[[Any], str]]], roles: Optional[Collection[str]] = None, external: Optional[bool] = None, @@ -626,15 +608,14 @@ def register_endpoint( :param view_func: View handler to be registered :param str action: Action to register a URL under :param str endpoint: View endpoint name to pass to Flask's ``url_for`` - :param app: Flask app (default: `None`) + :param app: Flask or Quart app (default: `None`) :param external: If `True`, URLs are assumed to be external-facing by default :param roles: Roles to which this URL is available, required by :class:`UrlDict` :param dict paramattrs: Mapping of URL parameter to attribute name on the object """ if 'url_for_endpoints' not in cls.__dict__: - cls.url_for_endpoints = { - None: {} - } # Stick it into the class with the first endpoint + # Stick it into the class with the first endpoint + cls.url_for_endpoints = {None: {}} cls.url_for_endpoints.setdefault(app, {}) paramattrs = dict(paramattrs) @@ -656,7 +637,7 @@ def register_endpoint( @classmethod def register_view_for( - cls, app: Optional[FlaskApp], action: str, classview: Any, attr: str + cls, app: Optional[BaseApp], action: str, classview: Any, attr: str ) -> None: """Register a classview and view method for a given app and action.""" if 'view_for_endpoints' not in cls.__dict__: @@ -666,14 +647,13 @@ def register_view_for( def view_for(self, action: str = 'view') -> Any: """Return the classview view method that handles the specified action.""" # pylint: disable=protected-access - app = current_app._get_current_object() # type: ignore[attr-defined] + app = current_app_object() view, attr = self.view_for_endpoints[app][action] return getattr(view(self), attr) def classview_for(self, action: str = 'view') -> Any: """Return the classview containing the view method for the specified action.""" - # pylint: disable=protected-access - app = current_app._get_current_object() # type: ignore[attr-defined] + app = current_app_object() return self.view_for_endpoints[app][action][0](self) @@ -721,10 +701,7 @@ class BaseNameMixin(BaseMixin[PkeyType, ActorType]): # Drop CHECK constraint first in case it was already present op.drop_constraint(tablename + '_name_check', tablename) # Create CHECK constraint - op.create_check_constraint( - tablename + '_name_check', - tablename, - "name <> ''") + op.create_check_constraint(tablename + '_name_check', tablename, "name <> ''") """ #: Prevent use of these reserved names @@ -855,9 +832,7 @@ class BaseScopedNameMixin(BaseMixin[PkeyType, ActorType]): class Event(BaseScopedNameMixin, Model): __tablename__ = 'event' - organizer_id: Mapped[int] = sa.orm.mapped_column(sa.ForeignKey( - 'organizer.id' - )) + organizer_id: Mapped[int] = sa.orm.mapped_column(sa.ForeignKey('organizer.id')) organizer: Mapped[Organizer] = relationship(Organizer) parent = sa.orm.synonym('organizer') __table_args__ = (sa.UniqueConstraint('organizer_id', 'name'),) @@ -874,10 +849,7 @@ class Event(BaseScopedNameMixin, Model): # Drop CHECK constraint first in case it was already present op.drop_constraint(tablename + '_name_check', tablename) # Create CHECK constraint - op.create_check_constraint( - tablename + '_name_check', - tablename, - "name <> ''") + op.create_check_constraint(tablename + '_name_check', tablename, "name <> ''") """ #: Prevent use of these reserved names @@ -1047,10 +1019,7 @@ class BaseIdNameMixin(BaseMixin[PkeyType, ActorType]): # Drop CHECK constraint first in case it was already present op.drop_constraint(tablename + '_name_check', tablename) # Create CHECK constraint - op.create_check_constraint( - tablename + '_name_check', - tablename, - "name <> ''") + op.create_check_constraint(tablename + '_name_check', tablename, "name <> ''") """ #: Allow blank names after all? @@ -1150,7 +1119,8 @@ def url_name_uuid_b58(self) -> str: def _url_name_uuid_b58_comparator(cls) -> SqlUuidB58Comparator: """Return SQL comparator for name and UUID in Base58 format.""" return SqlUuidB58Comparator( - cls.uuid, splitindex=-1 # type: ignore[attr-defined] + cls.uuid, # type: ignore[attr-defined] + splitindex=-1, ) @@ -1235,9 +1205,7 @@ class BaseScopedIdNameMixin(BaseScopedIdMixin[PkeyType, ActorType]): class Event(BaseScopedIdNameMixin, Model): __tablename__ = 'event' - organizer_id: Mapped[int] = sa.orm.mapped_column(sa.ForeignKey( - 'organizer.id' - )) + organizer_id: Mapped[int] = sa.orm.mapped_column(sa.ForeignKey('organizer.id')) organizer: Mapped[Organizer] = relationship(Organizer) parent = sa.orm.synonym('organizer') __table_args__ = (sa.UniqueConstraint('organizer_id', 'url_id'),) @@ -1254,10 +1222,7 @@ class Event(BaseScopedIdNameMixin, Model): # Drop CHECK constraint first in case it was already present op.drop_constraint(tablename + '_name_check', tablename) # Create CHECK constraint - op.create_check_constraint( - tablename + '_name_check', - tablename, - "name <> ''") + op.create_check_constraint(tablename + '_name_check', tablename, "name <> ''") """ #: Allow blank names after all? @@ -1363,7 +1328,8 @@ def url_name_uuid_b58(self) -> str: def _url_name_uuid_b58_comparator(cls) -> SqlUuidB58Comparator: """Return SQL comparator for name and UUID in Base58 format.""" return SqlUuidB58Comparator( - cls.uuid, splitindex=-1 # type: ignore[attr-defined] + cls.uuid, # type: ignore[attr-defined] + splitindex=-1, ) diff --git a/src/coaster/sqlalchemy/model.py b/src/coaster/sqlalchemy/model.py index cc5f842f..8243d52c 100644 --- a/src/coaster/sqlalchemy/model.py +++ b/src/coaster/sqlalchemy/model.py @@ -1,6 +1,5 @@ """ -Flask-SQLAlchemy-compatible model base class --------------------------------------------- +Flask-SQLAlchemy-compatible model base class. Flask-SQLAlchemy's ``db.Model`` is not compatible with PEP 484 type hinting. Coaster provides a replacement :class:`ModelBase` base class. To use, combine it with @@ -11,14 +10,16 @@ db = SQLAlchemy() + class MyModel(db.Model): others = db.relationship('Other', lazy='dynamic') + class MyBindModel(db.Model): __bind_key__ = 'my_bind' - class Other(db.Model): - ... + + class Other(db.Model): ... Replace with:: @@ -26,16 +27,17 @@ class Other(db.Model): from typing import List from sqlalchemy.orm import DeclarativeBase from flask_sqlalchemy import SQLAlchemy - from coaster.sqlalchemy import ( - DeclarativeBase, DynamicMapped, ModelBase, relationship - ) + from coaster.sqlalchemy import DeclarativeBase, DynamicMapped, ModelBase, relationship + class Model(ModelBase, DeclarativeBase): # ModelBase must be before DeclarativeBase pass + class BindModel(ModelBase, DeclarativeBase): __bind_key__ = 'my_bind' + class MyModel(Model): # __tablename__ is not autogenerated with ModelBase and must be specified __tablename__ = 'my_model' @@ -46,11 +48,13 @@ class MyModel(Model): # imported from Coaster others: DynamicMapped[Other] = relationship(lazy='dynamic') + class MyBindModel(BindModel): __tablename__ = 'my_bind_model' - class Other(Model): - ... + + class Other(Model): ... + db = SQLAlchemy(metadata=Model.metadata) # Use the base model's metadata Model.init_flask_sqlalchemy(db) @@ -133,16 +137,16 @@ class ModelWarning(UserWarning): # --- SQLAlchemy type aliases ---------------------------------------------------------- -bigint: TypeAlias = Annotated[int, mapped_column(sa.BigInteger())] -smallint: TypeAlias = Annotated[int, mapped_column(sa.SmallInteger())] -int_pkey: TypeAlias = Annotated[int, mapped_column(primary_key=True)] -uuid4_pkey: TypeAlias = Annotated[ +bigint: TypeAlias = Annotated[int, mapped_column(sa.BigInteger())] # noqa: PYI042 +smallint: TypeAlias = Annotated[int, mapped_column(sa.SmallInteger())] # noqa: PYI042 +int_pkey: TypeAlias = Annotated[int, mapped_column(primary_key=True)] # noqa: PYI042 +uuid4_pkey: TypeAlias = Annotated[ # noqa: PYI042 uuid.UUID, mapped_column(primary_key=True, default=uuid.uuid4) ] -timestamp: TypeAlias = Annotated[ +timestamp: TypeAlias = Annotated[ # noqa: PYI042 datetime.datetime, mapped_column(sa.TIMESTAMP(timezone=True)) ] -timestamp_now: TypeAlias = Annotated[ +timestamp_now: TypeAlias = Annotated[ # noqa: PYI042 datetime.datetime, mapped_column( sa.TIMESTAMP(timezone=True), @@ -150,7 +154,7 @@ class ModelWarning(UserWarning): nullable=False, ), ] -jsonb: TypeAlias = Annotated[ +jsonb: TypeAlias = Annotated[ # noqa: PYI042 dict, mapped_column(sa.JSON().with_variant(postgresql.JSONB, 'postgresql')) ] @@ -199,7 +203,7 @@ def get_or_404(self, ident: Any, description: Optional[str] = None) -> _T_co: :param ident: The primary key to query :param description: A custom message to show on the error page """ - rv = self.get(ident) # pylint: disable=assignment-from-no-return + rv = self.get(ident) if rv is None: abort(404, description=description) diff --git a/src/coaster/sqlalchemy/registry.py b/src/coaster/sqlalchemy/registry.py index 597e1468..f41a79b9 100644 --- a/src/coaster/sqlalchemy/registry.py +++ b/src/coaster/sqlalchemy/registry.py @@ -1,6 +1,5 @@ """ -Model helper registry ---------------------- +Model helper registry. Provides a :class:`Registry` type and a :class:`RegistryMixin` base class with three registries, used by other mixin classes. @@ -8,14 +7,14 @@ Helper classes such as forms and views can be registered to the model and later accessed from an instance:: - class MyModel(BaseMixin, Model): - ... + class MyModel(BaseMixin, Model): ... - class MyForm(Form): - ... - class MyView(ModelView): - ... + class MyForm(Form): ... + + + class MyView(ModelView): ... + MyModel.forms.main = MyForm MyModel.views.main = MyView @@ -254,7 +253,7 @@ def __get__( if TYPE_CHECKING: # Tell Mypy that it's okay for code to attempt reading an attr - def __getattr__(self, attr: str) -> Any: ... + def __getattr__(self, name: str) -> Any: ... class InstanceRegistry(Generic[_RT, _T]): @@ -273,15 +272,15 @@ def __init__(self, registry: _RT, obj: _T) -> None: self.__registry = registry self.__obj = obj - def __getattr__(self, attr: str) -> Any: + def __getattr__(self, name: str) -> Any: """Access a registry member.""" registry = self.__registry obj = self.__obj - func = getattr(registry, attr) # Raise AttributeError if unknown - kwarg = registry._members[attr] + func = getattr(registry, name) # Raise AttributeError if unknown + kwarg = registry._members[name] # If attr is a property, return the result - if attr in registry._properties: + if name in registry._properties: if kwarg is not None: return func(**{kwarg: obj}) return func(obj) @@ -289,12 +288,9 @@ def __getattr__(self, attr: str) -> Any: # These checks are cached to __dict__ so __getattr__ won't be called again: # If attr is a cached property, cache and return the result - if attr in registry._cached_properties: - if kwarg is not None: - val = func(**{kwarg: obj}) - else: - val = func(obj) - setattr(self, attr, val) + if name in registry._cached_properties: + val = func(**{kwarg: obj}) if kwarg is not None else func(obj) + setattr(self, name, val) return val # Not a property or cached_property. Construct a partial, cache and return it @@ -302,7 +298,7 @@ def __getattr__(self, attr: str) -> Any: partial_func = partial(func, **{kwarg: obj}) else: partial_func = partial(func, obj) - setattr(self, attr, partial_func) + setattr(self, name, partial_func) return partial_func diff --git a/src/coaster/sqlalchemy/roles.py b/src/coaster/sqlalchemy/roles.py index a081d600..923adb36 100644 --- a/src/coaster/sqlalchemy/roles.py +++ b/src/coaster/sqlalchemy/roles.py @@ -1,6 +1,5 @@ """ -Role-based access control -------------------------- +Role-based access control. Coaster provides a :class:`RoleMixin` class that can be used to define role-based access control to the attributes and methods of any SQLAlchemy model. :class:`RoleMixin` is a @@ -32,12 +31,14 @@ app = Flask(__name__) db = SQLAlchemy(app) + @declarative_mixin class ColumnMixin: ''' Mixin class that offers some columns to the RoleModel class below, demonstrating two ways to use `with_roles`. ''' + @with_roles(rw={'owner'}) @declared_attr def mixed_in1(cls) -> Mapped[str]: @@ -57,7 +58,7 @@ class RoleModel(ColumnMixin, RoleMixin, Model): # Avoid this approach in a parent or mixin class as definitions will # be lost if the subclass does not copy `__roles__`. - __roles__ = { + __roles__: ClassVar = { 'all': { 'read': {'id', 'name', 'title'}, }, @@ -72,18 +73,16 @@ class RoleModel(ColumnMixin, RoleMixin, Model): id: Mapped[int] = sa.orm.mapped_column(sa.Integer, primary_key=True) name: Mapped[str] = with_roles( # Specify read+write access - sa.orm.mapped_column(sa.Unicode(250)), - rw={'owner'} + sa.orm.mapped_column(sa.Unicode(250)), rw={'owner'} ) user_id: Mapped[int] = sa.orm.mapped_column( - sa.ForeignKey('user.id'), - nullable=False + sa.ForeignKey('user.id'), nullable=False ) user: Mapped[User] = with_roles( relationship(User), grants={'owner'}, # Use `grants` here or `granted_by` in `__roles__` - ) + ) # `with_roles` can also be called later. This is required for # properties, where roles must be assigned after the property is @@ -153,7 +152,6 @@ def roles_for( from typing_extensions import Self, TypeAlias, TypeVar import sqlalchemy as sa -from flask import g from sqlalchemy import event, select from sqlalchemy.exc import NoInspectionAvailable, NoResultFound from sqlalchemy.ext.orderinglist import OrderingList @@ -176,6 +174,7 @@ def roles_for( from sqlalchemy.schema import SchemaItem from ..auth import current_auth +from ..compat import flask_g, quart_g from ..utils import InspectableSet, is_collection, is_dunder, nary_op from .functions import idfilters from .model import AppenderQuery @@ -578,7 +577,7 @@ def __bool__(self) -> bool: else any(self._role_is_present(role) for role in self.obj.__roles__) ) - def __eq__(self, other: Any) -> bool: + def __eq__(self, other: object) -> bool: if isinstance(other, LazyRoleSet): return self.obj == other.obj and self.actor == other.actor return self._contents() == other @@ -659,7 +658,8 @@ class DynamicAssociationProxy(Generic[_V, _R]): # Proxy to an attribute on the target of the relationship (specifying the type): Document.child_attributes = DynamicAssociationProxy[attr_type, rel_type]( - 'child_relationship', 'attribute') + 'child_relationship', 'attribute' + ) This proxy does not provide access to the query capabilities of dynamic relationships. It merely optimizes for containment queries. A query like this:: @@ -783,7 +783,7 @@ def __bool__(self) -> bool: assert relattr.session is not None # nosec B101 return relattr.session.query(relattr.exists()).scalar() - def __eq__(self, other: Any) -> bool: + def __eq__(self, other: object) -> bool: if isinstance(other, DynamicAssociationProxyBind): return ( self.obj == other.obj @@ -863,7 +863,7 @@ def __class__(self) -> type[RoleMixinType]: return self._obj.__class__ @__class__.setter - def __class__(self, value: Any) -> NoReturn: # noqa: F811 + def __class__(self, value: Any) -> NoReturn: raise TypeError("__class__ cannot be set") def __init__( @@ -908,7 +908,10 @@ def __init__( def __repr__(self) -> str: return f'RoleAccessProxy(obj={self._obj!r}, roles={self.current_roles!r})' - def current_access(self, datasets: Optional[Sequence[str]] = None) -> Self: + def current_access( + self, + datasets: Optional[Sequence[str]] = None, # noqa: ARG002 + ) -> Self: """Mimic :meth:`RoleMixin.current_access`, but simply return self.""" return self @@ -1004,21 +1007,21 @@ def _all_read(self) -> set[str]: object.__setattr__(self, '_all_read_cache', available_read_attrs) return available_read_attrs - def __getattr__(self, attr: str) -> Any: + def __getattr__(self, name: str) -> Any: # See also __getitem__, which doesn't consult _call - if self.__attr_available(attr, 'read') or self.__attr_available(attr, 'call'): - return self.__get_processed_attr(attr) + if self.__attr_available(name, 'read') or self.__attr_available(name, 'call'): + return self.__get_processed_attr(name) raise AttributeError( - f"{self._obj.__class__.__qualname__}.{attr};" + f"{self._obj.__class__.__qualname__}.{name};" f" current roles {self.current_roles!r}" ) - def __setattr__(self, attr: str, value: Any) -> None: + def __setattr__(self, name: str, value: Any) -> None: # See also __setitem__ - if self.__attr_available(attr, 'write'): - return setattr(self._obj, attr, value) + if self.__attr_available(name, 'write'): + return setattr(self._obj, name, value) raise AttributeError( - f"{self._obj.__class__.__qualname__}.{attr};" + f"{self._obj.__class__.__qualname__}.{name};" f" current roles {self.current_roles!r}" ) @@ -1064,7 +1067,7 @@ def __json__(self) -> Any: ) return dict(self) - def __eq__(self, other: Any) -> bool: + def __eq__(self, other: object) -> bool: if other == self._obj: return True if ( @@ -1150,6 +1153,7 @@ def with_roles( title: Mapped[str] = with_roles(sa.orm.mapped_column(sa.Unicode), read={'all'}) + @with_roles(read={'all'}) @hybrid_property def url_id(self) -> str: @@ -1162,10 +1166,12 @@ def url_id(self) -> str: def title(self) -> str: return self._title + @title.setter def title(self, value: str) -> None: self._title = value + # Either of the following is fine, since with_roles annotates objects # instead of wrapping them. The return value can be discarded if it's # already present on the host object: @@ -1189,11 +1195,10 @@ class RoleModel(Model): user_id: Mapped[int] = sa.orm.mapped_column(sa.ForeignKey('user.id')) user: Mapped[UserModel] = relationship(UserModel) - document_id: Mapped[int] = sa.orm.mapped_column(sa.ForeignKey( - 'document.id' - )) + document_id: Mapped[int] = sa.orm.mapped_column(sa.ForeignKey('document.id')) document: Mapped[DocumentModel] = relationship(DocumentModel) + DocumentModel.rolemodels = with_roles( relationship(RoleModel), grants_via={'user': {'role1', 'role2'}} ) @@ -1302,6 +1307,7 @@ def role_check( class MyModel(RoleMixin, ...): ... + @role_check('reader', 'viewer') # Takes multiple roles as necessary def has_reader_role( self, actor: Optional[ActorType], anchors: Sequence[Any] = () @@ -1335,7 +1341,7 @@ def _(self) -> Iterable[ActorType]: """ def decorator( - func: Callable[[_CRM, Optional[_CRA], Sequence[Any]], bool] + func: Callable[[_CRM, Optional[_CRA], Sequence[Any]], bool], ) -> ConditionalRole[_CRM, _CRA]: return ConditionalRole(roles, func) @@ -1422,7 +1428,7 @@ def __iter__(self) -> Iterator[_CRA]: return iter(()) return iter(iter_func(self.__self__)) - def __eq__(self, other: Any) -> bool: + def __eq__(self, other: object) -> bool: if isinstance(other, ConditionalRoleBind): return ( self._rolecheck == other._rolecheck and self.__self__ == other.__self__ @@ -1438,15 +1444,15 @@ class RoleMixin(Generic[ActorType]): Subclasses must define a :attr:`__roles__` dictionary with roles and the attributes they have call, read and write access to:: - __roles__ = { + __roles__: ClassVar = { 'role_name': { 'call': {'meth1', 'meth2'}, 'read': {'attr1', 'attr2'}, 'write': {'attr1', 'attr2'}, 'granted_by': {'rel1', 'rel2'}, 'granted_via': {'rel1': 'attr1', 'rel2': 'attr2'}, - }, - } + }, + } The ``granted_by`` key works in reverse: if the actor is present in any of the attributes in the set, they are granted that role via :meth:`roles_for`. @@ -1534,10 +1540,11 @@ def current_roles(self) -> InspectableSet[LazyRoleSet]: :meth:`roles_for` instead, or use `current_roles` only after roles are changed. """ - cache = getattr(g, '_coaster_role_cache', None) + cache = getattr(quart_g or flask_g, '_coaster_role_cache', None) if cache is None: cache = {} - g._coaster_role_cache = cache # pylint: disable=protected-access + # pylint: disable=protected-access + (quart_g or flask_g)._coaster_role_cache = cache cache_key = (self, current_auth.actor, current_auth.anchors) if cache_key not in cache: cache[cache_key] = InspectableSet( @@ -1735,7 +1742,7 @@ def access_for( __datasets__ = { 'primary': {'uuid', 'name', 'title', 'children', 'parent'}, - 'related': {'uuid', 'name', 'title'} + 'related': {'uuid', 'name', 'title'}, } Objects and related objects can be safely enumerated like this:: @@ -1842,7 +1849,8 @@ def _configure_roles(_mapper: Any, cls: type[RoleMixin]) -> None: ) for role in data.grants: granted_by = cls.__roles__.setdefault(role, {}).setdefault( - 'granted_by', [] # List as it needs to be ordered + 'granted_by', + [], # List as it needs to be ordered ) if name not in granted_by: granted_by.append(name) diff --git a/src/coaster/sqlalchemy/statemanager.py b/src/coaster/sqlalchemy/statemanager.py index 35a0476e..a999d301 100644 --- a/src/coaster/sqlalchemy/statemanager.py +++ b/src/coaster/sqlalchemy/statemanager.py @@ -1,6 +1,5 @@ """ -States and transitions ----------------------- +States and transitions. :class:`StateManager` wraps a SQLAlchemy column with an enum to facilitate state inspection, and to control state change via transitions. Sample usage:: @@ -30,23 +29,21 @@ class MyPost(BaseMixin, Model): sa.Integer, StateManager.check_constraint('state', MY_STATE, sa.Integer), default=MY_STATE.DRAFT, - nullable=False + nullable=False, ) _reviewstate: Mapped[int] = sa.orm.mapped_column( 'reviewstate', sa.Integer, StateManager.check_constraint('reviewstate', REVIEW_STATE, sa.Integer), default=REVIEW_STATE.UNSUBMITTED, - nullable=False + nullable=False, ) # The state managers controlling the columns. If the host type is optionally # provided as a generic type argument, it will be applied to the lambda # functions in add_conditional_state for static type checking state = StateManager['MyPost']('_state', MY_STATE, doc="The post's state") - reviewstate = StateManager( - '_reviewstate', REVIEW_STATE, doc="Reviewer's state" - ) + reviewstate = StateManager('_reviewstate', REVIEW_STATE, doc="Reviewer's state") # Datetime for the additional states and transitions timestamp: Mapped[datetime] = sa.orm.mapped_column( @@ -59,13 +56,11 @@ class MyPost(BaseMixin, Model): state.add_conditional_state( 'RECENT', state.PUBLISHED, - lambda post: post.datetime > datetime.utcnow() - timedelta(hours=1) + lambda post: post.datetime > datetime.utcnow() - timedelta(hours=1), ) # REDRAFTABLE = DRAFT or PENDING or RECENT - state.add_state_group( - 'REDRAFTABLE', state.DRAFT, state.PENDING, state.RECENT - ) + state.add_state_group('REDRAFTABLE', state.DRAFT, state.PENDING, state.RECENT) # Transitions change FROM one state TO another, and can require another state # manager to be in a specific state @@ -522,10 +517,10 @@ def __init__( def __repr__(self) -> str: return repr(self._mstate) - def __getattr__(self, attr: str) -> Any: - return getattr(self._mstate, attr) + def __getattr__(self, name: str) -> Any: + return getattr(self._mstate, name) - def __eq__(self, other: Any) -> bool: + def __eq__(self, other: object) -> bool: return ( isinstance(other, ManagedStateInstance) and self._mstate == other._mstate @@ -606,10 +601,7 @@ def add_transition( # transition state_values = {} # Value: ManagedState # Step 1: Convert ManagedStateGroup into a list of ManagedState items - if isinstance(from_, ManagedStateGroup): - from_all = from_.states - else: # ManagedState - from_all = [from_] + from_all = from_.states if isinstance(from_, ManagedStateGroup) else [from_] # Step 2: Unroll grouped values from the original LabeledEnum for mstate in from_all: if is_collection(mstate.value): @@ -722,7 +714,7 @@ def __call__(self, *args: _P.args, **kwargs: _P.kwargs) -> _R: self.obj, transition=self.statetransition, exception=e ) return e.result - except Exception as e: # noqa: B902 + except Exception as e: transition_exception.send( self.obj, transition=self.statetransition, exception=e ) @@ -753,7 +745,7 @@ class StateManager(Generic[_SG]): """ #: Host class for the state manager - cls: type + cls: type[_SG] #: All possible states by name states: dict[str, Union[ManagedState, ManagedStateGroup]] #: All static states, back-referenced by value (group and conditional excluded) @@ -766,7 +758,8 @@ class StateManager(Generic[_SG]): def __init__( self, propname: str, lenum: type[LabeledEnum], doc: Optional[str] = None ) -> None: - self.cls = object # Depend on __set_name__ to update + # Depend on __set_name__ to update + self.cls = object # type: ignore[assignment] self.propname = propname self.name = '' # Currently unknown name, will only be known in __set_name__ self.lenum = lenum @@ -965,7 +958,7 @@ def transition( """ def decorator( - f: Union[StateTransition, Callable[Concatenate[Any, _P], _R]] + f: Union[StateTransition, Callable[Concatenate[Any, _P], _R]], ) -> StateTransition[_P, _R]: if isinstance(f, StateTransition): f.add_transition(self, from_, to, data) @@ -1028,7 +1021,7 @@ def group( @staticmethod def check_constraint( column: str, - enum: Union[type[Enum], type[LabeledEnum]], + enum: type[Union[Enum, LabeledEnum]], type_: Optional[Union[type[sa.types.TypeEngine], sa.types.TypeEngine]] = None, **kwargs: Any, ) -> sa.CheckConstraint: @@ -1043,7 +1036,7 @@ class MyModel(Model): 'state', sa.Integer, StateManager.check_constraint('state', MY_ENUM, sa.Integer), - default=MY_ENUM.DEFAULT + default=MY_ENUM.DEFAULT, ) state = StateManager(_state, MY_ENUM) @@ -1056,7 +1049,9 @@ class MyModel(Model): print( str( StateManager.check_constraint( - 'your_column', YOUR_ENUM, sa.Integer # Or specific column type + 'your_column', + YOUR_ENUM, + sa.Integer, # Or specific column type ).sqltext.compile(compile_kwargs={'literal_binds': True}) ) ) @@ -1072,7 +1067,7 @@ class MyModel(Model): values = enum.keys() else: values = [_member.value for _member in enum] - return sa.CheckConstraint(sa.Column(column, type_).in_(values)) + return sa.CheckConstraint(sa.Column(column, type_).in_(values), **kwargs) if TYPE_CHECKING: # Stub for mypy to recognise names added by _add_state_internal. There is a diff --git a/src/coaster/typing.py b/src/coaster/typing.py index 78535d56..77aece28 100644 --- a/src/coaster/typing.py +++ b/src/coaster/typing.py @@ -1,11 +1,9 @@ -""" -Coaster types -------------- -""" +"""Coaster types.""" from __future__ import annotations -from typing import Any, Callable, Protocol, TypeVar +from collections.abc import Coroutine +from typing import Any, Callable, Protocol, TypeVar, Union from typing_extensions import ParamSpec, TypeAlias WrappedFunc = TypeVar('WrappedFunc', bound=Callable) @@ -16,7 +14,12 @@ class Method(Protocol[_P, _R_co]): - """Protocol for an instance method.""" + """Protocol for an instance method (sync or async).""" # pylint: disable=no-self-argument - def __call__(__self, self: Any, *args: _P.args, **kwargs: _P.kwargs) -> _R_co: ... + def __call__( # noqa: D102,RUF100 + __self, # noqa: N805 + self: Any, + *args: _P.args, + **kwargs: _P.kwargs, + ) -> Union[Coroutine[Any, Any, _R_co], _R_co]: ... diff --git a/src/coaster/utils/__init__.py b/src/coaster/utils/__init__.py index 10d041d5..8aaf03b4 100644 --- a/src/coaster/utils/__init__.py +++ b/src/coaster/utils/__init__.py @@ -8,7 +8,6 @@ in Flask-based applications. """ - from .classes import * from .datetime import * from .markdown import * diff --git a/src/coaster/utils/classes.py b/src/coaster/utils/classes.py index f746fdb8..ac7f1fd9 100644 --- a/src/coaster/utils/classes.py +++ b/src/coaster/utils/classes.py @@ -1,7 +1,4 @@ -""" -Utility classes ---------------- -""" +"""Utility classes.""" from __future__ import annotations @@ -312,6 +309,8 @@ def __dataclass_repr__(self) -> str: class NameTitle(NamedTuple): + """Name and title pair.""" + name: str title: str @@ -328,7 +327,6 @@ def __new__( name: str, bases: tuple[type[Any], ...], attrs: dict[str, Any], - **kwargs: Any, ) -> type[LabeledEnum]: labels: dict[str, Any] = {} names: dict[str, Any] = {} @@ -607,12 +605,12 @@ def __bool__(self) -> bool: def __getitem__(self, key: Any) -> bool: return key in self._members # Return True if present, False otherwise - def __setattr__(self, attr: str, _value: Any) -> NoReturn: + def __setattr__(self, name: str, _value: Any) -> NoReturn: """Prevent accidental attempts to set a value.""" - raise AttributeError(attr) + raise AttributeError(name) - def __getattr__(self, attr: str) -> bool: - return attr in self._members # Return True if present, False otherwise + def __getattr__(self, name: str) -> bool: + return name in self._members # Return True if present, False otherwise def _op_bool(self, op: str, other: Any) -> bool: """Return result of a boolean operation.""" @@ -630,11 +628,11 @@ def __lt__(self, other: Any) -> bool: """Return self < other.""" return self._op_bool('__lt__', other) - def __eq__(self, other: Any) -> bool: + def __eq__(self, other: object) -> bool: """Return self == other.""" return self._op_bool('__eq__', other) - def __ne__(self, other: Any) -> bool: + def __ne__(self, other: object) -> bool: """Return self != other.""" return self._op_bool('__ne__', other) diff --git a/src/coaster/utils/datetime.py b/src/coaster/utils/datetime.py index 6f233b1b..e5ee4af5 100644 --- a/src/coaster/utils/datetime.py +++ b/src/coaster/utils/datetime.py @@ -1,8 +1,6 @@ -""" -Date, time and timezone utilities ---------------------------------- -""" +"""Date, time and timezone utilities.""" +# spell-checker:ignore isoweek from __future__ import annotations from datetime import date, datetime, timedelta, tzinfo @@ -24,13 +22,6 @@ 'ParseError', ] -# --- Thread safety fix ---------------------------------------------------------------- - -# Force import of strptime. This was previously used in :func:`parse_isoformat`, -# but we have left this in because it could break elsewhere. -# https://stackoverflow.com/q/16309650/78903 -datetime.strptime('20160816', '%Y%m%d') - def utcnow() -> datetime: """Return the current time at UTC with `tzinfo` set.""" @@ -107,8 +98,9 @@ def midnight_to_utc( datetime.datetime(2016, 12, 31, 18, 30, tzinfo=) >>> midnight_to_utc(datetime(2017, 1, 1), naive=True) datetime.datetime(2017, 1, 1, 0, 0) - >>> midnight_to_utc(pytz.timezone('Asia/Kolkata').localize(datetime(2017, 1, 1)), - ... naive=True) + >>> midnight_to_utc( + ... pytz.timezone('Asia/Kolkata').localize(datetime(2017, 1, 1)), naive=True + ... ) datetime.datetime(2016, 12, 31, 18, 30) >>> midnight_to_utc(date(2017, 1, 1)) datetime.datetime(2017, 1, 1, 0, 0, tzinfo=) @@ -118,16 +110,14 @@ def midnight_to_utc( datetime.datetime(2016, 12, 31, 18, 30, tzinfo=) >>> midnight_to_utc(datetime(2017, 1, 1), timezone='Asia/Kolkata') datetime.datetime(2016, 12, 31, 18, 30, tzinfo=) - >>> midnight_to_utc(pytz.timezone('Asia/Kolkata').localize(datetime(2017, 1, 1)), - ... timezone='UTC') + >>> midnight_to_utc( + ... pytz.timezone('Asia/Kolkata').localize(datetime(2017, 1, 1)), timezone='UTC' + ... ) datetime.datetime(2017, 1, 1, 0, 0, tzinfo=) """ tz: Union[tzinfo, BaseTzInfo] if timezone: - if isinstance(timezone, str): - tz = pytz.timezone(timezone) - else: - tz = timezone + tz = pytz.timezone(timezone) if isinstance(timezone, str) else timezone elif isinstance(dt, datetime) and dt.tzinfo: tz = dt.tzinfo else: diff --git a/src/coaster/utils/markdown.py b/src/coaster/utils/markdown.py index ded1540a..8f62a7a2 100644 --- a/src/coaster/utils/markdown.py +++ b/src/coaster/utils/markdown.py @@ -1,6 +1,5 @@ """ -Markdown processor -================== +Markdown processor. Markdown parser with a number of sane defaults that resembles GitHub-Flavoured Markdown (GFM). @@ -124,7 +123,8 @@ def extendMarkdown(self, md: Markdown) -> None: # noqa: N802 JavascriptProtocolExtension(), ] -default_markdown_extensions = default_markdown_extensions_html + [ +default_markdown_extensions = [ + *default_markdown_extensions_html, 'pymdownx.highlight', 'pymdownx.inlinehilite', 'pymdownx.tasklist', diff --git a/src/coaster/utils/misc.py b/src/coaster/utils/misc.py index 4718be2c..25e16338 100644 --- a/src/coaster/utils/misc.py +++ b/src/coaster/utils/misc.py @@ -1,8 +1,6 @@ -""" -Miscellaneous utilities ------------------------ -""" +"""Miscellaneous utilities.""" +# spell-checker:ignore newsecret newpin checkused nullint nullstr getbool dunder from __future__ import annotations import email.utils @@ -55,9 +53,9 @@ # --- Common delimiters and punctuation ------------------------------------------------ -_strip_re = re.compile('[\'"`‘’“”′″‴]+') +_strip_re = re.compile('[\'"`‘’“”′″‴]+') # noqa: RUF001 _punctuation_re = re.compile( - '[\x00-\x1f +!#$%&()*\\-/<=>?@\\[\\\\\\]^_{|}:;,.…‒–—―«»]+' + '[\x00-\x1f +!#$%&()*\\-/<=>?@\\[\\\\\\]^_{|}:;,.…‒–—―«»]+' # noqa: RUF001 ) _ipv4_re = re.compile( r'^(?:(?:25[0-5]|2[0-4][0-9]|[01]?[0-9][0-9]?)\.){3}' @@ -162,7 +160,7 @@ def uuid1mc_from_datetime(dt: Union[datetime, float]) -> uuid.UUID: True >>> u1.time - u2.time < 5000 True - >>> d2 = datetime.fromtimestamp((u2.time - 0x01b21dd213814000) * 100 / 1e9) + >>> d2 = datetime.fromtimestamp((u2.time - 0x01B21DD213814000) * 100 / 1e9) >>> d2 == dt True """ @@ -334,15 +332,21 @@ def make_name( 'candidate2' >>> make_name('Candidate', checkused=lambda c: c in ['candidate'], counter=1) 'candidate1' - >>> make_name('Candidate', - ... checkused=lambda c: c in ['candidate', 'candidate1', 'candidate2'], counter=1) + >>> make_name( + ... 'Candidate', + ... checkused=lambda c: c in ['candidate', 'candidate1', 'candidate2'], + ... counter=1, + ... ) 'candidate3' >>> make_name('Long title, but snipped', maxlength=20) 'long-title-but-snipp' >>> len(make_name('Long title, but snipped', maxlength=20)) 20 - >>> make_name('Long candidate', maxlength=10, - ... checkused=lambda c: c in ['long-candi', 'long-cand1']) + >>> make_name( + ... 'Long candidate', + ... maxlength=10, + ... checkused=lambda c: c in ['long-candi', 'long-cand1'], + ... ) 'long-cand2' >>> make_name('Lǝnkǝran') 'lankaran' @@ -361,7 +365,7 @@ def make_name( 'testing-more-slashes' >>> make_name('What if a HTML ') 'what-if-a-html-tag' - >>> make_name('These are equivalent to \x01 through \x1A') + >>> make_name('These are equivalent to \x01 through \x1a') 'these-are-equivalent-to-through' >>> make_name("feedback;\x00") 'feedback' @@ -391,7 +395,7 @@ def make_name( return candidate -def format_currency(value: Union[int, float], decimals: int = 2) -> str: +def format_currency(value: float, decimals: int = 2) -> str: """ Return a number suitably formatted for display as currency. @@ -445,7 +449,7 @@ def md5sum(data: str) -> str: >>> len(md5sum('random text')) 32 """ - return hashlib.md5(data.encode('utf-8')).hexdigest() # nosec # skipcq: PTC-W1003 + return hashlib.md5(data.encode('utf-8'), usedforsecurity=False).hexdigest() def getbool(value: Union[str, int, bool, None]) -> Optional[bool]: @@ -554,7 +558,7 @@ def my_func(this=None, that=None, other=None): # 2. len(kwargs) - kwargs.values().count(None) # # This is 2x faster than the first method under Python 2.7. Unfortunately, - # it doesn't work in Python 3 because `kwargs.values()` is a view that doesn't + # it does not work in Python 3 because `kwargs.values()` is a view that does not # have a `count` method. It needs to be cast into a tuple/list first, but # remains faster despite the cast's slowdown. Tuples are faster than lists. diff --git a/src/coaster/utils/text.py b/src/coaster/utils/text.py index abab3291..53118925 100644 --- a/src/coaster/utils/text.py +++ b/src/coaster/utils/text.py @@ -1,7 +1,4 @@ -""" -Text processing utilities -------------------------- -""" +"""Text processing utilities.""" from __future__ import annotations @@ -74,7 +71,8 @@ re_singleline_spaces = re.compile( '[' + unicode_extended_whitespace + ']', re.UNICODE | re.MULTILINE ) -re_multiline_spaces = re.compile( # This is missing \u2028 and \u2029 (separators) +re_multiline_spaces = re.compile( + # This is missing \u2028 and \u2029 (separators) '[' + ascii_whitespace_without_newline + unicode_format_whitespace + ']', re.UNICODE | re.MULTILINE, ) @@ -141,7 +139,7 @@ def dont_linkify_filenames( return attrs -LINKIFY_CALLBACKS = list(DEFAULT_CALLBACKS) + [dont_linkify_filenames] +LINKIFY_CALLBACKS = [*DEFAULT_CALLBACKS, dont_linkify_filenames] # type: ignore[list-item] def sanitize_html( @@ -285,8 +283,7 @@ def subloop( subloop(None, doc) # Replace   with ' ' - blocks = [t.replace('\xa0', ' ') for t in blocks] - return blocks + return [t.replace('\xa0', ' ') for t in blocks] def normalize_spaces(text: str) -> str: @@ -342,9 +339,7 @@ def deobfuscate_email(text: str) -> str: # Find the "at" text = _deobfuscate_at1_re.sub('@', text) text = _deobfuscate_at2_re.sub(r'\1@\2', text) - text = _deobfuscate_at3_re.sub(r'\1@\2', text) - - return text + return _deobfuscate_at3_re.sub(r'\1@\2', text) def simplify_text(text: str) -> str: @@ -356,7 +351,8 @@ def simplify_text(text: str) -> str: >>> simplify_text("Awesome Coder, wanted at Awesome Company! ") 'awesome coder wanted at awesome company' >>> simplify_text("Awesome Coder, wanted at Awesome Company! ") == ( - ... 'awesome coder wanted at awesome company') + ... 'awesome coder wanted at awesome company' + ... ) True """ text = text.translate(text.maketrans('', '', string.punctuation)).lower() diff --git a/src/coaster/views/classview.py b/src/coaster/views/classview.py index 38ce0910..f3c63e9b 100644 --- a/src/coaster/views/classview.py +++ b/src/coaster/views/classview.py @@ -1,8 +1,8 @@ """ -Class-based views ------------------ +Class-based views. -Group related views into a class for easier management. +Group related views into a class for easier management. See :class:`ClassView` and +:class:`ModelView` for two different ways to use them. """ # pyright: reportMissingImports=false @@ -10,7 +10,7 @@ from __future__ import annotations import warnings -from collections.abc import Collection +from collections.abc import Awaitable, Collection, Coroutine from functools import partial, update_wrapper, wraps from inspect import iscoroutinefunction from typing import ( @@ -37,29 +37,21 @@ ) from flask import abort, make_response, redirect, request +from flask.blueprints import BlueprintSetupState from flask.globals import _cv_app, app_ctx from flask.typing import ResponseReturnValue - -try: # Flask >= 3.0 - from flask.sansio.app import App as FlaskApp - from flask.sansio.blueprints import Blueprint, BlueprintSetupState -except ModuleNotFoundError: # Flask < 3.0 - from flask import Blueprint, Flask as FlaskApp - from flask.blueprints import BlueprintSetupState - from furl import furl from sqlalchemy.orm.attributes import InstrumentedAttribute from sqlalchemy.orm.descriptor_props import SynonymProperty from sqlalchemy.orm.properties import RelationshipProperty from werkzeug.local import LocalProxy from werkzeug.routing import Map as WzMap, Rule as WzRule -from werkzeug.wrappers import Response as BaseResponse from ..auth import add_auth_attribute, current_auth +from ..compat import BaseApp, BaseBlueprint, BaseResponse from ..sqlalchemy import PermissionMixin, Query, UrlForMixin from ..typing import Method from ..utils import InspectableSet -from .misc import ensure_sync __all__ = [ # Functions @@ -108,13 +100,9 @@ def __call__(self, __decorated: ViewMethod[_P, _R_co]) -> ViewMethod[_P, _R_co]: @overload def __call__(self, __decorated: Method[_P, _R_co]) -> ViewMethod[_P, _R_co]: ... - def __call__( # skipcq: PTC-W0049 + def __call__( self, - __decorated: Union[ - ClassViewType, - Method[_P, _R_co], - ViewMethod[_P, _R_co], - ], + __decorated: Union[ClassViewType, Method[_P, _R_co], ViewMethod[_P, _R_co]], ) -> Union[ClassViewType, ViewMethod[_P, _R_co]]: ... @@ -127,9 +115,8 @@ def __call__(self, __decorated: ViewMethod[_P, _R_co]) -> ViewMethod[_P, _R_co]: @overload def __call__(self, __decorated: Method[_P, _R_co]) -> ViewMethod[_P, _R_co]: ... - def __call__( # skipcq: PTC-W0049 - self, - __decorated: Union[Method[_P, _R_co], ViewMethod[_P, _R_co]], + def __call__( + self, __decorated: Union[Method[_P, _R_co], ViewMethod[_P, _R_co]] ) -> ViewMethod[_P, _R_co]: ... @@ -138,7 +125,7 @@ class InitAppCallback(Protocol): def __call__( self, - app: Union[FlaskApp, Blueprint], + app: Union[BaseApp, BaseBlueprint], rule: str, endpoint: str, view_func: Callable, @@ -161,7 +148,7 @@ def _get_arguments_from_rule( def route( rule: str, init_app: Optional[ - Union[FlaskApp, Blueprint, tuple[Union[FlaskApp, Blueprint], ...]] + Union[BaseApp, BaseBlueprint, tuple[Union[BaseApp, BaseBlueprint], ...]] ] = None, **options: Any, ) -> RouteDecoratorProtocol: @@ -202,23 +189,28 @@ def decorator( ClassViewType, Method[_P, _R_co], ViewMethod[_P, _R_co], - ] + ], ) -> Union[ClassViewType, ViewMethod[_P, _R_co]]: # Are we decorating a ClassView? If so, annotate the ClassView and return it - if isinstance(decorated, type) and issubclass(decorated, ClassView): - if '__routes__' not in decorated.__dict__: - decorated.__routes__ = [] - decorated.__routes__.append((rule, options)) - if init_app is not None: - apps = init_app if isinstance(init_app, tuple) else (init_app,) - for each in apps: - decorated.init_app(each) - return decorated + if isinstance(decorated, type): + if issubclass(decorated, ClassView): + if '__routes__' not in decorated.__dict__: + decorated.__routes__ = [] + decorated.__routes__.append((rule, options)) + if init_app is not None: + apps = init_app if isinstance(init_app, tuple) else (init_app,) + for each in apps: + decorated.init_app(each) + return decorated + raise TypeError("@route can only decorate ClassView subclasses") if init_app is not None: raise TypeError( "@route accepts init_app only when decorating a ClassView or ModelView" ) + + if isinstance(decorated, AsyncViewMethod) or iscoroutinefunction(decorated): + return AsyncViewMethod(decorated, rule=rule, rule_options=options) return ViewMethod(decorated, rule=rule, rule_options=options) return decorator @@ -239,8 +231,10 @@ def decorator(decorated: ViewMethod[_P, _R_co]) -> ViewMethod[_P, _R_co]: ... def decorator(decorated: Method[_P, _R_co]) -> ViewMethod[_P, _R_co]: ... def decorator( - decorated: Union[ViewMethod[_P, _R_co], Method[_P, _R_co]] + decorated: Union[ViewMethod[_P, _R_co], Method[_P, _R_co]], ) -> ViewMethod[_P, _R_co]: + if isinstance(decorated, AsyncViewMethod) or iscoroutinefunction(decorated): + return AsyncViewMethod(decorated, data=kwargs) return ViewMethod(decorated, data=kwargs) return decorator @@ -289,7 +283,7 @@ class ViewMethod(Generic[_P, _R_co]): #: Template-accessible name, same as :attr:`__name__` name: str #: The unmodified wrapped method, made available for future decorators - __func__: Callable[Concatenate[Any, _P], _R_co] + __func__: Callable[Concatenate[Any, _P], Any] #: The wrapped method with the class's :attr:`~ClassView.__decorators__` applied decorated_func: Callable #: The actual view function registered to Flask, responsible for creating an @@ -302,7 +296,7 @@ class ViewMethod(Generic[_P, _R_co]): def __init__( self, - decorated: Union[Callable[Concatenate[Any, _P], _R_co], ViewMethod[_P, _R_co]], + decorated: Union[Method[_P, _R_co], ViewMethod[_P, _R_co]], rule: Optional[str] = None, rule_options: Optional[dict[str, Any]] = None, data: Optional[dict[str, Any]] = None, @@ -347,6 +341,7 @@ class CrudView: @route('delete', methods=['GET', 'POST']) def delete(self): ... + @route('/') class MyModelView(CrudView, ModelView[MyModel]): @route('remove', methods=['GET', 'POST']) # Add another route @@ -380,6 +375,7 @@ def delete(self): ... @viewdata() # This creates a ViewMethod with no routes def latent(self): ... + @route('/') class MyModelView(CrudView, ModelView[MyModel]): delete = CrudView.delete.with_route('remove', methods=['GET', 'POST']) @@ -418,7 +414,10 @@ def __get__( return bind def __call__( # pylint: disable=no-self-argument - __self, self: ClassViewSubtype, *args: _P.args, **kwargs: _P.kwargs + __self, # noqa: N805 + self: ClassViewSubtype, + *args: _P.args, + **kwargs: _P.kwargs, ) -> _R_co: # Mimic an unbound method call return __self.__func__(self, *args, **kwargs) @@ -469,33 +468,69 @@ def __set_name__(self, owner: type[ClassViewSubtype], name: str) -> None: self.decorated_func = decorated_func - # TODO: Make async_view_func if `__func__` or `decorated_func` is async, and - # expect the class to provide an `async_dispatch_request` - def view_func(**view_args: Any) -> BaseResponse: - """ - The actual view function registered to Flask, responsible for dispatch. - - This function creates an instance of the view class, then calls - :meth:`~ViewClass.dispatch_request` on it passing in :attr:`decorated_func`. - """ - # Instantiate the view class. We depend on its __init__ requiring no args - viewinst = owner() - # Declare ourselves (the ViewMethod) as the current view. The bind makes - # equivalence tests possible, such as ``self.current_method == self.index`` - viewinst.current_method = ViewMethodBind(self, viewinst) - # Place view arguments in the instance, in case they are needed outside the - # dispatch process - viewinst.view_args = view_args - # Place the view instance on the app context for :obj:`current_view` to - # discover - if app_ctx: - app_ctx.current_view = viewinst # type: ignore[attr-defined] - # Call the view class's dispatch method. View classes can customise this - # for desired behaviour. - return viewinst.dispatch_request( - decorated_func, view_args # type: ignore[arg-type] + if iscoroutinefunction(self.__func__) and not iscoroutinefunction( + decorated_func + ): + raise TypeError( + f"{self.__qualname__} is async, but one of the decorators is not" ) + if iscoroutinefunction(decorated_func): + + async def view_func(**view_args: Any) -> BaseResponse: + """ + Dispatch Flask/Quart view. + + This function creates an instance of the view class, then calls + :meth:`~ViewClass.async_dispatch_request` on it passing in + :attr:`decorated_func`. + """ + # Instantiate the view class. We depend on its __init__ requiring no + # args + viewinst = owner() + # Declare ourselves (the AsyncViewMethod) as the current view. The bind + # makes equivalence tests possible, such as ``self.current_method == + # self.index`` + viewinst.current_method = AsyncViewMethodBind(self, viewinst) + # Place view arguments in the instance, in case they are needed outside + # the dispatch process + viewinst.view_args = view_args + # Place the view instance on the app context for :obj:`current_view` to + # discover + if app_ctx: + app_ctx.current_view = viewinst # type: ignore[attr-defined] + # Call the view class's dispatch method. View classes can customise this + # for desired behaviour. + return await viewinst.async_dispatch_request(decorated_func, view_args) + + else: + + def view_func(**view_args: Any) -> BaseResponse: # type: ignore[misc] + """ + Dispatch Flask/Quart view. + + This function creates an instance of the view class, then calls + :meth:`~ViewClass.dispatch_request` on it passing in + :attr:`decorated_func`. + """ + # Instantiate the view class. We depend on its __init__ requiring no + # args + viewinst = owner() + # Declare ourselves (the ViewMethod) as the current view. The bind makes + # equivalence tests possible, such as ``self.current_method == + # self.index`` + viewinst.current_method = ViewMethodBind(self, viewinst) + # Place view arguments in the instance, in case they are needed outside + # the dispatch process + viewinst.view_args = view_args + # Place the view instance on the app context for :obj:`current_view` to + # discover + if app_ctx: + app_ctx.current_view = viewinst # type: ignore[attr-defined] + # Call the view class's dispatch method. View classes can customise this + # for desired behaviour. + return viewinst.dispatch_request(decorated_func, view_args) + # Make view_func resemble the decorated function... view_func = update_wrapper(view_func, decorated_func) # ...but give view_func the name of the method in the class. @@ -515,7 +550,7 @@ def view_func(**view_args: Any) -> BaseResponse: def init_app( self, - app: Union[FlaskApp, Blueprint], + app: Union[BaseApp, BaseBlueprint], cls: type[ClassView], callback: Optional[InitAppCallback] = None, ) -> None: @@ -536,6 +571,41 @@ def init_app( callback(app, use_rule, endpoint, self.view_func, **use_options) +class AsyncViewMethod(ViewMethod[_P, _R_co]): + """Async variant of :class:`ViewMethod.""" + + @overload + def __get__(self, obj: None, cls: Optional[type[Any]] = None) -> Self: ... + + @overload + def __get__( + self, obj: Any, cls: Optional[type[Any]] = None + ) -> AsyncViewMethodBind[_P, _R_co]: ... + + def __get__( + self, obj: Optional[Any], cls: Optional[type[Any]] = None + ) -> Union[Self, AsyncViewMethodBind[_P, _R_co]]: + if obj is None: + return self + bind = AsyncViewMethodBind(self, obj) + if '__slots__' not in cls.__dict__: + # Cache it in the instance obj for repeat access. Since we are a non-data + # descriptor (no __set__ or __delete__ methods), the instance dict will have + # first priority for future lookups + setattr(obj, self.__name__, bind) + return bind + + # pylint: disable=no-self-argument, invalid-overridden-method + async def __call__( # type: ignore[override] + __self, # noqa: N805 + self: ClassViewSubtype, + *args: _P.args, + **kwargs: _P.kwargs, + ) -> _R_co: + # Mimic an unbound method call + return await __self.__func__(self, *args, **kwargs) + + class ViewMethodBind(Generic[_P, _R_co]): """Wrapper for :class:`ViewMethod` binding it to an instance of the view class.""" @@ -547,7 +617,7 @@ class ViewMethodBind(Generic[_P, _R_co]): __qualname__: str __module__: str __doc__: Optional[str] - __func__: Callable[Concatenate[Any, _P], _R_co] + __func__: Callable[Concatenate[Any, _P], Any] decorated_func: Callable view_func: Callable default_endpoint: str @@ -578,7 +648,7 @@ def __call__(self, *args: _P.args, **kwargs: _P.kwargs) -> _R_co: def __getattr__(self, name: str) -> Any: return getattr(self._view_method, name) - def __eq__(self, other: Any) -> bool: + def __eq__(self, other: object) -> bool: if isinstance(other, ViewMethodBind): return ( self._view_method == other._view_method @@ -596,6 +666,20 @@ def is_available(self) -> bool: return getattr(func, 'is_available', lambda _: True)(self.__self__) +class AsyncViewMethodBind(ViewMethodBind[_P, _R_co]): + """Wrapper for :class:`ViewMethod` binding it to an instance of the view class.""" + + __slots__ = () + + # pylint: disable=invalid-overridden-method + async def __call__( # type: ignore[override] + self, *args: _P.args, **kwargs: _P.kwargs + ) -> _R_co: + # Treat this like a call to the original method and not to the view. + # As per the __decorators__ spec, we call .__func__, not .decorated_func + return await self._view_method.__func__(self.__self__, *args, **kwargs) + + class ClassView: """ Base class for defining a collection of views that are related to each other. @@ -617,6 +701,7 @@ def index(): def about(): return render_template('about.html.jinja2') + IndexView.init_app(app) The :func:`route` decorator on the class specifies the base rule, which is prefixed @@ -696,7 +781,7 @@ def current_handler(self) -> ViewMethodBind: ) return self.current_method - def __eq__(self, other: Any) -> bool: + def __eq__(self, other: object) -> bool: return type(other) is type(self) def dispatch_request( @@ -717,13 +802,11 @@ def dispatch_request( :param view_args: View arguments, to be passed on to the view method """ # Call the :meth:`before_request` method - resp = ensure_sync(self.before_request)() + resp = self.before_request() # pylint: disable=assignment-from-none if resp is not None: - return ensure_sync(self.after_request)(make_response(resp)) + return self.after_request(make_response(resp)) # Call the view method, then pass the response to :meth:`after_response` - return ensure_sync(self.after_request)( - make_response(ensure_sync(view)(self, **view_args)) - ) + return self.after_request(make_response(view(self, **view_args))) def before_request(self) -> Optional[ResponseReturnValue]: """ @@ -746,6 +829,7 @@ def after_request(self, response: BaseResponse) -> BaseResponse: class MyView(ClassView): ... + def after_request(self, response): response = super().after_request(response) ... # Process here @@ -756,6 +840,69 @@ def after_request(self, response): """ return response + async def async_dispatch_request( + self, + view: Callable[..., Awaitable[ResponseReturnValue]], + view_args: dict[str, Any], + ) -> BaseResponse: + """ + Async view dispatcher that invokes before and after-view hooks. + + Calls :meth:`async_before_request`, the view, and then + :meth:`async_after_request`. If :meth:`async_before_request` returns a non-None + response, the view is skipped and flow proceeds to :meth:`async_after_request`. + + Generic subclasses may override this to provide a custom flow. + :class:`ModelView` overrides to insert a model loading phase. + + :param view: View method wrapped in specified decorators. The dispatcher must + call this + :param view_args: View arguments, to be passed on to the view method + """ + # Call the :meth:`async_before_request` method + resp = await self.async_before_request() + if resp is not None: + return await self.async_after_request(make_response(resp)) + # Call the view method, then pass the response to :meth:`async_after_response` + return await self.async_after_request( + make_response(await view(self, **view_args)) + ) + + async def async_before_request(self) -> Optional[ResponseReturnValue]: + """ + Process request before the async view method. + + This method is called after the app's ``before_request`` handlers, and before + the class's view method. Subclasses and mixin classes may define their own + :meth:`async_before_request` to pre-process requests. This method receives + context via `self`, in particular via :attr:`current_method` and + :attr:`view_args`. The default implementation calls :meth:`before_request`. + """ + return self.before_request() + + async def async_after_request(self, response: BaseResponse) -> BaseResponse: + """ + Process response returned by async view. + + This method is called with the response from the view method. It must return a + valid response object. Subclasses and mixin classes may override this to perform + any necessary post-processing:: + + class MyView(ClassView): + ... + + async def async_after_request(self, response): + response = await super().async_after_request(response) + ... # Process here + return response + + The default implementation calls :meth:`after_request`. + + :param response: Response from the view method + :return: Response object + """ + return self.after_request(response) + def is_available(self) -> bool: """ Return `True` if *any* view method in the class is currently available. @@ -794,7 +941,7 @@ def __init_subclass__(cls) -> None: @classmethod def init_app( cls, - app: Union[FlaskApp, Blueprint], + app: Union[BaseApp, BaseBlueprint], callback: Optional[InitAppCallback] = None, ) -> None: """ @@ -822,9 +969,9 @@ class ModelView(ClassView, Generic[ModelType]): @route('/doc/', init_app=app) class DocumentView(UrlForView, InstanceLoader, ModelView): model = Document - route_model_map = { - 'document': 'name' - } + route_model_map: ClassVar = { + 'document': 'name', + } @route('') @render_with(json=True) @@ -856,7 +1003,7 @@ def view(self): model = MyModel # This is auto-inserted when using ModelView[MyModel] obj: MyModel # This is auto-inserted when using ModelView[MyModel] - route_model_map = { + route_model_map: ClassVar = { 'document': 'name', # Map 'document' in URL to obj.name 'parent': 'parent.name', # Map 'parent' to obj.parent.name } @@ -895,12 +1042,11 @@ class can be used together in the same view, with the class taking priority. subclassed:: class Mixin: - class GetAttr: - ... + class GetAttr: ... + class MyModelView(Mixin, ModelView[MyModel]): - class GetAttr(Mixin.GetAttr): - ... + class GetAttr(Mixin.GetAttr): ... :class:`~ModelView.GetAttr` is verbose but its utility shows in static type checking and code refactoring. @@ -914,8 +1060,9 @@ def __init__(self, obj: Optional[ModelType] = None) -> None: registry:: @MyModel.views('main') - class MyModelView(ModelView[MyModel]): - ... + class MyModelView(ModelView[MyModel]): ... + + view = obj.views.main() # Same as `view = MyModelView(obj)` @@ -939,8 +1086,8 @@ def __init_subclass__(cls) -> None: break super().__init_subclass__() - def __eq__(self, other: Any) -> bool: - return type(other) is type(self) and other.obj == self.obj + def __eq__(self, other: object) -> bool: + return isinstance(other, self.__class__) and other.obj == self.obj def dispatch_request( self, view: Callable[..., ResponseReturnValue], view_args: dict[str, Any] @@ -960,17 +1107,50 @@ def dispatch_request( to the view """ # Call the :meth:`before_request` method - resp = ensure_sync(self.before_request)() + resp = self.before_request() # pylint: disable=assignment-from-none if resp is not None: - return ensure_sync(self.after_request)(make_response(resp)) + return self.after_request(make_response(resp)) # Load the database model - resp = ensure_sync(self.load)(**view_args) + resp = self.load(**view_args) if resp is not None: - return ensure_sync(self.after_request)(make_response(resp)) + return self.after_request(make_response(resp)) # Trigger post-load processing of the object self.post_load() # Call the view method, then pass the response to :meth:`after_response` - return ensure_sync(self.after_request)(make_response(ensure_sync(view)(self))) + return self.after_request(make_response(view(self))) + + async def async_dispatch_request( + self, + view: Callable[..., Awaitable[ResponseReturnValue]], + view_args: dict[str, Any], + ) -> BaseResponse: + """ + Dispatch an async view. + + Calls :meth:`before_request`, :meth:`load`, the view, and then + :meth:`after_request`. + + If :meth:`before_request` or :meth:`load` return a non-None response, it will + skip ahead to :meth:`after_request`, allowing either of these to override the + view. + + :param view: View method wrapped in specified decorators + :param dict view_args: View arguments, to be passed on to :meth:`load` but not + to the view + """ + # Call the :meth:`before_request` method (optionally async) + resp = await self.async_before_request() + if resp is not None: + # If it had a response, skip the view and call after_request, then return + return await self.async_after_request(make_response(resp)) + # Load the database model + resp = await self.async_load(**view_args) + if resp is not None: + return await self.async_after_request(make_response(resp)) + # Trigger post-load processing of the object + self.post_load() + # Call the view method, then pass the response to :meth:`async_after_response` + return await self.async_after_request(make_response(await view(self))) if TYPE_CHECKING: # Type-checking version without arg-spec to let subclasses specify explicit args @@ -1015,6 +1195,21 @@ def load(self, **__view_args) -> Optional[ResponseReturnValue]: self.obj = self.loader(**__view_args) return self.after_loader() + if TYPE_CHECKING: + # Type-checking version without arg-spec to let subclasses specify explicit args + async_load: Callable[..., Coroutine[Any, Any, Optional[ResponseReturnValue]]] + + else: + # Actual default implementation has variadic arguments + async def async_load(self, **__view_args) -> Optional[ResponseReturnValue]: + """ + Load the database object given view parameters. + + The default implementation calls :meth:`load`. Subclasses should override + this to make an actual async implementation. + """ + return self.load(**__view_args) + def after_loader( # pylint: disable=useless-return self, ) -> Optional[ResponseReturnValue]: @@ -1113,7 +1308,7 @@ class UrlForView: @classmethod def init_app( cls, - app: Union[FlaskApp, Blueprint], + app: Union[BaseApp, BaseBlueprint], callback: Optional[InitAppCallback] = None, ) -> None: """Register view on an app.""" @@ -1121,20 +1316,20 @@ def init_app( def register_view_on_model( cls: type[ModelView], callback: Optional[InitAppCallback], - app: Union[FlaskApp, Blueprint], + app: Union[BaseApp, BaseBlueprint], rule: str, endpoint: str, view_func: Callable, **options: Any, ) -> None: def register_paths_from_app( - reg_app: FlaskApp, + reg_app: BaseApp, reg_rule: str, reg_endpoint: str, reg_options: dict[str, Any], ) -> None: model = cls.model - assert issubclass(model, UrlForMixin) # nosec B101 + assert issubclass(model, UrlForMixin) # nosec B101 # noqa: S101 # Only pass in the attrs that are included in the rule. # 1. Extract list of variables from the rule rulevars = _get_arguments_from_rule( @@ -1191,23 +1386,23 @@ def blueprint_postprocess(state: BlueprintSetupState) -> None: ) register_paths_from_app(state.app, reg_rule, reg_endpoint, reg_options) - if isinstance(app, FlaskApp): + if isinstance(app, BaseApp): register_paths_from_app(app, rule, endpoint, options) - elif isinstance(app, Blueprint): + elif isinstance(app, BaseBlueprint): app.record(blueprint_postprocess) else: raise TypeError(f"App must be Flask or Blueprint: {app!r}") if callback: # pragma: no cover callback(app, rule, endpoint, view_func, **options) - assert issubclass(cls, ModelView) # nosec B101 + assert issubclass(cls, ModelView) # nosec B101 # noqa: S101 super().init_app( # type: ignore[misc] app, callback=partial(register_view_on_model, cls, callback) ) def url_change_check( - f: Callable[_P, _R_co] + f: Callable[_P, _R_co], ) -> Callable[_P, Union[_R_co, BaseResponse]]: """ Decorate view method in a :class:`ModelView` to check for a change in URL. @@ -1219,7 +1414,7 @@ def url_change_check( @route('/doc/') class MyModelView(UrlForView, InstanceLoader, ModelView): model = MyModel - route_model_map = {'document': 'url_id_name'} + route_model_map: ClassVar = {'document': 'url_id_name'} @route('') @url_change_check @@ -1294,7 +1489,7 @@ class UrlChangeCheck: @route('/doc/') class MyModelView(UrlChangeCheck, UrlForView, InstanceLoader, ModelView): model = MyModel - route_model_map = {'document': 'url_id_name'} + route_model_map: ClassVar = {'document': 'url_id_name'} @route('') @render_with(json=True) @@ -1369,8 +1564,7 @@ def loader(self, **view_args: Any) -> Any: query = query.filter(source == value) else: query = query.filter(getattr(self.model, name) == value) - obj = query.one_or_404() - return obj + return query.one_or_404() return None diff --git a/src/coaster/views/decorators.py b/src/coaster/views/decorators.py index 058dba79..82f78940 100644 --- a/src/coaster/views/decorators.py +++ b/src/coaster/views/decorators.py @@ -1,24 +1,28 @@ -""" -View decorators ---------------- - -Decorators for view handlers. - -All items in this module can be imported directly from :mod:`coaster.views`. -""" +"""View decorators.""" +# spell-checker:ignore requestargs from __future__ import annotations -from collections.abc import Collection, Container, Iterable, Mapping +from collections.abc import Awaitable, Collection, Container, Iterable, Mapping from functools import wraps from inspect import iscoroutinefunction -from typing import TYPE_CHECKING, Any, Callable, Literal, Optional, TypeVar, Union, cast +from typing import ( + TYPE_CHECKING, + Any, + Callable, + Literal, + Optional, + Protocol, + TypeVar, + Union, + cast, + overload, +) from typing_extensions import ParamSpec from flask import ( Response, abort, - g, jsonify, make_response, redirect, @@ -32,8 +36,8 @@ from werkzeug.wrappers import Response as WerkzeugResponse from ..auth import add_auth_attribute, current_auth +from ..compat import BaseResponse, ensure_sync, flask_g, quart_g from ..utils import InspectableSet, is_collection -from .misc import ensure_sync __all__ = [ 'ReturnRenderWith', @@ -50,7 +54,7 @@ 'requires_permission', ] -ReturnRenderWithData = Mapping[str, Any] +ReturnRenderWithData = Mapping[str, object] ReturnRenderWithResponse = Union[WerkzeugResponse, ReturnRenderWithData] ReturnRenderWithHeaders = Union[list[tuple[str, str]], dict[str, str], Headers] ReturnRenderWith = Union[ @@ -74,12 +78,7 @@ class RequestValueError(BadRequest, ValueError): def requestargs( *args: Union[str, tuple[str, Callable[[str], Any]]], - source: Union[ - Literal['values'], - Literal['form'], - Literal['query'], - Literal['body'], - ] = 'values', + source: Literal['values', 'form', 'query', 'body'] = 'values', ) -> Callable[[Callable[_VP, _VR_co]], Callable[_VP, _VR_co]]: """ Decorate a function to load parameters from the request if not supplied directly. @@ -87,8 +86,7 @@ def requestargs( Usage:: @requestargs('param1', ('param2', int), 'param3[]', ...) - def function(param1, param2=0, param3=None): - ... + def function(param1, param2=0, param3=None): ... :func:`requestargs` takes a list of parameters to pass to the wrapped function, with an optional filter (useful to convert incoming string request data into integers @@ -208,35 +206,38 @@ def wrapper(*args: _VP.args, **kwargs: _VP.kwargs) -> _VR_co: def requestquery( - *args: Union[str, tuple[str, Callable[[str], Any]]] + *args: Union[str, tuple[str, Callable[[str], Any]]], ) -> Callable[[Callable[_VP, _VR_co]], Callable[_VP, _VR_co]]: """Like :func:`requestargs`, but loads from request.args (the query string).""" return requestargs(*args, source='query') def requestform( - *args: Union[str, tuple[str, Callable[[str], Any]]] + *args: Union[str, tuple[str, Callable[[str], Any]]], ) -> Callable[[Callable[_VP, _VR_co]], Callable[_VP, _VR_co]]: """Like :func:`requestargs`, but loads from request.form (the form submission).""" return requestargs(*args, source='form') def requestbody( - *args: Union[str, tuple[str, Callable[[str], Any]]] + *args: Union[str, tuple[str, Callable[[str], Any]]], ) -> Callable[[Callable[_VP, _VR_co]], Callable[_VP, _VR_co]]: """Like :func:`requestargs`, but loads from form or JSON basis content type.""" return requestargs(*args, source='body') def load_model( - model: type[Any], - attributes: dict[str, str], + model: Union[type[Any], list[type[Any]], tuple[type[Any], ...]], + attributes: dict[str, Union[str, Callable[[dict, dict], Any]]], parameter: str, kwargs: bool = False, permission: Optional[Union[str, set[str]]] = None, addlperms: Optional[Union[Iterable[str], Callable[[], Iterable[str]]]] = None, urlcheck: Collection[str] = (), -) -> Callable[[Callable[..., _VR]], Callable[..., Union[_VR, WerkzeugResponse]]]: +) -> Callable[ + [Callable[..., _VR]], + Callable[..., Union[_VR, BaseResponse, Awaitable[BaseResponse]]], +]: """ Decorate a view to load a model given a query parameter. @@ -244,7 +245,7 @@ def load_model( @app.route('/') @load_model(Profile, {'name': 'profile'}, 'profileob') - def profile_view(profileob): + def profile_view(profileob: Profile) -> ResponseReturnValue: # 'profileob' is now a Profile model instance. # The load_model decorator replaced this: # profileob = Profile.query.filter_by(name=profile).first_or_404() @@ -254,7 +255,7 @@ def profile_view(profileob): @app.route('/') @load_model(Profile, {'name': 'profile'}, 'profile') - def profile_view(profile: Profile): + def profile_view(profile: Profile) -> ResponseReturnValue: return f"Hello, {profile.name}" ``load_model`` aborts with a 404 if no instance is found. @@ -302,11 +303,22 @@ def profile_view(profile: Profile): def load_models( - *chain, permission: Optional[Union[str, set[str]]] = None, **config -) -> Callable[[Callable[..., _VR]], Callable[..., Union[_VR, WerkzeugResponse]]]: + *chain: tuple[ + Union[type[Any], list[type[Any]], tuple[type[Any], ...]], + dict[str, Union[str, Callable[[dict, dict], Any]]], + str, + ], + permission: Optional[Union[str, set[str]]] = None, + **config, +) -> Callable[ + [Callable[..., _VR]], + Callable[..., Union[_VR, BaseResponse, Awaitable[BaseResponse]]], +]: """ - Decorator to load a chain of models from the given parameters. This works just like - :func:`load_model` and accepts the same parameters, with some small differences. + Load a chain of models from the given parameters. + + This works just like :func:`load_model` and accepts the same parameters, with some + small differences. :param chain: The chain is a list of tuples of (``model``, ``attributes``, ``parameter``). Lists and tuples can be used interchangeably. All retrieved @@ -330,21 +342,25 @@ def load_models( @load_models( (Folder, {'name': 'folder_name'}, 'folder'), (Page, {'name': 'page_name', 'parent': 'folder'}, 'page'), - permission='view') - def show_page(folder, page): + permission='view', + ) + def show_page(folder: Folder, page: Page) -> ResponseReturnValue: return render_template('page.html', folder=folder, page=page) """ - def decorator(f: Callable[..., _VR]) -> Callable[..., Union[_VR, WerkzeugResponse]]: - @wraps(f) - def wrapper(*args, **kwargs) -> Union[_VR, WerkzeugResponse]: + def decorator( + f: Callable[..., _VR], + ) -> Callable[..., Union[_VR, BaseResponse, Awaitable[BaseResponse]]]: + def loader(kwargs: dict[str, Any]) -> Union[dict[str, Any], BaseResponse]: view_args: Optional[dict[str, Any]] request_endpoint: str = request.endpoint # type: ignore[assignment] permissions: Optional[set[str]] = None permission_required = ( {permission} if isinstance(permission, str) - else set(permission) if permission is not None else None + else set(permission) + if permission is not None + else None ) url_check_attributes = config.get('urlcheck', []) result: dict[str, Any] = {} @@ -353,7 +369,7 @@ def wrapper(*args, **kwargs) -> Union[_VR, WerkzeugResponse]: models = (models,) item = None url_check = False - url_check_paramvalues = {} + url_check_paramvalues: dict[str, tuple[Union[str, Callable], Any]] = {} for model in models: query = model.query for k, v in attributes.items(): @@ -398,22 +414,22 @@ def wrapper(*args, **kwargs) -> Union[_VR, WerkzeugResponse]: if callable(addlperms): addlperms = addlperms() or [] permissions.update(addlperms) - if g: # XXX: Deprecated + if g := (quart_g or flask_g): # XXX: Deprecated g.permissions = permissions if request: add_auth_attribute('permissions', InspectableSet(permissions)) - if ( - url_check and request.method == 'GET' - ): # Only do urlcheck redirects on GET requests + if url_check and request.method == 'GET': + # Only do url_check redirects on GET requests url_redirect = False view_args = None - for k, v in url_check_paramvalues.items(): - uparam, uvalue = v - if getattr(item, k) != uvalue: + for k2, v2 in url_check_paramvalues.items(): + uparam, uvalue = v2 + if (vvalue := getattr(item, k2)) != uvalue: url_redirect = True if view_args is None: view_args = dict(request.view_args or {}) - view_args[uparam] = getattr(item, k) + if isinstance(uparam, str): + view_args[uparam] = vvalue if url_redirect: if view_args is None: location = url_for(request_endpoint) @@ -422,7 +438,7 @@ def wrapper(*args, **kwargs) -> Union[_VR, WerkzeugResponse]: if request.query_string: location = location + '?' + request.query_string.decode() return redirect(location, code=302) - if parameter.startswith('g.'): + if parameter.startswith('g.') and g: parameter = parameter[2:] setattr(g, parameter, item) result[parameter] = item @@ -430,9 +446,29 @@ def wrapper(*args, **kwargs) -> Union[_VR, WerkzeugResponse]: permissions is None or not permission_required & permissions ): abort(403) - if config.get('kwargs'): - return ensure_sync(f)(*args, kwargs=kwargs, **result) - return ensure_sync(f)(*args, **result) + return result + + if iscoroutinefunction(f): + + @wraps(f) + async def wrapper(*args, **kwargs) -> Union[_VR, BaseResponse]: + result = loader(kwargs) + if isinstance(result, BaseResponse): + return result + if config.get('kwargs'): + return await f(*args, kwargs=kwargs, **result) + return await f(*args, **result) + + else: + + @wraps(f) + def wrapper(*args, **kwargs) -> Union[_VR, BaseResponse]: + result = loader(kwargs) + if isinstance(result, BaseResponse): + return result + if config.get('kwargs'): + return f(*args, kwargs=kwargs, **result) + return f(*args, **result) return wrapper @@ -473,23 +509,30 @@ def render_with( def myview(): return {'data': 'value'} + @app.route('/myview_with_json') @render_with('myview.html', json=True) def myview_no_json(): return {'data': 'value'} + @app.route('/otherview') - @render_with({ - 'text/html': 'otherview.html', - 'text/xml': 'otherview.xml'}) + @render_with( + { + 'text/html': 'otherview.html', + 'text/xml': 'otherview.xml', + } + ) def otherview(): return {'data': 'value'} + @app.route('/404view') @render_with('myview.html') def myview(): return {'error': '404 Not Found'}, 404 + @app.route('/headerview') @render_with('myview.html') def myview(): @@ -519,10 +562,7 @@ def myview(): str, Union[str, Callable[[ReturnRenderWithData], ResponseReturnValue]] ] default_mimetype: Optional[str] = None - if json: - templates = {'application/json': jsonify} - else: - templates = {} + templates = {'application/json': jsonify} if json else {} if isinstance(template, str): templates['*/*'] = template elif isinstance(template, dict): @@ -547,7 +587,7 @@ def myview(): template_mimetypes.remove('*/*') def decorator( - f: Callable[_VP, ReturnRenderWith] + f: Callable[_VP, ReturnRenderWith], ) -> Callable[_VP, WerkzeugResponse]: @wraps(f) def wrapper(*args: _VP.args, **kwargs: _VP.kwargs) -> WerkzeugResponse: @@ -558,7 +598,7 @@ def wrapper(*args: _VP.args, **kwargs: _VP.kwargs) -> WerkzeugResponse: result = ensure_sync(f)(*args, **kwargs) if not render or not request: - # Return value is not a WerkzeugResponse here + # Return value is not a BaseResponse here return result # type: ignore[return-value] # Is the result a Response object? Don't attempt rendering @@ -638,6 +678,18 @@ def wrapper(*args: _VP.args, **kwargs: _VP.kwargs) -> WerkzeugResponse: return decorator +class CorsDecoratorProtocol(Protocol): + @overload + def __call__( + self, __decorated: Callable[_VP, Awaitable[ResponseReturnValue]] + ) -> Callable[_VP, Awaitable[BaseResponse]]: ... + + @overload + def __call__( + self, __decorated: Callable[_VP, ResponseReturnValue] + ) -> Callable[_VP, BaseResponse]: ... + + def cors( origins: Union[Literal['*'], Container[str], Callable[[str], bool]], methods: Iterable[str] = ( @@ -657,7 +709,7 @@ def cors( 'X-Requested-With', ), max_age: Optional[int] = None, -) -> Callable[[Callable[_VP, ResponseReturnValue]], Callable[_VP, WerkzeugResponse]]: +) -> CorsDecoratorProtocol: """ Add CORS headers to the decorated view function. @@ -668,9 +720,9 @@ def cors( The :obj:`origins` parameter may be one of: - 1. A callable that receives the origin as a parameter. + 1. A callable that receives the origin as a parameter and returns True/False. 2. A list of origins. - 3. ``*``, indicating that this resource is accessible by any origin. + 3. Literal['*'], indicating that this resource is accessible by any origin. Example use:: @@ -679,60 +731,72 @@ def cors( app = Flask(__name__) + @app.route('/any') @cors('*') def any_origin(): return Response() + @app.route('/static', methods=['GET', 'POST']) @cors( ['https://hasgeek.com'], methods=['GET', 'POST'], headers=['Content-Type', 'X-Requested-With'], - max_age=3600) + max_age=3600, + ) def static_list(): return Response() + def check_origin(origin): # check if origin should be allowed return True + @app.route('/callable', methods=['GET']) @cors(check_origin) def callable_function(): return Response() """ + @overload def decorator( - f: Callable[_VP, ResponseReturnValue] - ) -> Callable[_VP, WerkzeugResponse]: - @wraps(f) - def wrapper(*args: _VP.args, **kwargs: _VP.kwargs) -> WerkzeugResponse: + f: Callable[_VP, Awaitable[ResponseReturnValue]], + ) -> Callable[_VP, Awaitable[BaseResponse]]: ... + + @overload + def decorator( + f: Callable[_VP, ResponseReturnValue], + ) -> Callable[_VP, BaseResponse]: ... + + def decorator( + f: Union[ + Callable[_VP, ResponseReturnValue], + Callable[_VP, Awaitable[ResponseReturnValue]], + ], + ) -> Union[Callable[_VP, BaseResponse], Callable[_VP, Awaitable[BaseResponse]]]: + def check_origin() -> Optional[str]: origin = request.headers.get('Origin') if not origin or origin == 'null': if request.method == 'OPTIONS': abort(400) - # If no Origin header is supplied, CORS checks don't apply - return make_response(ensure_sync(f)(*args, **kwargs)) + return None if request.method not in methods: abort(405) - if origins == '*': - pass - elif is_collection(origins) and origin in origins: # type: ignore[operator] - pass - elif callable(origins) and origins(origin): - pass - else: + if not ( + origins == '*' + or ( + is_collection(origins) and origin in origins # type: ignore[operator] + ) + or (callable(origins) and origins(origin)) + ): abort(403) + return origin - if request.method == 'OPTIONS': - # pre-flight request - resp = Response() - else: - resp = make_response(ensure_sync(f)(*args, **kwargs)) - + def set_headers(origin: str, resp: BaseResponse) -> BaseResponse: resp.headers['Access-Control-Allow-Origin'] = origin resp.headers['Access-Control-Allow-Methods'] = ', '.join(methods) resp.headers['Access-Control-Allow-Headers'] = ', '.join(headers) @@ -743,6 +807,36 @@ def wrapper(*args: _VP.args, **kwargs: _VP.kwargs) -> WerkzeugResponse: return resp + if iscoroutinefunction(f): + + @wraps(f) + async def wrapper(*args: _VP.args, **kwargs: _VP.kwargs) -> BaseResponse: + origin = check_origin() + if origin is None: + # If no Origin header is supplied, CORS checks don't apply + return make_response(await f(*args, **kwargs)) + if request.method == 'OPTIONS': + # pre-flight request + resp = Response() + else: + resp = make_response(await f(*args, **kwargs)) + return set_headers(origin, resp) + + else: + + @wraps(f) + def wrapper(*args: _VP.args, **kwargs: _VP.kwargs) -> BaseResponse: + origin = check_origin() + if origin is None: + # If no Origin header is supplied, CORS checks don't apply + return make_response(f(*args, **kwargs)) + if request.method == 'OPTIONS': + # pre-flight request + resp = Response() + else: + resp = make_response(f(*args, **kwargs)) + return set_headers(origin, resp) + wrapper.provide_automatic_options = False # type: ignore[attr-defined] wrapper.required_methods = ['OPTIONS'] # type: ignore[attr-defined] @@ -752,7 +846,7 @@ def wrapper(*args: _VP.args, **kwargs: _VP.kwargs) -> WerkzeugResponse: def requires_permission( - permission: Union[str, set[str]] + permission: Union[str, set[str]], ) -> Callable[[Callable[_VP, _VR_co]], Callable[_VP, _VR_co]]: """ Decorate to require a permission to be present in ``current_auth.permissions``. diff --git a/src/coaster/views/misc.py b/src/coaster/views/misc.py index b02ae5e3..321c717d 100644 --- a/src/coaster/views/misc.py +++ b/src/coaster/views/misc.py @@ -1,41 +1,19 @@ -""" -Miscellaneous view helpers --------------------------- - -Helper functions for view handlers. - -All items in this module can be imported directly from :mod:`coaster.views`. -""" +"""Miscellaneous view helpers.""" # pyright: reportMissingImports=false from __future__ import annotations -import asyncio import re from collections.abc import Container -from inspect import iscoroutinefunction -from typing import Any, Optional, Union, cast +from typing import Any, Optional, Union from urllib.parse import urlsplit -from flask import ( - Response, - current_app, - json, - request, - session as request_session, - url_for, -) +from flask import Response, json, session as request_session, url_for from werkzeug.exceptions import MethodNotAllowed, NotFound -from werkzeug.routing import MapAdapter, RequestRedirect, Rule - -try: - from asgiref.sync import async_to_sync -except ModuleNotFoundError: - async_to_sync = None # type: ignore[assignment, misc] - +from werkzeug.routing import RequestRedirect, Rule -from ..typing import WrappedFunc +from ..compat import async_request, current_app __all__ = ['get_current_url', 'get_next_url', 'jsonp', 'endpoint_for'] @@ -43,8 +21,8 @@ def _index_url() -> str: - if request: - return request.script_root or '/' + if async_request: + return async_request.script_root or '/' return '/' @@ -54,7 +32,7 @@ def _clean_external_url( """Allow external URLs if they match current request's hostname.""" # Do the domains and ports match? pnext = urlsplit(url) - preq = urlsplit(request.url) + preq = urlsplit(async_request.url) if pnext.scheme and pnext.scheme.lower() not in allowed_schemes: # Not an allowed scheme, quit return '' @@ -84,17 +62,17 @@ def get_current_url() -> str: if current_app.config.get('SERVER_NAME') and ( # Check current hostname against server name, ignoring port numbers, if any # (split on ':') - request.environ['HTTP_HOST'].split(':', 1)[0] + async_request.environ['HTTP_HOST'].split(':', 1)[0] != current_app.config['SERVER_NAME'].split(':', 1)[0] ): - return request.url + return async_request.url url = ( - url_for(request.endpoint, **(request.view_args or {})) - if request.endpoint is not None # Will be None in a 404 handler - else request.url + url_for(async_request.endpoint, **(async_request.view_args or {})) + if async_request.endpoint is not None # Will be None in a 404 handler + else async_request.url ) - query = request.query_string + query = async_request.query_string if query: return url + '?' + query.decode() return url @@ -118,18 +96,20 @@ def get_next_url( or the script root (typically ``/``). """ if session: - next_url = request_session.pop('next', None) or request.args.get('next', '') + next_url = request_session.pop('next', None) or async_request.args.get( + 'next', '' + ) else: - next_url = request.args.get('next', '') + next_url = async_request.args.get('next', '') if next_url and not external: next_url = _clean_external_url(next_url) if next_url: return next_url - if referrer and request.referrer: + if referrer and async_request.referrer: if external: - return request.referrer - return _clean_external_url(request.referrer) or ( + return async_request.referrer + return _clean_external_url(async_request.referrer) or ( default if default is not None else _index_url() ) return default if default is not None else _index_url() @@ -144,7 +124,7 @@ def jsonp(*args: Any, **kwargs: Any) -> Response: :func:`~coaster.views.decorators.cors` decorator. """ data = json.dumps(dict(*args, **kwargs), indent=2) - callback = request.args.get('callback', request.args.get('jsonp')) + callback = async_request.args.get('callback', async_request.args.get('jsonp')) if callback and __jsoncallback_re.search(callback) is not None: data = callback + '(' + data + ');' mimetype = 'application/javascript' @@ -155,7 +135,7 @@ def jsonp(*args: Any, **kwargs: Any) -> Response: def endpoint_for( url: str, - method: Optional[str] = None, + method: str = 'GET', return_rule: bool = False, follow_redirects: bool = True, ) -> tuple[Optional[Union[Rule, str]], dict[str, Any]]: @@ -174,29 +154,36 @@ def endpoint_for( # We require an absolute URL return None, {} - # Take the current runtime environment... - environ = dict(request.environ) - # ...but replace the HTTP host with the URL's host... - environ['HTTP_HOST'] = parsed_url.netloc - # ...and the path with the URL's path (after discounting the app path, if not - # hosted at root). - environ['PATH_INFO'] = parsed_url.path[len(environ.get('SCRIPT_NAME', '')) :] - # Create a new request with this environment... - url_request = current_app.request_class(environ) - # ...and a URL adapter with the new request. - url_adapter = cast(MapAdapter, current_app.create_url_adapter(url_request)) + use_host = current_app.config['SERVER_NAME'] or parsed_url.netloc + if async_request: + use_root_path = async_request.root_path + use_scheme = async_request.scheme + else: + use_root_path = '/' + use_scheme = 'https' + url_adapter = current_app.url_map.bind( + use_host, + use_root_path, + parsed_url.netloc[: -len(use_host) - 1] + if parsed_url.netloc.endswith('.' + use_host) + else None, + use_scheme, + method, + '/', + None, + ) # Run three hostname tests, one of which must pass: # 1. Does the URL map have host matching enabled? If so, the URL adapter will # validate the hostname. - if current_app.url_map.host_matching: + if current_app.url_map.host_matching: # noqa: SIM114 pass # 2. If not, does the domain match? url_adapter.server_name will prefer # app.config['SERVER_NAME'], but if that is not specified, it will take it from the # environment. - elif parsed_url.netloc == url_adapter.server_name: + elif parsed_url.netloc == url_adapter.server_name: # noqa: SIM114 pass # 3. If subdomain matching is enabled, does the subdomain match? @@ -229,20 +216,3 @@ def endpoint_for( pass # If we got here, no endpoint was found. return None, {} - - -def ensure_sync(func: WrappedFunc) -> WrappedFunc: - """Help use Flask's ensure_sync outside a request context.""" - if current_app: - return cast(WrappedFunc, current_app.ensure_sync(func)) - - if not iscoroutinefunction(func): - return func - - if async_to_sync is not None: - return async_to_sync(func) # type: ignore[return-value] - - return cast( # type: ignore[unreachable] - WrappedFunc, - lambda *args, **kwargs: (asyncio.run(func(*args, **kwargs))), - ) diff --git a/test_requirements.txt b/test_requirements.txt index 9cc4a25f..916aee54 100644 --- a/test_requirements.txt +++ b/test_requirements.txt @@ -4,10 +4,11 @@ psycopg Pygments pylint pytest +pytest-asyncio>=0.23 pytest-cov pytest-env -pytest-ignore-flaky pytest-rerunfailures pytest-socket pyyaml +quart toml diff --git a/tests/coaster_tests/__init__.py b/tests/coaster_tests/__init__.py index e69de29b..0b7f5dc6 100644 --- a/tests/coaster_tests/__init__.py +++ b/tests/coaster_tests/__init__.py @@ -0,0 +1 @@ +"""Coaster tests.""" diff --git a/tests/coaster_tests/alembic/versions/132231d12fcd_test.py b/tests/coaster_tests/alembic/versions/132231d12fcd_test.py index 60267992..bdfabb16 100644 --- a/tests/coaster_tests/alembic/versions/132231d12fcd_test.py +++ b/tests/coaster_tests/alembic/versions/132231d12fcd_test.py @@ -1,4 +1,6 @@ -"""Test Migration. +# noqa: INP001 +""" +Test Migration. Revision ID: 132231d12fcd Revises: None diff --git a/tests/coaster_tests/app_test.py b/tests/coaster_tests/app_test.py index 1c643ccf..32e60107 100644 --- a/tests/coaster_tests/app_test.py +++ b/tests/coaster_tests/app_test.py @@ -177,7 +177,7 @@ def test_logging_handler(self) -> None: for handler in self.another_app.logger.handlers: try: raise Exception # pylint: disable=broad-exception-raised - except Exception: # noqa: B902 # pylint: disable=W0703 + except Exception: # noqa: BLE001 # pylint: disable=broad-except formatter = handler.formatter if isinstance(formatter, LocalVarFormatter): formatter.formatException(sys.exc_info()) @@ -253,7 +253,8 @@ def test_key_rotation_wrapper() -> None: # The KeyRotationWrapper has a safety catch for when a string secret is provided with pytest.raises(ValueError, match="Secret keys must be a list"): KeyRotationWrapper( - itsdangerous.URLSafeSerializer, 'secret_key' # type: ignore[arg-type] + itsdangerous.URLSafeSerializer, + 'secret_key', # type: ignore[arg-type] ) # If KeyRotationWrapper somehow loses its engines, we'll get a RuntimeError instead diff --git a/tests/coaster_tests/assets_test.py b/tests/coaster_tests/assets_test.py index a8d83c6e..948e6871 100644 --- a/tests/coaster_tests/assets_test.py +++ b/tests/coaster_tests/assets_test.py @@ -1,6 +1,6 @@ """Tests for asset management helpers.""" -# pylint: disable=redefined-outer-name,pointless-statement +# pylint: disable=redefined-outer-name import json import logging @@ -153,7 +153,7 @@ def test_create_empty_manifest() -> None: assert manifest.get('random') is None assert manifest.get('random', 'default-value') == 'default-value' with pytest.raises(KeyError): - manifest['random'] + _ = manifest['random'] def test_unaffiliated_manifest(app1: Flask) -> None: @@ -167,7 +167,7 @@ def test_unaffiliated_manifest(app1: Flask) -> None: assert manifest.get('random') is None assert manifest.get('random', 'default-value') == 'default-value' with pytest.raises(KeyError): - manifest['random'] + _ = manifest['random'] @pytest.mark.parametrize( @@ -228,7 +228,7 @@ def test_load_manifest_from_file(app1: Flask) -> None: assert manifest.get('other.css') is None assert manifest.get('other.css', 'default-value') == 'default-value' with pytest.raises(KeyError): - manifest['other.css'] + _ = manifest['other.css'] def test_manifest_limited_to_app_with_context(app1: Flask, app2: Flask) -> None: @@ -287,7 +287,7 @@ def test_load_manifest_from_file_with_custom_basepath(app1: Flask) -> None: assert manifest.get('other.css') is None assert manifest.get('other.css', 'default-value') == 'default-value' with pytest.raises(KeyError): - manifest['other.css'] + _ = manifest['other.css'] def test_manifest_disable_substitutions(app1: Flask, app2: Flask) -> None: @@ -321,9 +321,9 @@ def test_manifest_disable_substitutions(app1: Flask, app2: Flask) -> None: assert manifest1['index.scss'] == 'test-index.css' assert manifest2['index.scss'] == '/nosub/test-index.css' with pytest.raises(KeyError): - manifest1['index.css'] + _ = manifest1['index.css'] with pytest.raises(KeyError): - manifest2['index.css'] + _ = manifest2['index.css'] def test_compiled_regex_substitutes(app1: Flask) -> None: @@ -555,14 +555,14 @@ def test_keyerror_caplog(caplog: pytest.LogCaptureFixture, app1: Flask) -> None: caplog.clear() # Without an app context, KeyError will not be logged with pytest.raises(KeyError): - manifest['does-not-exist.css'] + _ = manifest['does-not-exist.css'] assert caplog.record_tuples == [] with app1.app_context(): assert manifest['exists.css'] == 'asset-exists.css' # A successful lookup will not be logged assert caplog.record_tuples == [] with pytest.raises(KeyError): - manifest['does-not-exist.css'] + _ = manifest['does-not-exist.css'] assert caplog.record_tuples == [ ( __name__, diff --git a/tests/coaster_tests/auth_test.py b/tests/coaster_tests/auth_test.py index f98ac604..022abc5d 100644 --- a/tests/coaster_tests/auth_test.py +++ b/tests/coaster_tests/auth_test.py @@ -227,7 +227,7 @@ def test_current_auth_with_user_loaded( assert current_auth.is_authenticated # type: ignore[unreachable] assert current_auth assert current_auth.user is not None - assert current_auth.user == user # type: ignore[unreachable] + assert current_auth.user == user assert current_auth.actor == user @@ -284,7 +284,7 @@ def test_other_actor_authenticated(models: SimpleNamespace) -> None: @pytest.mark.usefixtures('request_ctx', 'login_manager') def test_auth_anchor() -> None: - """A request starts with zero anchors, but they can be added""" + """A request starts with zero anchors, but they can be added.""" assert not current_auth.anchors add_auth_anchor('test-anchor') assert current_auth.anchors @@ -299,7 +299,7 @@ def test_has_current_auth() -> None: assert not current_auth.is_placeholder assert not request_has_auth() # Invoke current_auth to check for a user - current_auth.is_anonymous # pylint: disable=W0104 + _anon = current_auth.is_anonymous assert request_has_auth() diff --git a/tests/coaster_tests/conftest.py b/tests/coaster_tests/conftest.py index 57638819..218c1482 100644 --- a/tests/coaster_tests/conftest.py +++ b/tests/coaster_tests/conftest.py @@ -2,10 +2,17 @@ # pylint: disable=redefined-outer-name +from __future__ import annotations + +import asyncio +import contextvars import sys +import traceback import unittest +from collections.abc import Coroutine, Generator from os import environ -from typing import cast +from pathlib import Path +from typing import Any, Optional, Union, cast import pytest import sqlalchemy as sa @@ -76,3 +83,87 @@ def tearDown(self) -> None: self.session.rollback() db.drop_all() self.ctx.pop() + + +# Patch for asyncio tests, adapted from +# https://github.com/Donate4Fun/donate4fun/blob/273a4e/tests/fixtures.py + + +class CustomEventLoopPolicy(asyncio.DefaultEventLoopPolicy): + def __init__(self, context: Optional[contextvars.Context]) -> None: + super().__init__() + self.context = context + + def task_factory( + self, + loop: asyncio.AbstractEventLoop, + factory: Union[Coroutine, Generator], + context: Optional[contextvars.Context] = None, + ) -> Task311: + if context is None: + context = self.context + stack = traceback.extract_stack() + for frame in stack[-2::-1]: + package_name = Path(frame.filename).parts[-2] + if package_name != 'asyncio': + if package_name == 'pytest_asyncio': + # This function was called from pytest_asyncio, use shared context + break + # This function was called from somewhere else, create context copy + context = None + break + return Task311(factory, loop=loop, context=context) + + def new_event_loop(self) -> asyncio.AbstractEventLoop: + loop = super().new_event_loop() + loop.set_task_factory(self.task_factory) + return loop + + +@pytest.fixture(scope='session') +def event_loop_policy() -> Generator[CustomEventLoopPolicy, Any, None]: + policy = CustomEventLoopPolicy(contextvars.copy_context()) + yield policy + policy.get_event_loop().close() + + +# pylint: disable=protected-access +class Task311(asyncio.tasks._PyTask): # type: ignore[name-defined] + """Backport of Task from CPython 3.11 for passing context from fixture to test.""" + + def __init__( + self, + coro: Union[Coroutine, Generator], + *, + loop: Optional[asyncio.AbstractEventLoop] = None, + name: Optional[str] = None, + context: Optional[contextvars.Context] = None, + ) -> None: + super( + asyncio.tasks._PyTask, # type: ignore[attr-defined] + self, + ).__init__(loop=loop) + if self._source_traceback: + del self._source_traceback[-1] + if not asyncio.coroutines.iscoroutine(coro): + # raise after Future.__init__(), attrs are required for __del__ + # prevent logging for pending task in __del__ + self._log_destroy_pending = False + raise TypeError(f"a coroutine was expected, got {coro!r}") + + if name is None: + self._name = f'Task-{asyncio.tasks._task_name_counter()}' # type: ignore[attr-defined] + else: + self._name = str(name) + + self._num_cancels_requested = 0 + self._must_cancel = False + self._fut_waiter = None + self._coro = coro + if context is None: + self._context = contextvars.copy_context() + else: + self._context = context + + self._loop.call_soon(self._Task__step, context=self._context) + asyncio.tasks._register_task(self) diff --git a/tests/coaster_tests/logger_test.py b/tests/coaster_tests/logger_test.py index d68cdc74..72e0a1ec 100644 --- a/tests/coaster_tests/logger_test.py +++ b/tests/coaster_tests/logger_test.py @@ -1,3 +1,5 @@ +"""Tests for Coaster's log formatter.""" + from io import StringIO from coaster.logger import RepeatValueIndicator, filtered_value, pprint_with_indent diff --git a/tests/coaster_tests/settings.py b/tests/coaster_tests/settings.py index c8929540..586448db 100644 --- a/tests/coaster_tests/settings.py +++ b/tests/coaster_tests/settings.py @@ -1,7 +1,5 @@ -""" -Note: This is a test config file used by test_app.py. -""" +"""Note: This is a test config file used by test_app.py.""" SETTINGS_KEY = 'settings' -SECRET_KEY = 'd vldvnvnvjn' # nosec +SECRET_KEY = 'd vldvnvnvjn' # nosec B105 # noqa: S105 SQLALCHEMY_DATABASE_URI = 'postgresql+psycopg://localhost/coaster_test' diff --git a/tests/coaster_tests/sqlalchemy_annotations_test.py b/tests/coaster_tests/sqlalchemy_annotations_test.py index 99911bb2..53ad39d2 100644 --- a/tests/coaster_tests/sqlalchemy_annotations_test.py +++ b/tests/coaster_tests/sqlalchemy_annotations_test.py @@ -1,3 +1,5 @@ +"""Test column annotations.""" + import warnings from typing import Any, Optional @@ -92,7 +94,7 @@ class PolymorphicParent(BaseMixin, Model): class PolymorphicChild(PolymorphicParent): __tablename__ = 'polymorphic_child' - id = sa.orm.mapped_column( # type: ignore[assignment] # noqa: A003 + id = sa.orm.mapped_column( # type: ignore[assignment] None, sa.ForeignKey('polymorphic_parent.id', ondelete='CASCADE'), primary_key=True, @@ -134,8 +136,7 @@ def test_annotation_in_annotations() -> None: assert issubclass(model, ModelBase) for annotation in (immutable, cached): assert ( - annotation.__name__ - in model.__column_annotations__ # type: ignore[attr-defined] + annotation.__name__ in model.__column_annotations__ # type: ignore[attr-defined] ) @@ -144,12 +145,10 @@ def test_attr_in_annotations() -> None: for model in (IdOnly, IdUuid, UuidOnly): assert issubclass(model, ModelBase) assert ( - 'is_immutable' - in model.__column_annotations__['immutable'] # type: ignore[attr-defined] + 'is_immutable' in model.__column_annotations__['immutable'] # type: ignore[attr-defined] ) assert ( - 'is_cached' - in model.__column_annotations__['cached'] # type: ignore[attr-defined] + 'is_cached' in model.__column_annotations__['cached'] # type: ignore[attr-defined] ) @@ -384,7 +383,7 @@ def test_polymorphic_immutable(self) -> None: child.also_immutable = 'yy' def test_synonym_annotation(self) -> None: - """The immutable annotation can be bypassed via synonyms""" + """The immutable annotation can be bypassed via synonyms.""" syna = SynonymAnnotation(col_regular='a', col_immutable='b') # The columns behave as expected: assert syna.col_regular == 'a' diff --git a/tests/coaster_tests/sqlalchemy_markdowncolumn_test.py b/tests/coaster_tests/sqlalchemy_markdowncolumn_test.py index d95dbf64..3a951f9c 100644 --- a/tests/coaster_tests/sqlalchemy_markdowncolumn_test.py +++ b/tests/coaster_tests/sqlalchemy_markdowncolumn_test.py @@ -2,8 +2,8 @@ from typing import Optional -from coaster.gfm import markdown from coaster.sqlalchemy import BaseMixin, MarkdownColumn +from coaster.utils import markdown from .conftest import AppTestCase, Model, db @@ -26,7 +26,7 @@ class MarkdownHtmlData(BaseMixin, Model): value_html: Optional[str] -def fake_markdown(text: str) -> str: +def fake_markdown(_text: str) -> str: return 'fake-markdown' diff --git a/tests/coaster_tests/sqlalchemy_models_test.py b/tests/coaster_tests/sqlalchemy_models_test.py index be3424f3..8373998b 100644 --- a/tests/coaster_tests/sqlalchemy_models_test.py +++ b/tests/coaster_tests/sqlalchemy_models_test.py @@ -7,7 +7,7 @@ from datetime import datetime, timedelta from time import sleep -from typing import Any, Optional +from typing import TYPE_CHECKING, Any, Optional from uuid import UUID import pytest @@ -71,7 +71,7 @@ class UnnamedDocument(BaseMixin, Model): class NamedDocument(BaseNameMixin, Model): __tablename__ = 'named_document' - reserved_names = ['new'] + reserved_names = ('new',) container_id: Mapped[Optional[int]] = sa.orm.mapped_column( sa.ForeignKey('container.id') ) @@ -83,7 +83,7 @@ class NamedDocument(BaseNameMixin, Model): class NamedDocumentBlank(BaseNameMixin, Model): __tablename__ = 'named_document_blank' __name_blank_allowed__ = True - reserved_names = ['new'] + reserved_names = ('new',) container_id: Mapped[Optional[int]] = sa.orm.mapped_column( sa.ForeignKey('container.id') ) @@ -94,7 +94,7 @@ class NamedDocumentBlank(BaseNameMixin, Model): class ScopedNamedDocument(BaseScopedNameMixin, Model): __tablename__ = 'scoped_named_document' - reserved_names = ['new'] + reserved_names = ('new',) container_id: Mapped[Optional[int]] = sa.orm.mapped_column( sa.ForeignKey('container.id') ) @@ -196,13 +196,13 @@ class User(BaseMixin, Model): class MyData(Model): __tablename__ = 'my_data' - id: Mapped[int] = sa.orm.mapped_column(sa.Integer, primary_key=True) # noqa: A003 + id: Mapped[int] = sa.orm.mapped_column(sa.Integer, primary_key=True) data: Mapped[Optional[dict]] = sa.orm.mapped_column(JsonDict) class MyUrlModel(Model): __tablename__ = 'my_url' - id: Mapped[int] = sa.orm.mapped_column(sa.Integer, primary_key=True) # noqa: A003 + id: Mapped[int] = sa.orm.mapped_column(sa.Integer, primary_key=True) url: Mapped[Optional[furl]] = sa.orm.mapped_column(UrlType) url_all_scheme: Mapped[Optional[furl]] = sa.orm.mapped_column(UrlType(schemes=None)) url_custom_scheme: Mapped[Optional[furl]] = sa.orm.mapped_column( @@ -229,7 +229,7 @@ class UuidKey(BaseMixin[UUID, Any], Model): class UuidKeyNoDefault(BaseMixin[UUID, Any], Model): __tablename__ = 'uuid_key_no_default' - id: Mapped[UUID] = sa.orm.mapped_column( # type: ignore[assignment] # noqa: A003 + id: Mapped[UUID] = sa.orm.mapped_column( # type: ignore[assignment] sa.Uuid, primary_key=True ) @@ -269,10 +269,11 @@ class UuidMixinKey(UuidMixin, BaseMixin[UUID, Any], Model): class ParentForPrimary(BaseMixin, Model): __tablename__ = 'parent_for_primary' - # The relationship must be explicitly defined for type hinting to work. - # add_primary_relationship will replace this with a fleshed-out relationship - # for SQLAlchemy configuration - primary_child: Mapped[Optional[ChildForPrimary]] = relationship() + if TYPE_CHECKING: + # The relationship must be explicitly defined for type hinting to work. + # add_primary_relationship will replace this with a fleshed-out relationship + # for SQLAlchemy configuration + primary_child: Mapped[Optional[ChildForPrimary]] = relationship() class ChildForPrimary(BaseMixin, Model): @@ -393,13 +394,12 @@ def test_named(self) -> None: 'invalid1', title='Invalid1', non_existent_field="I don't belong here." ) + NamedDocument.upsert('valid1', title='Valid1') + self.session.commit() with pytest.raises(TypeError): - NamedDocument.upsert('valid1', title='Valid1') - self.session.commit() NamedDocument.upsert( 'valid1', title='Invalid1', non_existent_field="I don't belong here." ) - self.session.commit() # TODO: Versions of this test are required for BaseNameMixin, # BaseScopedNameMixin, BaseIdNameMixin and BaseScopedIdNameMixin @@ -483,7 +483,6 @@ def test_scoped_named(self) -> None: title='Invalid1', non_existent_field="I don't belong here.", ) - self.session.commit() def test_scoped_named_short_title(self) -> None: """Test the short_title method of BaseScopedNameMixin.""" @@ -503,7 +502,7 @@ def test_scoped_named_short_title(self) -> None: assert d1.short_title == "Contained" def test_id_named(self) -> None: - """Documents with a global id in the URL""" + """Documents with a global id in the URL.""" c1 = self.make_container() d1 = IdNamedDocument(title="Hello", content="World", container=c1) self.session.add(d1) @@ -522,7 +521,7 @@ def test_id_named(self) -> None: assert d3.url_name == '3-hello' def test_scoped_id(self) -> None: - """Documents with a container-specific id in the URL""" + """Documents with a container-specific id in the URL.""" c1 = self.make_container() d1 = ScopedIdDocument(content="Hello", container=c1) u = User(username="foo") @@ -550,7 +549,7 @@ def test_scoped_id(self) -> None: assert d4.url_id == 3 def test_scoped_id_named(self) -> None: - """Documents with a container-specific id and name in the URL""" + """Documents with a container-specific id and name in the URL.""" c1 = self.make_container() d1 = ScopedIdNamedDocument(title="Hello", content="World", container=c1) self.session.add(d1) @@ -777,21 +776,21 @@ def test_urltype(self) -> None: assert str(m1.url_custom_scheme) == "ftp://example.com" def test_urltype_invalid(self) -> None: + m1 = MyUrlModel(url="example.com") + self.session.add(m1) with pytest.raises(StatementError): - m1 = MyUrlModel(url="example.com") - self.session.add(m1) self.session.commit() def test_urltype_invalid_without_scheme(self) -> None: + m2 = MyUrlModel(url="//example.com") + self.session.add(m2) with pytest.raises(StatementError): - m2 = MyUrlModel(url="//example.com") - self.session.add(m2) self.session.commit() def test_urltype_invalid_without_host(self) -> None: + m2 = MyUrlModel(url="https:///test") + self.session.add(m2) with pytest.raises(StatementError): - m2 = MyUrlModel(url="https:///test") - self.session.add(m2) self.session.commit() def test_urltype_empty(self) -> None: @@ -803,15 +802,15 @@ def test_urltype_empty(self) -> None: assert str(m1.url_custom_scheme) == "" def test_urltype_invalid_scheme_default(self) -> None: + m1 = MyUrlModel(url="magnet://example.com") + self.session.add(m1) with pytest.raises(StatementError): - m1 = MyUrlModel(url="magnet://example.com") - self.session.add(m1) self.session.commit() def test_urltype_invalid_scheme_custom(self) -> None: + m1 = MyUrlModel(url_custom_scheme="magnet://example.com") + self.session.add(m1) with pytest.raises(StatementError): - m1 = MyUrlModel(url_custom_scheme="magnet://example.com") - self.session.add(m1) self.session.commit() def test_urltype_optional_scheme(self) -> None: @@ -819,9 +818,9 @@ def test_urltype_optional_scheme(self) -> None: self.session.add(m1) self.session.commit() + m2 = MyUrlModel(url_optional_scheme="example.com/test") + self.session.add(m2) with pytest.raises(StatementError): - m2 = MyUrlModel(url_optional_scheme="example.com/test") - self.session.add(m2) self.session.commit() def test_urltype_optional_host(self) -> None: @@ -829,9 +828,9 @@ def test_urltype_optional_host(self) -> None: self.session.add(m1) self.session.commit() + m2 = MyUrlModel(url_optional_host="https:///test") + self.session.add(m2) with pytest.raises(StatementError): - m2 = MyUrlModel(url_optional_host="https:///test") - self.session.add(m2) self.session.commit() def test_urltype_optional_scheme_host(self) -> None: @@ -852,9 +851,7 @@ def test_query(self) -> None: Container.query.one_or_none() def test_failsafe_add(self) -> None: - """ - failsafe_add gracefully handles IntegrityError from dupe entries - """ + """`failsafe_add` gracefully handles IntegrityError from dupe entries.""" d1 = NamedDocument(name='add_and_commit_test', title="Test") d1a = failsafe_add(self.session, d1, name='add_and_commit_test') assert d1a is d1 # We got back what we created, so the commit succeeded @@ -865,9 +862,7 @@ def test_failsafe_add(self) -> None: assert d2a is d1 def test_failsafe_add_existing(self) -> None: - """ - failsafe_add doesn't fail if the item is already in the session - """ + """`failsafe_add` doesn't fail if the item is already in the session.""" d1 = NamedDocument(name='add_and_commit_test', title="Test") d1a = failsafe_add(self.session, d1, name='add_and_commit_test') assert d1a is d1 # We got back what we created, so the commit succeeded @@ -879,25 +874,18 @@ def test_failsafe_add_existing(self) -> None: assert d2a is d1 def test_failsafe_add_fail(self) -> None: - """ - failsafe_add passes through errors occuring from bad data - """ + """`failsafe_add` passes through errors occurring from bad data.""" d1 = NamedDocument(name='missing_title') with pytest.raises(IntegrityError): failsafe_add(self.session, d1, name='missing_title') def test_failsafe_add_silent_fail(self) -> None: - """ - failsafe_add does not raise IntegrityError with bad data - when no filters are provided - """ + """`failsafe_add` without filters does not raise IntegrityError.""" d1 = NamedDocument(name='missing_title') assert failsafe_add(self.session, d1) is None def test_uuid_key(self) -> None: - """ - Models with a UUID primary key work as expected - """ + """Models with a UUID primary key work as expected.""" u1 = UuidKey() u2 = UuidKey() self.session.add(u1) @@ -922,10 +910,10 @@ def test_uuid_key(self) -> None: def test_uuid_url_id(self) -> None: """ - IdMixin provides a url_id that renders as a string of either the - integer primary key or the UUID primary key. In addition, UuidMixin - provides a uuid_hex that always renders a UUID against either the - id or uuid columns. + IdMixin provides a url_id that renders as a string of int or UUID pkey. + + In addition, UuidMixin provides a uuid_hex that always renders a UUID against + either the id or uuid columns. """ # TODO: This test is a little muddled because UuidMixin renamed # its url_id property (which overrode IdMixin's url_id) to uuid_hex. @@ -1115,9 +1103,7 @@ def test_uuid_url_id(self) -> None: assert u4 == qu4 def test_uuid_buid_uuid_b58(self) -> None: - """ - UuidMixin provides uuid_b64 (also as buid) and uuid_b58 - """ + """UuidMixin provides uuid_b64 (also as buid) and uuid_b58.""" u1 = NonUuidMixinKey() u2 = UuidMixinKey() db.session.add_all([u1, u2]) @@ -1127,13 +1113,13 @@ def test_uuid_buid_uuid_b58(self) -> None: assert isinstance(u1.uuid, UUID) assert isinstance(u2.uuid, UUID) - # Test readbility of `buid` attribute + # Test readability of `buid` attribute assert u1.buid == uuid_to_base64(u1.uuid) assert len(u1.buid) == 22 # This is a 22-char B64 representation assert u2.buid == uuid_to_base64(u2.uuid) assert len(u2.buid) == 22 # This is a 22-char B64 representation - # Test readbility of `uuid_b58` attribute + # Test readability of `uuid_b58` attribute assert u1.uuid_b58 == uuid_to_base58(u1.uuid) assert len(u1.uuid_b58) in (21, 22) # 21 or 22-char B58 representation assert u2.uuid_b58 == uuid_to_base58(u2.uuid) @@ -1231,6 +1217,8 @@ def test_uuid_buid_uuid_b58(self) -> None: def test_uuid_url_id_name(self) -> None: """ + Check fields provided by BaseIdNameMixin. + BaseIdNameMixin models with UUID primary or secondary keys should generate properly formatted url_id, url_id_name and url_name_uuid_b58. The url_id_name and url_name_uuid_b58 fields should be queryable as well. @@ -1288,10 +1276,7 @@ def test_uuid_url_id_name(self) -> None: assert q58u3 == u3 def test_uuid_default(self) -> None: - """ - Models with a UUID primary or secondary key have a default value before - adding to session - """ + """UUID columns have a default value before database commit.""" uuid_no = NonUuidKey() uuid_yes = UuidKey() uuid_no_default = UuidKeyNoDefault() @@ -1312,16 +1297,14 @@ def test_uuid_default(self) -> None: assert u4 is None # UuidMixin works likewise - um1 = uuidm_no.uuid # type: ignore[unreachable] + um1 = uuidm_no.uuid assert isinstance(um1, UUID) - um2 = uuidm_yes.uuid # This should generate uuidm_yes.id + um2 = uuidm_yes.uuid # This should generate `uuidm_yes.id` assert isinstance(um2, UUID) assert uuidm_yes.id == uuidm_yes.uuid def test_parent_child_primary(self) -> None: - """ - Test parents with multiple children and a primary child - """ + """Test parents with multiple children and a primary child.""" parent1 = ParentForPrimary() parent2 = ParentForPrimary() child1a = ChildForPrimary(parent=parent1) @@ -1364,14 +1347,14 @@ def test_parent_child_primary(self) -> None: assert qparent2.primary_child == child2a # # A parent can't have a default that is another's child - with pytest.raises(ValueError): + with pytest.raises(ValueError, match="not affiliated with"): parent1.primary_child = child2b # The default hasn't changed despite the validation error assert parent1.primary_child == child1a - # Unsetting the default removes the relationship row, - # but does not remove the child instance from the db + # Clearing the default removes the relationship row, but does not remove the + # child instance from the db parent1.primary_child = None self.session.commit() assert ( @@ -1395,9 +1378,7 @@ def test_parent_child_primary(self) -> None: assert ParentForPrimary.query.count() == 2 def test_auto_init_default(self) -> None: - """ - Calling ``auto_init_default`` on a column makes it load defaults automatically - """ + """Calling ``auto_init_default`` sets the default on first access.""" d1 = DefaultValue() d2 = DefaultValue(value='not-default') d3 = DefaultValue() diff --git a/tests/coaster_tests/sqlalchemy_registry_test.py b/tests/coaster_tests/sqlalchemy_registry_test.py index 486b69b7..3e8878f5 100644 --- a/tests/coaster_tests/sqlalchemy_registry_test.py +++ b/tests/coaster_tests/sqlalchemy_registry_test.py @@ -2,6 +2,7 @@ # pylint: disable=redefined-outer-name,protected-access +from collections.abc import Callable from types import SimpleNamespace from typing import Any @@ -16,7 +17,7 @@ @pytest.fixture -def CallableRegistry(): # noqa: N802 +def CallableRegistry() -> type: """Callable registry with a positional parameter.""" class CallableRegistry: @@ -26,7 +27,7 @@ class CallableRegistry: @pytest.fixture -def PropertyRegistry(): # noqa: N802 +def PropertyRegistry() -> type: """Registry with property and a positional parameter.""" class PropertyRegistry: @@ -36,7 +37,7 @@ class PropertyRegistry: @pytest.fixture -def CachedPropertyRegistry(): # noqa: N802 +def CachedPropertyRegistry() -> type: """Registry with cached property and a positional parameter.""" class CachedPropertyRegistry: @@ -46,7 +47,7 @@ class CachedPropertyRegistry: @pytest.fixture -def CallableParamRegistry(): # noqa: N802 +def CallableParamRegistry() -> type: """Callable registry with a keyword parameter.""" class CallableParamRegistry: @@ -56,7 +57,7 @@ class CallableParamRegistry: @pytest.fixture -def PropertyParamRegistry(): # noqa: N802 +def PropertyParamRegistry() -> type: """Registry with property and a keyword parameter.""" class PropertyParamRegistry: @@ -66,7 +67,7 @@ class PropertyParamRegistry: @pytest.fixture -def CachedPropertyParamRegistry(): # noqa: N802 +def CachedPropertyParamRegistry() -> type: """Registry with cached property and a keyword parameter.""" class CachedPropertyParamRegistry: @@ -77,13 +78,13 @@ class CachedPropertyParamRegistry: @pytest.fixture def all_registry_hosts( - CallableRegistry, # noqa: N803 + CallableRegistry, PropertyRegistry, CachedPropertyRegistry, CallableParamRegistry, PropertyParamRegistry, CachedPropertyParamRegistry, -): +) -> list[type]: """All test registries as a list.""" return [ CallableRegistry, @@ -96,17 +97,17 @@ def all_registry_hosts( @pytest.fixture(scope='module') -def registry_member(): +def registry_member() -> Callable: """Test registry member function.""" - def member(pos=None, kwparam=None): + def member(pos=None, kwparam=None) -> Any: # noqa: ARG001 pass return member @pytest.fixture(scope='session') -def registrymixin_models(): +def registrymixin_models() -> SimpleNamespace: """Fixtures for RegistryMixin tests.""" # pylint: disable=possibly-unused-variable @@ -146,7 +147,7 @@ def __init__(self, obj: Any = None) -> None: # Sample registered item 3 @RegistryTest1.features('is1') @RegistryTest2.features() - def is1(obj): + def is1(obj) -> bool: """Assert object is instance of RegistryTest1.""" return isinstance(obj, RegistryTest1) @@ -242,13 +243,13 @@ def test_registry_property_cached_property() -> None: def test_add_to_registry( - CallableRegistry, # noqa: N803 + CallableRegistry, PropertyRegistry, CachedPropertyRegistry, CallableParamRegistry, PropertyParamRegistry, CachedPropertyParamRegistry, -): +) -> None: """A member can be added to registries and accessed as per registry settings.""" @CallableRegistry.registry() @@ -257,7 +258,7 @@ def test_add_to_registry( @CallableParamRegistry.registry() @PropertyParamRegistry.registry() @CachedPropertyParamRegistry.registry() - def member(pos=None, kwparam=None): + def member(pos=None, kwparam=None) -> Any: return (pos, kwparam) callable_host = CallableRegistry() @@ -279,35 +280,36 @@ def member(pos=None, kwparam=None): def test_property_cache_mismatch( - PropertyRegistry, CachedPropertyRegistry # noqa: N803 -): + PropertyRegistry, + CachedPropertyRegistry, +) -> None: """A registry's default setting must be explicitly turned off if conflicting.""" with pytest.raises(TypeError): @PropertyRegistry.registry(cached_property=True) - def member1(pos=None, kwparam=None): + def member1(pos=None, kwparam=None) -> Any: return (pos, kwparam) with pytest.raises(TypeError): @CachedPropertyRegistry.registry(property=True) - def member2(pos=None, kwparam=None): + def member2(pos=None, kwparam=None) -> Any: return (pos, kwparam) @PropertyRegistry.registry(cached_property=True, property=False) @CachedPropertyRegistry.registry(property=True, cached_property=False) - def member(pos=None, kwparam=None): + def member(pos=None, kwparam=None) -> Any: return (pos, kwparam) def test_add_to_registry_host( - CallableRegistry, # noqa: N803 + CallableRegistry, PropertyRegistry, CachedPropertyRegistry, CallableParamRegistry, PropertyParamRegistry, CachedPropertyParamRegistry, -): +) -> None: """A member can be added as a function, overriding default settings.""" @CallableRegistry.registry() @@ -316,7 +318,7 @@ def test_add_to_registry_host( @CallableParamRegistry.registry() @PropertyParamRegistry.registry(property=False) @CachedPropertyParamRegistry.registry(cached_property=False) - def member(pos=None, kwparam=None): + def member(pos=None, kwparam=None) -> Any: return (pos, kwparam) callable_host = CallableRegistry() @@ -338,13 +340,13 @@ def member(pos=None, kwparam=None): def test_add_to_registry_property( - CallableRegistry, # noqa: N803 + CallableRegistry, PropertyRegistry, CachedPropertyRegistry, CallableParamRegistry, PropertyParamRegistry, CachedPropertyParamRegistry, -): +) -> None: """A member can be added as a property, overriding default settings.""" @CallableRegistry.registry(property=True) @@ -353,7 +355,7 @@ def test_add_to_registry_property( @CallableParamRegistry.registry(property=True) @PropertyParamRegistry.registry(property=True) @CachedPropertyParamRegistry.registry(property=True, cached_property=False) - def member(pos=None, kwparam=None): + def member(pos=None, kwparam=None) -> Any: return (pos, kwparam) callable_host = CallableRegistry() @@ -375,13 +377,13 @@ def member(pos=None, kwparam=None): def test_add_to_registry_cached_property( - CallableRegistry, # noqa: N803 + CallableRegistry, PropertyRegistry, CachedPropertyRegistry, CallableParamRegistry, PropertyParamRegistry, CachedPropertyParamRegistry, -): +) -> None: """A member can be added as a property, overriding default settings.""" @CallableRegistry.registry(property=True) @@ -390,7 +392,7 @@ def test_add_to_registry_cached_property( @CallableParamRegistry.registry(property=True) @PropertyParamRegistry.registry(property=True) @CachedPropertyParamRegistry.registry(property=True, cached_property=False) - def member(pos=None, kwparam=None): + def member(pos=None, kwparam=None) -> Any: return (pos, kwparam) callable_host = CallableRegistry() @@ -411,7 +413,7 @@ def member(pos=None, kwparam=None): ) -def test_add_to_registry_custom_name(all_registry_hosts, registry_member): +def test_add_to_registry_custom_name(all_registry_hosts, registry_member) -> None: """Members can be added to a registry with a custom name.""" assert registry_member.__name__ == 'member' for host in all_registry_hosts: @@ -427,7 +429,7 @@ def test_add_to_registry_custom_name(all_registry_hosts, registry_member): assert host.registry.member is registry_member -def test_add_to_registry_underscore(all_registry_hosts, registry_member): +def test_add_to_registry_underscore(all_registry_hosts, registry_member) -> None: """Registry member names cannot start with an underscore.""" for host in all_registry_hosts: with pytest.raises(AttributeError): @@ -436,7 +438,7 @@ def test_add_to_registry_underscore(all_registry_hosts, registry_member): host.registry._new_member = registry_member -def test_add_to_registry_dupe(all_registry_hosts, registry_member): +def test_add_to_registry_dupe(all_registry_hosts, registry_member) -> None: """Registry member names cannot be duplicates of an existing name.""" for host in all_registry_hosts: host.registry()(registry_member) @@ -453,11 +455,11 @@ def test_add_to_registry_dupe(all_registry_hosts, registry_member): def test_cached_properties_are_cached( - PropertyRegistry, # noqa: N803 + PropertyRegistry, CachedPropertyRegistry, PropertyParamRegistry, CachedPropertyParamRegistry, -): +) -> None: """Cached properties are truly cached.""" # Register registry member @@ -465,7 +467,7 @@ def test_cached_properties_are_cached( @CachedPropertyRegistry.registry() @PropertyParamRegistry.registry() @CachedPropertyParamRegistry.registry() - def member(pos=None, kwparam=None): + def member(pos=None, kwparam=None) -> Any: return [pos, kwparam] # Lists are different each call property_host = PropertyRegistry() diff --git a/tests/coaster_tests/sqlalchemy_roles_test.py b/tests/coaster_tests/sqlalchemy_roles_test.py index ae1bb259..c6dd4f40 100644 --- a/tests/coaster_tests/sqlalchemy_roles_test.py +++ b/tests/coaster_tests/sqlalchemy_roles_test.py @@ -8,7 +8,7 @@ import json from collections.abc import Iterable, MutableSet, Sequence from dataclasses import dataclass -from typing import Any, Optional +from typing import Any, ClassVar, Optional import pytest import sqlalchemy as sa @@ -84,15 +84,18 @@ class RoleModel(DeclaredAttrMixin, RoleMixin, Model): # Approach one, declare roles in advance. # 'all' is a special role that is always granted from the base class - __roles__ = {'all': {'read': {'id', 'name', 'title'}}} + __roles__: ClassVar = {'all': {'read': {'id', 'name', 'title'}}} - __datasets__ = {'minimal': {'id', 'name'}, 'extra': {'id', 'name', 'mixed_in1'}} + __datasets__: ClassVar = { + 'minimal': {'id', 'name'}, + 'extra': {'id', 'name', 'mixed_in1'}, + } # Additional dataset members are defined using with_roles # Approach two, annotate roles on the attributes. # These annotations always add to anything specified in __roles__ - id: Mapped[int] = sa_orm.mapped_column(sa.Integer, primary_key=True) # noqa: A003 + id: Mapped[int] = sa_orm.mapped_column(sa.Integer, primary_key=True) # Specify read+write access name: Mapped[str] = with_roles(sa_orm.mapped_column(sa.Unicode(250)), rw={'owner'}) @@ -136,7 +139,7 @@ class AutoRoleModel(RoleMixin, Model): # This model doesn't specify __roles__. It only uses with_roles. # It should still work - id: Mapped[int] = sa_orm.mapped_column(sa.Integer, primary_key=True) # noqa: A003 + id: Mapped[int] = sa_orm.mapped_column(sa.Integer, primary_key=True) with_roles(id, read={'all'}) name: Mapped[Optional[str]] = sa_orm.mapped_column(sa.Unicode(250)) @@ -168,7 +171,7 @@ class RelationshipChild(BaseNameMixin, Model): ) parent: Mapped[RelationshipParent] = relationship(back_populates='children_list') - __roles__ = {'all': {'read': {'name', 'title', 'parent'}}} + __roles__: ClassVar = {'all': {'read': {'name', 'title', 'parent'}}} __datasets__ = { 'primary': {'name', 'title', 'parent'}, 'related': {'name', 'title'}, @@ -215,7 +218,7 @@ class RelationshipParent(BaseNameMixin, Model): ) ) - __roles__ = { + __roles__: ClassVar = { 'all': { 'read': { 'name', @@ -253,7 +256,7 @@ class RoleGrantMany(BaseMixin, Model): __tablename__ = 'role_grant_many' - __roles__ = { + __roles__: ClassVar = { 'primary_role': {'granted_by': ['primary_users']}, 'secondary_role': {'granted_by': ['secondary_users']}, } @@ -368,7 +371,7 @@ class MultiroleDocument(BaseMixin, Model): # Acquire role1 through both relationships (query and list relationships) # Acquire role2 and role3 via only one relationship each # This contrived setup is only to test that it works via all relationship types - __roles__ = { + __roles__: ClassVar = { 'parent_role': {'granted_via': {'parent': 'user'}}, 'parent_other_role': {'granted_via': {'parent': 'user'}}, 'role1': {'granted_via': {'rel_lazy': 'user', 'rel_list': 'user'}}, @@ -443,11 +446,11 @@ def default(self, o: Any) -> Any: class TestCoasterRoles(AppTestCase): def test_base_is_clean(self) -> None: - """Specifying roles never mutates RoleMixin.__roles__""" + """Specifying roles never mutates RoleMixin.__roles__.""" assert RoleMixin.__roles__ == {} def test_role_dict(self) -> None: - """Roles may be declared multiple ways and they all work""" + """Roles may be declared multiple ways and they all work.""" assert RoleModel.__roles__ == { 'all': {'call': {'hello'}, 'read': {'id', 'name', 'title', 'mixed_in2'}}, 'editor': {'read': {'mixed_in2'}, 'write': {'title', 'mixed_in2'}}, @@ -473,34 +476,31 @@ def test_role_dict(self) -> None: } def test_autorole_dict(self) -> None: - """A model without __roles__, using only with_roles, also works as expected""" + """A model without __roles__, using only with_roles, also works as expected.""" assert AutoRoleModel.__roles__ == { 'all': {'read': {'id', 'name'}}, 'owner': {'read': {'name'}, 'write': {'name'}}, } def test_basemixin_roles(self) -> None: - """A model with BaseMixin by default exposes nothing to the 'all' role""" + """A model with BaseMixin by default exposes nothing to the 'all' role.""" assert BaseModel.__roles__.get('all', {}).get('read', set()) == set() def test_uuidmixin_roles(self) -> None: - """ - A model with UuidMixin provides 'all' read access to uuid, uuid_b58 and uuid_b64 - among others. - """ + """A model with UuidMixin provides read access to 'all' role for some cols.""" assert 'read' in UuidModel.__roles__['all'] assert {'uuid', 'buid', 'uuid_b58', 'uuid_b64'} <= UuidModel.__roles__['all'][ 'read' ] def test_roles_for_anon(self) -> None: - """An anonymous actor should have 'all' and 'anon' roles""" + """An anonymous actor should have 'all' and 'anon' roles.""" rm = RoleModel(name='test', title='Test') roles = rm.roles_for(actor=None) assert roles == {'all', 'anon'} def test_roles_for_actor(self) -> None: - """An actor (but anchors) must have 'all' and 'auth' roles""" + """An actor (but anchors) must have 'all' and 'auth' roles.""" rm = RoleModel(name='test', title='Test') roles = rm.roles_for(actor=1) assert roles == {'all', 'auth'} @@ -508,13 +508,13 @@ def test_roles_for_actor(self) -> None: assert roles == {'all', 'anon'} def test_roles_for_owner(self) -> None: - """Presenting the correct anchor grants 'owner' role""" + """Presenting the correct anchor grants 'owner' role.""" rm = RoleModel(name='test', title='Test') roles = rm.roles_for(anchors=('owner-secret',)) assert roles == {'all', 'anon', 'owner'} def test_current_roles(self) -> None: - """Current roles are available""" + """Current roles are available.""" rm = RoleModel(name='test', title='Test') roles = rm.current_roles assert roles == {'all', 'anon'} @@ -523,21 +523,21 @@ def test_current_roles(self) -> None: assert not roles.owner def test_access_for_syntax(self) -> None: - """access_for can be called with either roles or actor for identical outcomes""" + """access_for can be called with either roles or actor for identical outcomes.""" rm = RoleModel(name='test', title='Test') proxy1 = rm.access_for(roles=rm.roles_for(actor=None)) proxy2 = rm.access_for(actor=None) assert proxy1 == proxy2 def test_access_for_all(self) -> None: - """All actors should be able to read some fields""" + """All actors should be able to read some fields.""" arm = AutoRoleModel(name='test') proxy = arm.access_for(actor=None) assert len(proxy) == 2 assert set(proxy.keys()) == {'id', 'name'} def test_current_access(self) -> None: - """Current access is available""" + """Current access is available.""" arm = AutoRoleModel(name='test') proxy = arm.current_access() assert len(proxy) == 2 @@ -550,7 +550,7 @@ def test_current_access(self) -> None: assert not roles.owner def test_json_protocol(self) -> None: - """Cast to JSON happens with __json__""" + """Cast to JSON happens with __json__.""" arm = AutoRoleModel(name='test') json_str = json.dumps(arm, cls=JsonProtocolEncoder) data = json.loads(json_str) @@ -561,7 +561,7 @@ def test_json_protocol(self) -> None: } def test_attr_dict_access(self) -> None: - """Proxies support identical attribute and dictionary access""" + """Proxies support identical attribute and dictionary access.""" rm = RoleModel(name='test', title='Test') proxy = rm.access_for(actor=None) assert 'name' in proxy @@ -569,7 +569,7 @@ def test_attr_dict_access(self) -> None: assert proxy['name'] == 'test' def test_diff_roles(self) -> None: - """Different roles get different access""" + """Different roles get different access.""" rm = RoleModel(name='test', title='Test') proxy1 = rm.access_for(roles={'all'}) proxy2 = rm.access_for(roles={'owner'}) @@ -595,7 +595,7 @@ def test_diff_roles(self) -> None: } def test_diff_roles_single_model_dataset(self) -> None: - """Data profiles constrain the attributes available via enumeration""" + """Data profiles constrain the attributes available via enumeration.""" rm = RoleModel(name='test', title='Test') proxy1a = rm.access_for(roles={'all'}, datasets=('minimal',)) proxy2a = rm.access_for(roles={'owner'}, datasets=('minimal',)) @@ -621,7 +621,7 @@ def test_diff_roles_single_model_dataset(self) -> None: assert RoleModel.__datasets__['third'] == {'title'} def test_write_without_read(self) -> None: - """A proxy may allow writes without allowing reads""" + """A proxy may allow writes without allowing reads.""" rm = RoleModel(name='test', title='Test') proxy = rm.access_for(roles={'owner'}) assert rm.title == 'Test' @@ -630,12 +630,12 @@ def test_write_without_read(self) -> None: proxy['title'] = 'Changed again' assert rm.title == 'Changed again' with pytest.raises(AttributeError): - proxy.title # pylint: disable=pointless-statement + _ = proxy.title with pytest.raises(KeyError): - proxy['title'] # pylint: disable=pointless-statement + _ = proxy['title'] def test_no_write(self) -> None: - """A proxy will disallow writes if the role doesn't permit it""" + """A proxy will disallow writes if the role doesn't permit it.""" rm = RoleModel(name='test', title='Test') proxy = rm.access_for(roles={'editor'}) assert rm.title == 'Test' @@ -651,7 +651,7 @@ def test_no_write(self) -> None: assert rm.name == 'test' def test_method_call(self) -> None: - """Method calls are allowed as calling is just an alias for reading""" + """Method calls are allowed as calling is just an alias for reading.""" rm = RoleModel(name='test', title='Test') proxy1 = rm.access_for(roles={'all'}) proxy2 = rm.access_for(roles={'owner'}) @@ -662,7 +662,7 @@ def test_method_call(self) -> None: proxy2['hello']() def test_dictionary_comparison(self) -> None: - """A proxy can be compared with a dictionary""" + """A proxy can be compared with a dictionary.""" rm = RoleModel(name='test', title='Test') proxy = rm.access_for(roles={'all'}) assert proxy == {'id': None, 'name': 'test', 'title': 'Test', 'mixed_in2': None} @@ -682,11 +682,11 @@ def test_bad_decorator(self) -> None: with pytest.raises(TypeError): @with_roles({'all'}) # type: ignore[operator] - def f(): + def f() -> None: pass def test_access_for_roles_and_actor_or_anchors(self) -> None: - """access_for accepts roles or actor/anchors, not both/all""" + """access_for accepts roles or actor/anchors, not both/all.""" rm = RoleModel(name='test', title='Test') with pytest.raises(TypeError): rm.access_for(roles={'all'}, actor=1) @@ -696,7 +696,7 @@ def test_access_for_roles_and_actor_or_anchors(self) -> None: rm.access_for(roles={'all'}, actor=1, anchors=('owner-secret',)) def test_scalar_relationship(self) -> None: - """Scalar relationships are automatically wrapped in an access proxy""" + """Scalar relationships are automatically wrapped in an access proxy.""" parent = RelationshipParent(title="Parent") child = RelationshipChild(title="Child", parent=parent) self.session.add_all([parent, child]) @@ -710,7 +710,7 @@ def test_scalar_relationship(self) -> None: # TODO: Test for other roles using the actor parameter def test_collection_relationship(self) -> None: - """Collection relationships are automatically wrapped in an access proxy""" + """Collection relationships are automatically wrapped in an access proxy.""" parent = RelationshipParent(title="Parent") child = RelationshipChild(title="Child", parent=parent) self.session.add_all([parent, child]) @@ -741,7 +741,7 @@ def test_collection_relationship(self) -> None: assert proxy.children_dict_column['child'].title == child.title def test_cascading_datasets(self) -> None: - """Test data profile cascades""" + """Test data profile cascades.""" parent = RelationshipParent(title="Parent") child = RelationshipChild(title="Child", parent=parent) self.session.add_all([parent, child]) @@ -800,7 +800,7 @@ def test_cascading_datasets(self) -> None: assert pchild.parent is not None def test_missing_dataset(self) -> None: - """A missing dataset will raise a KeyError indicating what is missing where""" + """A missing dataset will raise a KeyError indicating what is missing where.""" parent = RelationshipParent(title="Parent") self.session.add(parent) self.session.commit() @@ -909,7 +909,7 @@ def test_actors_with_invalid(self) -> None: next(m1.actors_with('owner')) # skipcq: PTC-W0063 def test_role_grant_synonyms(self) -> None: - """Test that synonyms reflect the underlying attribute""" + """Test that synonyms reflect the underlying attribute.""" rgs = RoleGrantSynonym(datacol='abc') assert rgs.datacol == 'abc' assert rgs.altcol == 'abc' @@ -974,7 +974,7 @@ def test_dynamic_association_proxy(self) -> None: assert parent1.children_names[child1.name] == child1 assert parent1.children_names[child2.name] == child2 with pytest.raises(KeyError): - parent1.children_names[child3.name] # pylint: disable=pointless-statement + _ = parent1.children_names[child3.name] assert parent1.children_names.get(child3.name) is None assert dict(parent1.children_names) == { child1.name: child1, @@ -990,7 +990,7 @@ def test_dynamic_association_proxy(self) -> None: p1b = parent1.children_names assert p1a is p1b assert p1a == p1b # Test __eq__ - assert not p1a != p1b # Test __ne__ + assert not (p1a != p1b) # Test __ne__ # noqa: SIM202 assert p1a != parent2.children_names # Cross-check with an unrelated proxy def test_dynamic_association_proxy_qattr(self) -> None: @@ -1033,7 +1033,7 @@ def test_dynamic_association_proxy_qattr(self) -> None: assert parent1.children_namesq[child1.name] == child1 assert parent1.children_namesq[child2.name] == child2 with pytest.raises(KeyError): - parent1.children_namesq[child3.name] # pylint: disable=pointless-statement + _ = parent1.children_namesq[child3.name] assert parent1.children_namesq.get(child3.name) is None assert dict(parent1.children_namesq) == { child1.name: child1, @@ -1049,13 +1049,11 @@ def test_dynamic_association_proxy_qattr(self) -> None: p1b = parent1.children_namesq assert p1a is p1b assert p1a == p1b # Test __eq__ - assert not p1a != p1b # Test __ne__ + assert not (p1a != p1b) # Test __ne__ # noqa: SIM202 assert p1a != parent2.children_namesq # Cross-check with an unrelated proxy def test_granted_via(self) -> None: - """ - Roles can be granted via related objects - """ + """Roles can be granted via related objects.""" u1 = RoleUser() u2 = RoleUser() u3 = RoleUser() @@ -1219,7 +1217,7 @@ def test_granted_via(self) -> None: assert 'parent_role_shared' in croles3 def test_granted_via_error(self) -> None: - """A misconfigured granted_via declaration will raise an error""" + """A misconfigured granted_via declaration will raise an error.""" user = RoleUser() document = MultiroleDocument() membership = RoleMembership(doc=document, user=user) @@ -1229,9 +1227,7 @@ def test_granted_via_error(self) -> None: _ = 'incorrectly_specified_role' in roles def test_actors_from_granted_via(self) -> None: - """ - actors_with will find actors whose roles are declared in granted_via - """ + """actors_with will find actors whose roles are declared in granted_via.""" u1 = RoleUser() u2 = RoleUser() u3 = RoleUser() @@ -1298,7 +1294,7 @@ def test_actors_from_granted_via(self) -> None: class TestLazyRoleSet: - """Tests for LazyRoleSet, isolated from RoleMixin""" + """Tests for LazyRoleSet, isolated from RoleMixin.""" class EmptyDocument(RoleMixin): # Test LazyRoleSet without the side effects of roles defined in the document @@ -1306,30 +1302,30 @@ class EmptyDocument(RoleMixin): class Document(RoleMixin): _user: Optional[TestLazyRoleSet.User] = None - _userlist = () - __roles__ = {'owner': {'granted_by': ['user', 'userlist']}} + _userlist: Sequence[TestLazyRoleSet.User] = () + __roles__: ClassVar = {'owner': {'granted_by': ['user', 'userlist']}} # Test flags accessed_user: bool = False accessed_userlist: bool = False @property - def user(self): + def user(self) -> Optional[TestLazyRoleSet.User]: self.accessed_user = True return self._user @user.setter - def user(self, value): + def user(self, value: Optional[TestLazyRoleSet.User]) -> None: self._user = value self.accessed_user = False @property - def userlist(self): + def userlist(self) -> Sequence[TestLazyRoleSet.User]: self.accessed_userlist = True return self._userlist @userlist.setter - def userlist(self, value): + def userlist(self, value: Sequence[TestLazyRoleSet.User]) -> None: self._userlist = value self.accessed_userlist = False @@ -1416,7 +1412,7 @@ def test_set_operations(self) -> None: assert r2 == r def test_has_any(self) -> None: - """Test the has_any method""" + """Test the has_any method.""" doc = self.Document() user = self.User() doc.user = user @@ -1532,9 +1528,7 @@ def test_inspectable_lazyroleset(self) -> None: assert d.accessed_userlist is True def test_offered_roles(self) -> None: - """ - Test that an object with an `offered_roles` method is a RoleGrantABC type - """ + """Test that an object with an `offered_roles` method is a RoleGrantABC type.""" role_membership = RoleMembership() assert issubclass(RoleMembership, RoleGrantABC) assert isinstance(role_membership, RoleGrantABC) @@ -1560,14 +1554,16 @@ class RoleCheckModel(RoleMixin[User]): created_by: TestConditionalRole.User owners: Optional[list[TestConditionalRole.User]] = None - __roles__ = { + __roles__: ClassVar = { 'creator': {'granted_via': {'created_by': None}}, 'owner': {'granted_by': ['owners']}, } @role_check('reader', 'viewer') # Takes multiple roles as necessary def has_reader_role( - self, actor: Optional[TestConditionalRole.User], anchors: Sequence[Any] = () + self, + actor: Optional[TestConditionalRole.User], + _anchors: Sequence[Any] = (), ) -> bool: # If this object is public, everyone gets the 'reader' role if self.public: @@ -1583,7 +1579,9 @@ def _(self) -> Iterable[TestConditionalRole.User]: @role_check('owner') def creator_is_owner( - self, actor: Optional[TestConditionalRole.User], anchors: Sequence[Any] = () + self, + actor: Optional[TestConditionalRole.User], + _anchors: Sequence[Any] = (), ) -> bool: """ Only validates the creator as owner, while failing everyone else. diff --git a/tests/coaster_tests/statemanager_test.py b/tests/coaster_tests/statemanager_test.py index f36e792b..2f48f9b3 100644 --- a/tests/coaster_tests/statemanager_test.py +++ b/tests/coaster_tests/statemanager_test.py @@ -184,7 +184,7 @@ def roles_for( @pytest.mark.filterwarnings("ignore::coaster.utils.classes.LabeledEnumWarning") -def test_check_constraint_labeledenum(): +def test_check_constraint_labeledenum() -> None: """Test check_constraint with a LabeledEnum.""" class TestEnum1(LabeledEnum): @@ -228,7 +228,7 @@ class TestEnumStr(LabeledEnum): ) -def test_check_constraint_enum(): +def test_check_constraint_enum() -> None: """Test check_constraint with an Enum.""" class TestEnumInt(enum.Enum): @@ -309,19 +309,17 @@ def test_state_already_exists(self) -> None: """Conditional state with the name of an existing state will raise an error.""" state = MyPost.__dict__['state'] with pytest.raises(AttributeError): - state.add_conditional_state('PENDING', state.DRAFT, lambda post: True) + state.add_conditional_state('PENDING', state.DRAFT, lambda _: True) def test_conditional_state_unmanaged_state(self) -> None: """Conditional states require a managed state as base.""" state = MyPost.__dict__['state'] reviewstate = MyPost.__dict__['reviewstate'] with pytest.raises(TypeError): - state.add_conditional_state( - 'TEST_STATE1', MY_STATE.DRAFT, lambda post: True - ) + state.add_conditional_state('TEST_STATE1', MY_STATE.DRAFT, lambda _: True) with pytest.raises(ValueError, match="not associated with this state manager"): state.add_conditional_state( - 'TEST_STATE2', reviewstate.UNSUBMITTED, lambda post: True + 'TEST_STATE2', reviewstate.UNSUBMITTED, lambda _: True ) def test_conditional_state_label(self) -> None: @@ -367,7 +365,7 @@ def test_has_state(self) -> None: def test_has_nonstate(self) -> None: """Test that StateManagerWrapper is only a proxy to StateManager's attrs.""" with pytest.raises(AttributeError): - self.post.state.does_not_exist # pylint: disable=pointless-statement + _ = self.post.state.does_not_exist assert isinstance(self.post.state.transition, types.MethodType) def test_readonly(self) -> None: @@ -390,9 +388,7 @@ def test_change_state_invalid(self) -> None: state._set_state_value(self.post, 100) def test_conditional_state(self) -> None: - """ - Conditional states include custom validators which are called to confirm the state - """ + """Conditional states include custom validators which are called to confirm the state.""" assert self.post.state.DRAFT assert not self.post.state.RECENT self.post._state = MY_STATE.PUBLISHED @@ -401,9 +397,7 @@ def test_conditional_state(self) -> None: assert not self.post.state.RECENT def test_bestmatch_state(self) -> None: - """ - The best matching state prioritises conditional over direct - """ + """The best matching state prioritises conditional over direct.""" assert self.post.state.DRAFT assert self.post.state.bestmatch() == self.post.state.DRAFT assert not self.post.state.RECENT @@ -423,7 +417,7 @@ def test_bestmatch_state(self) -> None: assert self.post.state.label.name == 'published' def test_added_state_group(self) -> None: - """Added state groups can be tested""" + """Added state groups can be tested.""" assert self.post.state.DRAFT # True because DRAFT state matches assert self.post.state.REDRAFTABLE @@ -435,26 +429,24 @@ def test_added_state_group(self) -> None: assert not self.post.state.REDRAFTABLE def test_state_group_invalid(self) -> None: - """add_state_group validates the states being added""" + """add_state_group validates the states being added.""" state = MyPost.__dict__['state'] reviewstate = MyPost.__dict__['reviewstate'] # Can't add an existing state name with pytest.raises(AttributeError): state.add_state_group('DRAFT', state.PENDING) # Can't add a state from another state manager - with pytest.raises(ValueError): + with pytest.raises(ValueError, match="Invalid state .* for state group"): state.add_state_group('OTHER', reviewstate.UNSUBMITTED) # Can't group a conditional state with the main state - with pytest.raises(ValueError): + with pytest.raises(ValueError, match="The value for state .* is already in"): state.add_state_group('MIXED1', state.PUBLISHED, state.RECENT) # Can't group a conditional state with group containing main state - with pytest.raises(ValueError): + with pytest.raises(ValueError, match="The value for state .* is already in"): state.add_state_group('MIXED2', state.PUBLISHED_AND_AFTER, state.RECENT) def test_sql_query_single_value(self) -> None: - """ - Different queries with the same state value work as expected - """ + """Different queries with the same state value work as expected.""" post1 = MyPost.query.filter(MyPost.state.DRAFT).first() assert post1 is not None assert post1.id == self.post.id @@ -467,9 +459,7 @@ def test_sql_query_single_value(self) -> None: assert post4.id == self.post.id def test_sql_query_multi_value(self) -> None: - """ - Same queries with different state values work as expected - """ + """Same queries with different state values work as expected.""" post1 = MyPost.query.filter(MyPost.state.UNPUBLISHED).first() assert post1 is not None assert post1.id == self.post.id @@ -479,9 +469,7 @@ def test_sql_query_multi_value(self) -> None: assert post2 is None def test_sql_query_added_state(self) -> None: - """ - Querying for an added state works as expected (with two filter conditions combined with and_) - """ + """Querying for an added state works as expected (with two filter conditions combined with and_).""" post1 = MyPost.query.filter(MyPost.state.RECENT).first() assert post1 is None self.post._state = MY_STATE.PUBLISHED @@ -491,9 +479,7 @@ def test_sql_query_added_state(self) -> None: assert post2.id == self.post.id def test_sql_query_state_group(self) -> None: - """ - Querying for a state group works as expected (with multiple filter conditions combined with or_) - """ + """Querying for a state group works as expected (with multiple filter conditions combined with or_).""" post1 = MyPost.query.filter(MyPost.state.REDRAFTABLE).first() assert post1 is not None assert post1.id == self.post.id @@ -508,9 +494,7 @@ def test_sql_query_state_group(self) -> None: assert post3 is None def test_transition_submit(self) -> None: - """ - `submit` transition works - """ + """`submit` transition works.""" assert self.post.state.value == MY_STATE.DRAFT self.post.submit() assert self.post.state.value == MY_STATE.PENDING @@ -521,9 +505,7 @@ def test_transition_submit(self) -> None: assert self.post.state.value == MY_STATE.PENDING def test_transition_publish_invalid(self) -> None: - """ - An exception in the transition aborts it - """ + """An exception in the transition aborts it.""" assert self.post.state.DRAFT with pytest.raises(AssertionError): # publish() should raise AssertionError if we're a draft (custom exception, not decorator's) @@ -532,9 +514,7 @@ def test_transition_publish_invalid(self) -> None: assert self.post.state.DRAFT def test_transition_publish_datetime(self) -> None: - """ - `publish` transition amends `datetime` - """ + """`publish` transition amends `datetime`.""" assert self.post.state.DRAFT self.post.submit() assert self.post.state.PENDING @@ -543,9 +523,7 @@ def test_transition_publish_datetime(self) -> None: assert self.post.published_at is not None def test_requires(self) -> None: - """ - The `requires` decorator behaves similarly to a transition, but doesn't state change - """ + """The `requires` decorator behaves similarly to a transition, but doesn't state change.""" assert self.post.state.DRAFT with pytest.raises(StateTransitionError): # Can only be called in published state @@ -560,9 +538,7 @@ def test_requires(self) -> None: assert self.post.published_at < d def test_state_labels(self) -> None: - """ - The current state's label can be accessed from the `.label` attribute - """ + """The current state's label can be accessed from the `.label` attribute.""" assert self.post.state.DRAFT assert self.post.state.label == "Draft" self.post.submit() @@ -570,9 +546,7 @@ def test_state_labels(self) -> None: assert self.post.state.label.title == "Pending" def test_added_state_transition(self) -> None: - """ - Transition works with added states as a `from` state - """ + """Transition works with added states as a `from` state.""" assert self.post.state.DRAFT self.post.submit() # Change from DRAFT to PENDING self.post.publish() # Change from PENDING to PUBLISHED @@ -589,9 +563,7 @@ def test_added_state_transition(self) -> None: self.post.undo() def test_added_regular_state_transition(self) -> None: - """ - Transitions work with mixed use of regular and added states in the `from` state - """ + """Transitions work with mixed use of regular and added states in the `from` state.""" assert self.post.state.DRAFT self.post.submit() # Change from DRAFT to PENDING assert self.post.state.PENDING @@ -614,7 +586,7 @@ def test_added_regular_state_transition(self) -> None: self.post.redraft() def test_reviewstate_also_changes(self) -> None: - """Transitions with two decorators change state on both managers""" + """Transitions with two decorators change state on both managers.""" assert self.post.state.DRAFT assert self.post.reviewstate.UNSUBMITTED self.post.submit() # This changes only `state` @@ -628,7 +600,7 @@ def test_reviewstate_also_changes(self) -> None: assert self.post.reviewstate.PENDING def test_transition_state_lock(self) -> None: - """Both states must be in valid state for a transition to be available""" + """Both states must be in valid state for a transition to be available.""" self.post.submit() assert self.post.state.PENDING assert self.post.reviewstate.UNSUBMITTED @@ -676,7 +648,7 @@ def test_transition_abort(self) -> None: assert self.post.state.PUBLISHED # state has changed def test_transition_is_available(self) -> None: - """A transition's is_available property is reliable""" + """A transition's is_available property is reliable.""" assert self.post.state.DRAFT assert self.post.submit.is_available self.post.submit() @@ -691,7 +663,7 @@ def test_transition_is_available(self) -> None: assert not self.post.undo.is_available def test_transition_data(self) -> None: - """Additional data defined on a transition works regardless of decorator order""" + """Additional data defined on a transition works regardless of decorator order.""" # Titles are defined on different decorators on these: assert self.post.publish.data['title'] == "Publish" assert self.post.undo.data['title'] == "Undo" @@ -700,28 +672,28 @@ def test_transition_data(self) -> None: assert MyPost.undo.data['title'] == "Undo" def test_transition_data_name_invalid(self) -> None: - """The `name` data field on transitions is reserved and cannot be specified""" + """The `name` data field on transitions is reserved and cannot be specified.""" state = MyPost.__dict__['state'] with pytest.raises(TypeError): @state.transition(None, state.DRAFT, name='invalid_data_field') - def name_test(self): + def name_test(self) -> None: # noqa: ARG001 pass def test_duplicate_transition(self) -> None: - """Transitions can't be decorated twice with the same state manager""" + """Transitions can't be decorated twice with the same state manager.""" state = MyPost.__dict__['state'] with pytest.raises(TypeError): @state.transition(state.DRAFT, state.PENDING) @state.transition(state.PENDING, state.PUBLISHED) - def dupe_decorator(self): + def dupe_decorator(self) -> None: # noqa: ARG001 pass state.transitions.remove('dupe_decorator') def test_available_transitions(self) -> None: - """State managers indicate the currently available transitions""" + """State managers indicate the currently available transitions.""" assert self.post.state.DRAFT assert 'submit' in self.post.state.transitions(current=False) self.post.state.transitions(current=False)['submit']() @@ -729,7 +701,7 @@ def test_available_transitions(self) -> None: assert self.post.state.PENDING def test_available_transitions_order(self) -> None: - """State managers maintain the order of transitions from the class definition""" + """State managers maintain the order of transitions from the class definition.""" assert self.post.state.DRAFT # `submit` must come before `publish` assert list(self.post.state.transitions(current=False).keys())[:2] == [ @@ -738,7 +710,7 @@ def test_available_transitions_order(self) -> None: ] def test_currently_available_transitions(self) -> None: - """State managers indicate the currently available transitions (using current_auth)""" + """State managers indicate the currently available transitions (using current_auth).""" assert self.post.state.DRAFT assert 'submit' not in self.post.state.transitions() # Add a user using the string 'author' (see MyPost.roles_for) @@ -749,7 +721,7 @@ def test_currently_available_transitions(self) -> None: assert self.post.state.PENDING def test_available_transitions_for(self) -> None: - """State managers indicate the currently available transitions (using access_for)""" + """State managers indicate the currently available transitions (using access_for).""" assert self.post.state.DRAFT assert 'submit' not in self.post.state.transitions_for(roles={'reviewer'}) assert 'submit' in self.post.state.transitions_for(roles={'author'}) @@ -758,7 +730,7 @@ def test_available_transitions_for(self) -> None: assert self.post.state.PENDING def test_current_states(self) -> None: - """All states that are currently active""" + """All states that are currently active.""" current = self.post.state.current() assert set(current.keys()) == {'DRAFT', 'UNPUBLISHED', 'REDRAFTABLE'} assert current['DRAFT'] @@ -769,23 +741,23 @@ def test_current_states(self) -> None: MyPost.state.current() def test_managed_state_wrapper(self) -> None: - """ManagedStateWrapper will only wrap a managed state or group""" + """ManagedStateWrapper will only wrap a managed state or group.""" draft = MyPost.__dict__['state'].DRAFT wdraft = ManagedStateInstance(draft, self.post) assert draft.value == wdraft.value assert wdraft # Object is falsy - assert self.post.state.DRAFT == wdraft + assert wdraft == self.post.state.DRAFT self.post.submit() assert not wdraft # Object remains the same even if not active - assert self.post.state.DRAFT == wdraft - assert self.post.state.PENDING != wdraft # These objects don't match + assert wdraft == self.post.state.DRAFT + assert wdraft != self.post.state.PENDING # These objects don't match with pytest.raises(TypeError): ManagedStateInstance(MY_STATE.DRAFT, self.post) # type: ignore[arg-type] def test_role_proxy_transitions(self) -> None: - """with_roles works on the transition decorator""" + """with_roles works on the transition decorator.""" assert self.post.state.DRAFT # Create access proxies for each of these roles author = self.post.access_for(roles={'author'}) @@ -824,13 +796,13 @@ def test_group_by_state(self) -> None: self.session.commit() groups1 = MyPost.state.group(MyPost.query.all()) # Order is preserved. Draft before Published. No Pending. - assert [g.label for g in groups1.keys()] == [ + assert [g.label for g in groups1] == [ MY_STATE[MY_STATE.DRAFT], MY_STATE[MY_STATE.PUBLISHED], ] # Order is preserved. Draft before Pending before Published. groups2 = MyPost.state.group(MyPost.query.all(), keep_empty=True) - assert [g.label for g in groups2.keys()] == [ + assert [g.label for g in groups2] == [ MY_STATE[MY_STATE.DRAFT], MY_STATE[MY_STATE.PENDING], MY_STATE[MY_STATE.PUBLISHED], diff --git a/tests/coaster_tests/testing.py b/tests/coaster_tests/testing.py index d4a5ca3e..5e4eab92 100644 --- a/tests/coaster_tests/testing.py +++ b/tests/coaster_tests/testing.py @@ -1,5 +1,3 @@ -""" -Configuration used by coaster test suite -""" +"""Configuration used by coaster test suite.""" TEST_KEY = 'test' diff --git a/tests/coaster_tests/url_for_test.py b/tests/coaster_tests/url_for_test.py index 11ecc315..a189335b 100644 --- a/tests/coaster_tests/url_for_test.py +++ b/tests/coaster_tests/url_for_test.py @@ -25,19 +25,19 @@ @app1.route('/') @NamedDocument.is_url_for('view', doc='name') -def doc_view(doc): +def doc_view(doc) -> str: return f'view {doc}' @app1.route('//edit') @NamedDocument.is_url_for('edit', doc='name') -def doc_edit(doc): +def doc_edit(doc) -> str: return f'edit {doc}' @app1.route('//upper') @NamedDocument.is_url_for('upper', doc=lambda d: d.name.upper()) -def doc_upper(doc): +def doc_upper(doc) -> str: return f'upper {doc}' @@ -46,13 +46,13 @@ def doc_upper(doc): # to the parameter given to `NamedDocument.url_for` in the test below. @app1.route('//with/') @NamedDocument.is_url_for('with', doc='name', other='**other.name') -def doc_with(doc, other): +def doc_with(doc, other) -> str: return f'{doc} with {other}' @app1.route('//') @ScopedNamedDocument.is_url_for('view', container='parent.id', doc='name') -def sdoc_view(container, doc): +def sdoc_view(container, doc) -> str: return f'view {container} {doc}' @@ -60,25 +60,25 @@ def sdoc_view(container, doc): @ScopedNamedDocument.is_url_for( 'edit', _external=True, container=('parent', 'id'), doc='name' ) -def sdoc_edit(container, doc): +def sdoc_edit(container, doc) -> str: return f'edit {container} {doc}' @app1.route('//app_only') @NamedDocument.is_url_for('app_only', None, app1, doc='name') -def doc_app_only(doc): +def doc_app_only(doc) -> str: return f'app_only {doc}' @app1.route('//app1') @NamedDocument.is_url_for('per_app', None, app1, doc='name') -def doc_per_app1(doc): +def doc_per_app1(doc) -> str: return f'per_app {doc}' @app2.route('//app2') @NamedDocument.is_url_for('per_app', None, app2, doc='name') -def doc_per_app2(doc): +def doc_per_app2(doc) -> str: return f'per_app {doc}' @@ -138,7 +138,7 @@ def test_url_for(self) -> None: ) def test_absolute_url(self) -> None: - """The .absolute_url property is the same as .url_for(_external=True)""" + """The .absolute_url property is the same as .url_for(_external=True).""" # Make two documents doc1 = NamedDocument(name='document1', title="Document 1") self.session.add(doc1) @@ -154,18 +154,18 @@ def test_absolute_url(self) -> None: assert doc2.absolute_url != doc2.url_for(_external=False) def test_absolute_url_missing(self) -> None: - """The .absolute_url property exists on all UrlForMixin-models, even if there is no view""" + """The .absolute_url property exists on all UrlForMixin-models, even if there is no view.""" c1 = Container() assert c1.absolute_url is None def test_absolute_url_in_access_proxy(self) -> None: - """The .absolute_url property does not have a default access role""" + """The .absolute_url property does not have a default access role.""" c1 = Container() d = c1.access_for(roles={'all'}) assert 'absolute_url' not in d def test_per_app(self) -> None: - """Allow app-specific URLs for the same action name""" + """Allow app-specific URLs for the same action name.""" doc1 = NamedDocument(name='document1', title="Document 1") self.session.add(doc1) self.session.commit() @@ -174,7 +174,7 @@ def test_per_app(self) -> None: assert doc1.url_for('per_app') == '/document1/app1' def test_app_only(self) -> None: - """Allow URLs to only be available in one app""" + """Allow URLs to only be available in one app.""" doc1 = NamedDocument(name='document1', title="Document 1") self.session.add(doc1) self.session.commit() @@ -183,7 +183,7 @@ def test_app_only(self) -> None: assert doc1.url_for('app_only') == '/document1/app_only' def test_linked_doc(self) -> None: - """URLs linking two unrelated models are possible""" + """URLs linking two unrelated models are possible.""" doc1 = NamedDocument(name='document1', title="Document 1") doc2 = NamedDocument(name='document2', title="Document 2") self.session.add_all([doc1, doc2]) @@ -201,7 +201,7 @@ def test_url_dict(self) -> None: assert doc1.urls != {} assert doc1.urls['view'] == 'http://localhost/document1' with pytest.raises(KeyError): - doc1.urls['random'] # pylint: disable=pointless-statement + _ = doc1.urls['random'] # The len() count includes the doc_with view, but it is excluded from actual # enumeration because it requires additional keyword parameters, which cannot @@ -220,7 +220,7 @@ class TestUrlFor2(TestUrlForBase): app = app2 def test_per_app(self) -> None: - """Allow app-specific URLs for the same action name""" + """Allow app-specific URLs for the same action name.""" doc1 = NamedDocument(name='document1', title="Document 1") self.session.add(doc1) self.session.commit() @@ -229,7 +229,7 @@ def test_per_app(self) -> None: assert doc1.url_for('per_app') == '/document1/app2' def test_app_only(self) -> None: - """Allow URLs to only be available in one app""" + """Allow URLs to only be available in one app.""" doc1 = NamedDocument(name='document1', title="Document 1") self.session.add(doc1) self.session.commit() diff --git a/tests/coaster_tests/utils_classes_dataclass_test.py b/tests/coaster_tests/utils_classes_dataclass_test.py index eae03355..5a286ac8 100644 --- a/tests/coaster_tests/utils_classes_dataclass_test.py +++ b/tests/coaster_tests/utils_classes_dataclass_test.py @@ -2,6 +2,8 @@ # pylint: disable=redefined-outer-name,unused-variable +from __future__ import annotations + import pickle # nosec B403 from dataclasses import FrozenInstanceError, dataclass from enum import Enum @@ -14,16 +16,22 @@ @dataclass(frozen=True, eq=False) class StringMetadata(DataclassFromType, str): + """String with metadata.""" + description: str extra: Optional[str] = None @dataclass(frozen=True, eq=False) class IntMetadata(DataclassFromType, int): + """Int with metadata.""" + title: str class MetadataEnum(StringMetadata, Enum): + """Enum with metadata.""" + FIRST = "first", "First string" SECOND = "second", "Second string", "Optional extra" @@ -79,7 +87,7 @@ def test_immutable_data_type() -> None: """The data type must be immutable.""" class Immutable(DataclassFromType, tuple): # skipcq: PTC-W0065 - pass + __slots__ = () with pytest.raises(TypeError, match="data type must be immutable"): @@ -94,9 +102,9 @@ def test_annotated_str( assert a == 'a' assert b == 'b' assert b2 == 'b' - assert 'a' == a - assert 'b' == b - assert 'b' == b2 + assert 'a' == a # noqa: SIM300 + assert 'b' == b # noqa: SIM300 + assert 'b' == b2 # noqa: SIM300 assert a != b assert a != b2 assert b != a @@ -133,7 +141,7 @@ def test_dataclass_fields_set( a.self = 'b' # type: ignore[misc] -def test_dict_keys(a: StringMetadata, b: StringMetadata, b2: StringMetadata) -> None: +def test_dict_keys(a: StringMetadata, b: StringMetadata) -> None: """DataclassFromType-based dataclasses can be used as dict keys.""" d: dict[Any, Any] = {a: a.description, b: b.description} assert d['a'] == a.description @@ -151,16 +159,16 @@ def test_dict_overlap(a: StringMetadata) -> None: assert len(d1) == 1 assert d1['a'] == "Overlap" assert d2['a'] == "Overlap" - assert isinstance(list(d1.keys())[0], str) - assert isinstance(list(d2.keys())[0], str) - assert not isinstance(list(d1.keys())[0], StringMetadata) # Retained str - assert isinstance(list(d2.keys())[0], StringMetadata) # Retained StringMetadata + assert isinstance(next(iter(d1.keys())), str) + assert isinstance(next(iter(d2.keys())), str) + assert not isinstance(next(iter(d1.keys())), StringMetadata) # Retained str + assert isinstance(next(iter(d2.keys())), StringMetadata) # Retained StringMetadata def test_pickle(a: StringMetadata) -> None: """Pickle dump and load will reconstruct the full dataclass.""" p = pickle.dumps(a) - a2 = pickle.loads(p) # nosec B301 + a2 = pickle.loads(p) # nosec B301 # noqa: S301 assert isinstance(a2, StringMetadata) assert a2 == a assert a2.self == 'a' diff --git a/tests/coaster_tests/utils_markdown_test.py b/tests/coaster_tests/utils_markdown_test.py index 1727af2d..df68988b 100644 --- a/tests/coaster_tests/utils_markdown_test.py +++ b/tests/coaster_tests/utils_markdown_test.py @@ -1,4 +1,6 @@ -from coaster.gfm import markdown +"""Test Markdown.""" + +from coaster.utils import markdown sample_markdown = ''' This is a sample piece of text and represents a paragraph. diff --git a/tests/coaster_tests/utils_test.py b/tests/coaster_tests/utils_test.py index 4549748f..3801451f 100644 --- a/tests/coaster_tests/utils_test.py +++ b/tests/coaster_tests/utils_test.py @@ -1,3 +1,5 @@ +"""Test utility functions.""" + import datetime import unittest from collections.abc import Iterator, MutableSet @@ -198,13 +200,13 @@ def test_deobfuscate_email(self) -> None: assert deobfuscate_email(in_text) == out_text def test_isoweek_datetime_all_timezones(self) -> None: - """Test that isoweek_datetime works for all timezones""" + """Test that isoweek_datetime works for all timezones.""" for timezone in common_timezones: for week in range(53): isoweek_datetime(2017, week + 1, timezone) def test_midnight_to_utc_all_timezones(self) -> None: - """Test that midnight_to_utc works for all timezones""" + """Test that midnight_to_utc works for all timezones.""" for timezone in common_timezones: for day in range(365): midnight_to_utc( @@ -212,7 +214,7 @@ def test_midnight_to_utc_all_timezones(self) -> None: ) def test_utcnow(self) -> None: - """Test that Coaster's utcnow works correctly""" + """Test that Coaster's utcnow works correctly.""" # Get date from function being tested now1 = utcnow() # Get date from Python stdlib diff --git a/tests/coaster_tests/nlp_test.py b/tests/coaster_tests/utils_text_test.py similarity index 83% rename from tests/coaster_tests/nlp_test.py rename to tests/coaster_tests/utils_text_test.py index 5d8f3965..c43a38a3 100644 --- a/tests/coaster_tests/nlp_test.py +++ b/tests/coaster_tests/utils_text_test.py @@ -1,4 +1,4 @@ -import unittest +"""Tests for text utilities.""" from coaster.utils import text_blocks @@ -42,7 +42,6 @@ ] -class TestExtractText(unittest.TestCase): - def test_extract_text(self) -> None: - tb = text_blocks(sample_html, skip_pre=True) - assert tb == sample_text_blocks +def test_extract_text() -> None: + tb = text_blocks(sample_html, skip_pre=True) + assert tb == sample_text_blocks diff --git a/tests/coaster_tests/views_classview_test.py b/tests/coaster_tests/views_classview_test.py index 64922d5e..63cb1bb1 100644 --- a/tests/coaster_tests/views_classview_test.py +++ b/tests/coaster_tests/views_classview_test.py @@ -5,8 +5,8 @@ from __future__ import annotations import unittest -from collections.abc import Sequence -from typing import Any, Optional +from collections.abc import Mapping, Sequence +from typing import Any, ClassVar, Optional import pytest import sqlalchemy as sa @@ -57,7 +57,7 @@ class ViewDocument(BaseNameMixin, Model): __tablename__ = 'view_document' - __roles__ = {'all': {'read': {'name', 'title'}}} + __roles__: ClassVar = {'all': {'read': {'name', 'title'}}} children: Mapped[list[ScopedViewDocument]] = relationship( cascade='all, delete-orphan', back_populates='view_document' @@ -96,7 +96,7 @@ class ScopedViewDocument(BaseScopedNameMixin, Model): ) parent = sa.orm.synonym('view_document') - __roles__ = {'all': {'read': {'name', 'title', 'doctype'}}} + __roles__: ClassVar = {'all': {'read': {'name', 'title', 'doctype'}}} @property def doctype(self) -> str: @@ -106,7 +106,7 @@ def doctype(self) -> str: # Use serial int pkeys so that we can get consistent `1-` url_name in tests class RenameableDocument(BaseIdNameMixin[int, Any], Model): __tablename__ = 'renameable_document' - __roles__ = {'all': {'read': {'name', 'title'}}} + __roles__: ClassVar = {'all': {'read': {'name', 'title'}}} # --- Views ---------------------------------------------------------------------------- @@ -142,8 +142,7 @@ def current_method_is_bound(self) -> str: @route('view_args//') def view_args_are_received(self, **kwargs) -> str: - # pylint: disable=consider-using-f-string - return '{one}/{two}'.format(**self.view_args) + return '{one}/{two}'.format(**kwargs) IndexView.init_app(app) @@ -155,7 +154,7 @@ class DocumentView(ClassView): @route('') @render_with(json=True) - def view(self, name: str): + def view(self, name: str) -> Mapping[str, Any]: """View the document.""" document = ViewDocument.query.filter_by(name=name).first_or_404() return document.current_access() @@ -164,7 +163,7 @@ def view(self, name: str): @route('/edit/', methods=['POST']) # Maps to /edit/ @route('', methods=['POST']) # Maps to /doc/ @requestform('title') - def edit(self, name: str, title: str): + def edit(self, name: str, title: str) -> str: """Edit the document.""" document = ViewDocument.query.filter_by(name=name).first_or_404() document.title = title @@ -207,16 +206,16 @@ class SubView(BaseView): @viewdata(title="Still first") @BaseView.first.replace - def first(self): + def first(self) -> str: return 'replaced-first' @route('2') @BaseView.second.replace @viewdata(title="Not still second") - def second(self): + def second(self) -> str: return 'replaced-second' - def third(self): + def third(self) -> str: # type: ignore[override] return 'removed-third' also_inherited = BaseView.also_inherited.with_route('/inherited').with_route( @@ -234,7 +233,7 @@ class AnotherSubView(BaseView): @route('2-2') @BaseView.second.replace - def second(self): + def second(self) -> str: return 'also-replaced-second' @@ -242,30 +241,30 @@ def second(self): class ModelDocumentView(UrlForView, InstanceLoader, ModelView[ViewDocument]): """Test ModelView.""" - route_model_map = {'document': 'name'} + route_model_map: ClassVar = {'document': 'name'} @requestargs('access_token') def before_request( self, access_token: Optional[str] = None ) -> Optional[ResponseReturnValue]: - if access_token == 'owner-admin-secret': # nosec + if access_token == 'owner-admin-secret': # nosec B105 # noqa: S105 add_auth_attribute('permissions', InspectableSet({'siteadmin'})) # See ViewDocument.permissions add_auth_attribute('user', 'this-is-the-owner') - if access_token == 'owner-secret': # nosec + if access_token == 'owner-secret': # nosec B105 # noqa: S105 # See ViewDocument.permissions add_auth_attribute('user', 'this-is-the-owner') return super().before_request() @route('') @render_with(json=True) - def view(self): + def view(self) -> Mapping[str, Any]: return self.obj.current_access() @route('edit', methods=['GET', 'POST']) @route('', methods=['PUT']) @requires_permission('edit') - def edit(self): + def edit(self) -> str: return 'edit-called' @@ -279,7 +278,7 @@ class ScopedDocumentView(ModelDocumentView): """Test subclass of a ModelView.""" model = ScopedViewDocument # type: ignore[assignment,misc] - route_model_map = {'document': 'name', 'parent': 'parent.name'} + route_model_map: ClassVar = {'document': 'name', 'parent': 'parent.name'} @RenameableDocument.views('main') @@ -289,11 +288,11 @@ class RenameableDocumentView( ): """Test ModelView for a document that will auto-redirect if the URL changes.""" - route_model_map = {'document': 'url_name'} + route_model_map: ClassVar = {'document': 'url_name'} @route('') @render_with(json=True) - def view(self): + def view(self) -> Mapping[str, Any]: return self.obj.current_access() @@ -301,7 +300,7 @@ def view(self): class MultiDocumentView(UrlForView, ModelView[ViewDocument]): """Test ModelView that has multiple documents.""" - route_model_map = {'doc2': '**doc2.url_name'} + route_model_map: ClassVar = {'doc2': '**doc2.url_name'} obj: tuple[ViewDocument, RenameableDocument] # type: ignore[assignment] class GetAttr: @@ -318,7 +317,7 @@ def loader( # type: ignore[override] # pylint: disable=arguments-differ @route('') @requires_permission('view') - def linked_view(self): + def linked_view(self) -> str: return self.obj[0].url_for('linked_view', doc2=self.obj[1]) @@ -330,19 +329,19 @@ def linked_view(self): class GatedDocumentView(UrlForView, InstanceLoader, ModelView[ViewDocument]): """Test ModelView that has an intercept in before_request.""" - route_model_map = {'document': 'name'} + route_model_map: ClassVar = {'document': 'name'} @requestargs('access_token') def before_request( self, access_token: Optional[str] = None ) -> Optional[ResponseReturnValue]: - if access_token == 'owner-secret': # nosec + if access_token == 'owner-secret': # nosec B105 # noqa: S105 # See ViewDocument.permissions add_auth_attribute('user', 'this-is-the-owner') - if access_token == 'editor-secret': # nosec + if access_token == 'editor-secret': # nosec B105 # noqa: S105 # See ViewDocument.permissions add_auth_attribute('user', 'this-is-the-editor') - if access_token == 'another-owner-secret': # nosec + if access_token == 'another-owner-secret': # nosec B105 # noqa: S105 # See ViewDocument.permissions add_auth_attribute('user', 'this-is-another-owner') return super().before_request() @@ -404,12 +403,12 @@ def tearDown(self) -> None: self.ctx.pop() def test_index(self) -> None: - """Test index view (/)""" + """Test index view (/).""" rv = self.client.get('/') assert rv.data == b'index' def test_page(self) -> None: - """Test page view (/page)""" + """Test page view (/page).""" rv = self.client.get('/page') assert rv.data == b'page' @@ -437,7 +436,7 @@ def test_document_404(self) -> None: assert rv.status_code == 404 # This 404 came from DocumentView.view def test_document_view(self) -> None: - """Test document view (loaded from database)""" + """Test document view (loaded from database).""" doc = ViewDocument(name='test1', title="Test") self.session.add(doc) self.session.commit() @@ -473,8 +472,8 @@ def test_callable_view(self) -> None: assert data['name'] == 'test1' assert data['title'] == "Test" - rv = DocumentView().edit('test1', "Edited") - assert rv == 'edited!' + rv2 = DocumentView().edit('test1', "Edited") + assert rv2 == 'edited!' assert doc.title == "Edited" def test_replaced(self) -> None: @@ -657,6 +656,8 @@ def test_redirectablemodel_view(self) -> None: def test_multi_view(self) -> None: """ + Test ModelView with two objects. + A ModelView view can handle multiple objects and also construct URLs for objects that do not have a well defined relationship between each other. """ diff --git a/tests/coaster_tests/views_endpointfor_test.py b/tests/coaster_tests/views_endpointfor_test.py index 125aed12..80c3f7f1 100644 --- a/tests/coaster_tests/views_endpointfor_test.py +++ b/tests/coaster_tests/views_endpointfor_test.py @@ -1,150 +1,240 @@ """Tests for endpoint_for view helper.""" -import unittest -from typing import Optional +# pylint: disable=redefined-outer-name +from __future__ import annotations + +from collections.abc import AsyncIterator +from typing import Union + +import pytest from flask import Flask -from flask.ctx import RequestContext +from quart import Quart from coaster.views import endpoint_for def view() -> str: + """Test view, never actually called.""" return "view" -class TestScaffolding(unittest.TestCase): - server_name: Optional[str] = None - app: Flask - ctx: RequestContext +@pytest.fixture(params=['flask', 'quart']) +def app(request: pytest.FixtureRequest) -> Union[Flask, Quart]: + """Create a Flask or Quart app.""" + app: Union[Flask, Quart] + server_name = ( + 'example.com' if request.node.get_closest_marker('has_server_name') else None + ) + if request.param == 'quart': + app = Quart(__name__, subdomain_matching=bool(server_name)) + else: + app = Flask(__name__, subdomain_matching=bool(server_name)) + + if server_name: + app.config['SERVER_NAME'] = server_name + + # Use `view` as the view function for all routes as it's not actually called + app.add_url_rule('/', 'index', view) + app.add_url_rule('/slashed/', 'slashed', view) + app.add_url_rule('/sub', 'un_subdomained', view) + app.add_url_rule('/sub', 'subdomained', view, subdomain='') + + return app + + +@pytest.fixture +async def arequest_ctx(app: Union[Flask, Quart]) -> AsyncIterator[None]: + """Create an async test request context.""" + if isinstance(app, Flask): + with app.test_request_context(): + yield None + else: + async with app.test_request_context(path='/'): + yield None + + +@pytest.mark.usefixtures('arequest_ctx') +async def test_localhost_index() -> None: + assert endpoint_for('http://localhost/') == ('index', {}) + + +@pytest.mark.usefixtures('arequest_ctx') +async def test_localhost_slashed() -> None: + assert endpoint_for('http://localhost/slashed/') == ('slashed', {}) + + +@pytest.mark.usefixtures('arequest_ctx') +async def test_localhost_unslashed() -> None: + assert endpoint_for('http://localhost/slashed') == ('slashed', {}) + + +@pytest.mark.usefixtures('arequest_ctx') +async def test_localhost_unslashed_noredirect() -> None: + assert endpoint_for('http://localhost/slashed', follow_redirects=False) == ( + None, + {}, + ) + + +@pytest.mark.usefixtures('arequest_ctx') +async def test_localhost_sub() -> None: + assert endpoint_for('http://localhost/sub') == ('un_subdomained', {}) + + +@pytest.mark.usefixtures('arequest_ctx') +async def test_example_index() -> None: + assert endpoint_for('http://example.com/') == ('index', {}) + + +@pytest.mark.usefixtures('arequest_ctx') +async def test_example_slashed() -> None: + assert endpoint_for('http://example.com/slashed/') == ('slashed', {}) + + +@pytest.mark.usefixtures('arequest_ctx') +async def test_example_unslashed() -> None: + assert endpoint_for('http://example.com/slashed') == ('slashed', {}) + + +@pytest.mark.usefixtures('arequest_ctx') +async def test_example_unslashed_noredirect() -> None: + assert endpoint_for('http://example.com/slashed', follow_redirects=False) == ( + None, + {}, + ) + + +@pytest.mark.usefixtures('arequest_ctx') +async def test_example_sub() -> None: + assert endpoint_for('http://example.com/sub') == ('un_subdomained', {}) + + +@pytest.mark.usefixtures('arequest_ctx') +async def test_subexample_index() -> None: + assert endpoint_for('http://sub.example.com/') == ('index', {}) + + +@pytest.mark.usefixtures('arequest_ctx') +async def test_subexample_slashed() -> None: + assert endpoint_for('http://sub.example.com/slashed/') == ('slashed', {}) - def setUp(self) -> None: - self.app = Flask(__name__, subdomain_matching=bool(self.server_name)) - # Use `view` as the view function for all routes as it's not actually called - self.app.add_url_rule('/', 'index', view) - self.app.add_url_rule('/slashed/', 'slashed', view) - self.app.add_url_rule('/sub', 'un_subdomained', view) - self.app.add_url_rule('/sub', 'subdomained', view, subdomain='') - if self.server_name: - self.app.config['SERVER_NAME'] = self.server_name - self.ctx = self.app.test_request_context() - self.ctx.push() +@pytest.mark.usefixtures('arequest_ctx') +async def test_subexample_unslashed() -> None: + assert endpoint_for('http://sub.example.com/slashed') == ('slashed', {}) - def tearDown(self) -> None: - self.ctx.pop() +@pytest.mark.usefixtures('arequest_ctx') +async def test_subexample_unslashed_noredirect() -> None: + assert endpoint_for('http://sub.example.com/slashed', follow_redirects=False) == ( + None, + {}, + ) -class TestNoServerName(TestScaffolding): - def test_localhost_index(self) -> None: - assert endpoint_for('http://localhost/') == ('index', {}) - def test_localhost_slashed(self) -> None: - assert endpoint_for('http://localhost/slashed/') == ('slashed', {}) +@pytest.mark.usefixtures('arequest_ctx') +async def test_subexample_sub() -> None: + assert endpoint_for('http://sub.example.com/sub') == ('un_subdomained', {}) - def test_localhost_unslashed(self) -> None: - assert endpoint_for('http://localhost/slashed') == ('slashed', {}) - def test_localhost_unslashed_noredirect(self) -> None: - assert endpoint_for('http://localhost/slashed', follow_redirects=False) == ( - None, - {}, - ) +@pytest.mark.usefixtures('arequest_ctx') +@pytest.mark.has_server_name +async def test_named_localhost_index() -> None: + assert endpoint_for('http://localhost/') == (None, {}) - def test_localhost_sub(self) -> None: - assert endpoint_for('http://localhost/sub') == ('un_subdomained', {}) - def test_example_index(self) -> None: - assert endpoint_for('http://example.com/') == ('index', {}) +@pytest.mark.usefixtures('arequest_ctx') +@pytest.mark.has_server_name +async def test_named_localhost_slashed() -> None: + assert endpoint_for('http://localhost/slashed/') == (None, {}) - def test_example_slashed(self) -> None: - assert endpoint_for('http://example.com/slashed/') == ('slashed', {}) - def test_example_unslashed(self) -> None: - assert endpoint_for('http://example.com/slashed') == ('slashed', {}) +@pytest.mark.usefixtures('arequest_ctx') +@pytest.mark.has_server_name +async def test_named_localhost_unslashed() -> None: + assert endpoint_for('http://localhost/slashed') == (None, {}) - def test_example_unslashed_noredirect(self) -> None: - assert endpoint_for('http://example.com/slashed', follow_redirects=False) == ( - None, - {}, - ) - def test_example_sub(self) -> None: - assert endpoint_for('http://example.com/sub') == ('un_subdomained', {}) +@pytest.mark.usefixtures('arequest_ctx') +@pytest.mark.has_server_name +async def test_named_localhost_unslashed_noredirect() -> None: + assert endpoint_for('http://localhost/slashed', follow_redirects=False) == ( + None, + {}, + ) - def test_subexample_index(self) -> None: - assert endpoint_for('http://sub.example.com/') == ('index', {}) - def test_subexample_slashed(self) -> None: - assert endpoint_for('http://sub.example.com/slashed/') == ('slashed', {}) +@pytest.mark.usefixtures('arequest_ctx') +@pytest.mark.has_server_name +async def test_named_localhost_sub() -> None: + assert endpoint_for('http://localhost/sub') == (None, {}) - def test_subexample_unslashed(self) -> None: - assert endpoint_for('http://sub.example.com/slashed') == ('slashed', {}) - def test_subexample_unslashed_noredirect(self) -> None: - assert endpoint_for( - 'http://sub.example.com/slashed', follow_redirects=False - ) == (None, {}) +@pytest.mark.usefixtures('arequest_ctx') +@pytest.mark.has_server_name +async def test_named_example_index() -> None: + assert endpoint_for('http://example.com/') == ('index', {}) - def test_subexample_sub(self) -> None: - assert endpoint_for('http://sub.example.com/sub') == ('un_subdomained', {}) +@pytest.mark.usefixtures('arequest_ctx') +@pytest.mark.has_server_name +async def test_named_example_slashed() -> None: + assert endpoint_for('http://example.com/slashed/') == ('slashed', {}) -class TestWithServerName(TestScaffolding): - server_name = 'example.com' - def test_localhost_index(self) -> None: - assert endpoint_for('http://localhost/') == (None, {}) +@pytest.mark.usefixtures('arequest_ctx') +@pytest.mark.has_server_name +async def test_named_example_unslashed() -> None: + assert endpoint_for('http://example.com/slashed') == ('slashed', {}) - def test_localhost_slashed(self) -> None: - assert endpoint_for('http://localhost/slashed/') == (None, {}) - def test_localhost_unslashed(self) -> None: - assert endpoint_for('http://localhost/slashed') == (None, {}) +@pytest.mark.usefixtures('arequest_ctx') +@pytest.mark.has_server_name +async def test_named_example_unslashed_noredirect() -> None: + assert endpoint_for('http://example.com/slashed', follow_redirects=False) == ( + None, + {}, + ) - def test_localhost_unslashed_noredirect(self) -> None: - assert endpoint_for('http://localhost/slashed', follow_redirects=False) == ( - None, - {}, - ) - def test_localhost_sub(self) -> None: - assert endpoint_for('http://localhost/sub') == (None, {}) +@pytest.mark.usefixtures('arequest_ctx') +@pytest.mark.has_server_name +async def test_named_example_sub() -> None: + assert endpoint_for('http://example.com/sub') == ('un_subdomained', {}) - def test_example_index(self) -> None: - assert endpoint_for('http://example.com/') == ('index', {}) - def test_example_slashed(self) -> None: - assert endpoint_for('http://example.com/slashed/') == ('slashed', {}) +@pytest.mark.usefixtures('arequest_ctx') +@pytest.mark.has_server_name +async def test_named_subexample_index() -> None: + assert endpoint_for('http://sub.example.com/') == (None, {}) - def test_example_unslashed(self) -> None: - assert endpoint_for('http://example.com/slashed') == ('slashed', {}) - def test_example_unslashed_noredirect(self) -> None: - assert endpoint_for('http://example.com/slashed', follow_redirects=False) == ( - None, - {}, - ) +@pytest.mark.usefixtures('arequest_ctx') +@pytest.mark.has_server_name +async def test_named_subexample_slashed() -> None: + assert endpoint_for('http://sub.example.com/slashed/') == (None, {}) - def test_example_sub(self) -> None: - assert endpoint_for('http://example.com/sub') == ('un_subdomained', {}) - def test_subexample_index(self) -> None: - assert endpoint_for('http://sub.example.com/') == (None, {}) +@pytest.mark.usefixtures('arequest_ctx') +@pytest.mark.has_server_name +async def test_named_subexample_unslashed() -> None: + assert endpoint_for('http://sub.example.com/slashed') == (None, {}) - def test_subexample_slashed(self) -> None: - assert endpoint_for('http://sub.example.com/slashed/') == (None, {}) - def test_subexample_unslashed(self) -> None: - assert endpoint_for('http://sub.example.com/slashed') == (None, {}) +@pytest.mark.usefixtures('arequest_ctx') +@pytest.mark.has_server_name +async def test_named_subexample_unslashed_noredirect() -> None: + assert endpoint_for('http://sub.example.com/slashed', follow_redirects=False) == ( + None, + {}, + ) - def test_subexample_unslashed_noredirect(self) -> None: - assert endpoint_for( - 'http://sub.example.com/slashed', follow_redirects=False - ) == (None, {}) - def test_subexample_sub(self) -> None: - assert endpoint_for('http://sub.example.com/sub') == ( - 'subdomained', - {'subdomain': 'sub'}, - ) +@pytest.mark.usefixtures('arequest_ctx') +@pytest.mark.has_server_name +async def test_named_subexample_sub() -> None: + assert endpoint_for('http://sub.example.com/sub') == ( + 'subdomained', + {'subdomain': 'sub'}, + ) diff --git a/tests/coaster_tests/views_loadmodels_test.py b/tests/coaster_tests/views_loadmodels_test.py index 05c41481..6eb7d781 100644 --- a/tests/coaster_tests/views_loadmodels_test.py +++ b/tests/coaster_tests/views_loadmodels_test.py @@ -9,8 +9,8 @@ from flask import g from sqlalchemy.orm import Mapped from werkzeug.exceptions import Forbidden, NotFound -from werkzeug.wrappers import Response +from coaster.compat import BaseResponse from coaster.sqlalchemy import BaseMixin, BaseNameMixin, BaseScopedIdMixin, relationship from coaster.views import load_model, load_models @@ -65,10 +65,7 @@ class ChildDocument(BaseScopedIdMixin, Model): def permissions( self, actor: User, inherited: Optional[set[str]] = None ) -> set[str]: - if inherited is None: - perms = set() - else: - perms = inherited + perms = set() if inherited is None else inherited if actor.username == 'foo' and 'delete' in perms: perms.remove('delete') return perms @@ -105,7 +102,10 @@ def return_siteadmin_perms() -> set[str]: kwargs=True, addlperms=return_siteadmin_perms, ) -def t_container(container: Container, kwargs: dict[str, str]) -> Container: +def t_container( + container: Container, + kwargs: dict[str, str], # noqa: ARG001 +) -> Container: return container @@ -123,7 +123,10 @@ def t_single_model_in_loadmodels(user: User) -> User: (Container, {'name': 'container'}, 'container'), (NamedDocument, {'name': 'document', 'container': 'container'}, 'document'), ) -def t_named_document(container: Container, document: NamedDocument) -> NamedDocument: +def t_named_document( + container: Container, # noqa: ARG001 + document: NamedDocument, +) -> NamedDocument: return document @@ -135,7 +138,10 @@ def t_named_document(container: Container, document: NamedDocument) -> NamedDocu 'document', ), ) -def t_redirect_document(container: Container, document: NamedDocument) -> NamedDocument: +def t_redirect_document( + container: Container, # noqa: ARG001 + document: NamedDocument, +) -> NamedDocument: return document @@ -144,7 +150,8 @@ def t_redirect_document(container: Container, document: NamedDocument) -> NamedD (ScopedNamedDocument, {'name': 'document', 'container': 'container'}, 'document'), ) def t_scoped_named_document( - container: Container, document: ScopedNamedDocument + container: Container, # noqa: ARG001 + document: ScopedNamedDocument, ) -> ScopedNamedDocument: return document @@ -155,7 +162,8 @@ def t_scoped_named_document( urlcheck=['url_name'], ) def t_id_named_document( - container: Container, document: IdNamedDocument + container: Container, # noqa: ARG001 + document: IdNamedDocument, ) -> IdNamedDocument: return document @@ -164,12 +172,13 @@ def t_id_named_document( (Container, {'name': 'container'}, 'container'), ( ScopedIdDocument, - {'url_id': lambda r, p: int(p['document']), 'container': 'container'}, + {'url_id': lambda _r, p: int(p['document']), 'container': 'container'}, 'document', ), ) def t_scoped_id_document( - container: Container, document: ScopedIdDocument + container: Container, # noqa: ARG001 + document: ScopedIdDocument, ) -> ScopedIdDocument: return document @@ -184,7 +193,8 @@ def t_scoped_id_document( urlcheck=['url_name'], ) def t_scoped_id_named_document( - container: Container, document: ScopedIdNamedDocument + container: Container, # noqa: ARG001 + document: ScopedIdNamedDocument, ) -> ScopedIdNamedDocument: return document @@ -193,12 +203,13 @@ def t_scoped_id_named_document( (ParentDocument, {'name': 'document'}, 'document'), ( ChildDocument, - {'id': 'child', 'parent': lambda r, p: r['document'].middle}, + {'id': 'child', 'parent': lambda r, _p: r['document'].middle}, 'child', ), ) def t_callable_document( - document: ParentDocument, child: ChildDocument + document: ParentDocument, # noqa: ARG001 + child: ChildDocument, ) -> ChildDocument: return child @@ -207,7 +218,10 @@ def t_callable_document( (ParentDocument, {'name': 'document'}, 'document'), (ChildDocument, {'id': 'child', 'parent': 'document.middle'}, 'child'), ) -def t_dotted_document(document: ParentDocument, child: ChildDocument) -> ChildDocument: +def t_dotted_document( + document: ParentDocument, # noqa: ARG001 + child: ChildDocument, +) -> ChildDocument: return child @@ -217,7 +231,8 @@ def t_dotted_document(document: ParentDocument, child: ChildDocument) -> ChildDo permission='view', ) def t_dotted_document_view( - document: ParentDocument, child: ChildDocument + document: ParentDocument, # noqa: ARG001 + child: ChildDocument, ) -> ChildDocument: return child @@ -228,7 +243,8 @@ def t_dotted_document_view( permission='edit', ) def t_dotted_document_edit( - document: ParentDocument, child: ChildDocument + document: ParentDocument, # noqa: ARG001 + child: ChildDocument, ) -> ChildDocument: return child @@ -238,7 +254,10 @@ def t_dotted_document_edit( (ChildDocument, {'id': 'child', 'parent': 'document.middle'}, 'child'), permission='delete', ) -def t_dotted_document_delete(document, child): +def t_dotted_document_delete( + document: ParentDocument, # noqa: ARG001 + child: ChildDocument, +) -> ChildDocument: return child @@ -246,7 +265,7 @@ def t_dotted_document_delete(document, child): @pytest.fixture(scope='module', autouse=True) -def _app_extra(app): +def _app_extra(app) -> None: LoginManager(app) app.add_url_rule( '//', 'redirect_document', t_redirect_document @@ -340,12 +359,12 @@ def test_redirect_document(self) -> None: ) with self.app.test_request_context('/c/redirect-document'): response = t_redirect_document(container='c', document='redirect-document') - assert isinstance(response, Response) + assert isinstance(response, BaseResponse) assert response.status_code == 307 assert response.headers['Location'] == '/c/named-document' with self.app.test_request_context('/c/redirect-document?preserve=this'): response = t_redirect_document(container='c', document='redirect-document') - assert isinstance(response, Response) + assert isinstance(response, BaseResponse) assert response.status_code == 307 assert response.headers['Location'] == '/c/named-document?preserve=this' @@ -373,12 +392,12 @@ def test_id_named_document(self) -> None: ) with self.app.test_request_context('/c/1-wrong-name'): r = t_id_named_document(container='c', document='1-wrong-name') - assert isinstance(r, Response) + assert isinstance(r, BaseResponse) assert r.status_code == 302 assert r.location == '/c/1-id-named-document' with self.app.test_request_context('/c/1-wrong-name?preserve=this'): r = t_id_named_document(container='c', document='1-wrong-name') - assert isinstance(r, Response) + assert isinstance(r, BaseResponse) assert r.status_code == 302 assert r.location == '/c/1-id-named-document?preserve=this' with pytest.raises(NotFound): @@ -406,7 +425,7 @@ def test_scoped_id_named_document(self) -> None: ) with self.app.test_request_context('/c/1-wrong-name'): r = t_scoped_id_named_document(container='c', document='1-wrong-name') - assert isinstance(r, Response) + assert isinstance(r, BaseResponse) assert r.status_code == 302 assert r.location == '/c/1-scoped-id-named-document' with pytest.raises(NotFound): @@ -441,7 +460,7 @@ def test_inherited_permissions(self) -> None: } def test_unmutated_inherited_permissions(self) -> None: - """The inherited permission set should not be mutated by a permission check""" + """The inherited permission set should not be mutated by a permission check.""" user = User(username='admin') inherited = {'add-video'} assert self.pc.permissions(user, inherited=inherited) == {'add-video', 'view'} diff --git a/tests/coaster_tests/views_renderwith_test.py b/tests/coaster_tests/views_renderwith_test.py index 2eb8302e..6f221972 100644 --- a/tests/coaster_tests/views_renderwith_test.py +++ b/tests/coaster_tests/views_renderwith_test.py @@ -1,3 +1,5 @@ +"""Test `renderwith` view decorator.""" + import unittest from collections.abc import Mapping from typing import Any @@ -5,8 +7,8 @@ import pytest from flask import Flask, Response, jsonify from jinja2 import TemplateNotFound -from werkzeug.wrappers import Response as BaseResponse +from coaster.compat import BaseResponse from coaster.views import render_with # --- Test setup ----------------------------------------------------------------------- @@ -83,12 +85,8 @@ def test_render(self) -> None: # with the correct template name. Since the templates don't actually exist, # we'll get a TemplateNotFound exception, so our "test" is to confirm that the # missing template is the one that was supposed to be rendered. - try: - rv = self.client.get('/renderedview1') - except TemplateNotFound as e: - assert str(e) == 'renderedview1.html' - else: - pytest.fail(f"Unexpected response: {rv.headers!r} {rv.data!r}") + with pytest.raises(TemplateNotFound, match='renderedview1.html'): + self.client.get('/renderedview1') for acceptheader, template in [ ('text/html;q=0.9,text/xml;q=0.8,*/*', 'renderedview2.html'), @@ -98,16 +96,8 @@ def test_render(self) -> None: 'renderedview2.html', ), ]: - try: - rv = self.client.get( - '/renderedview2', headers=[('Accept', acceptheader)] - ) - except TemplateNotFound as e: - assert str(e) == template - else: - pytest.fail( - f"Accept: {acceptheader} Response: {rv.headers!r} {rv.data!r}" - ) + with pytest.raises(TemplateNotFound, match=template): + self.client.get('/renderedview2', headers=[('Accept', acceptheader)]) # The application/json and text/plain renderers do exist, so we should get # a valid return value from them. diff --git a/tests/coaster_tests/views_test.py b/tests/coaster_tests/views_test.py index 03e1aec6..68ca0bb1 100644 --- a/tests/coaster_tests/views_test.py +++ b/tests/coaster_tests/views_test.py @@ -1,3 +1,5 @@ +"""Test view helpers.""" + import unittest from typing import Any, Optional @@ -148,7 +150,7 @@ def test_jsonp(self) -> None: r = jsonp(**kwargs) # pylint: disable=consider-using-f-string response = ( - 'callback({\n "%s": "%s",\n "%s": "%s"\n});' + 'callback({\n "%s": "%s",\n "%s": "%s"\n});' # noqa: UP031 % ('lang', kwargs['lang'], 'query', kwargs['query']) ).encode('utf-8')