Skip to content

Commit

Permalink
Don't force type definition of ModelBase.query (#377)
Browse files Browse the repository at this point in the history
  • Loading branch information
jace committed May 28, 2023
1 parent 72cdd65 commit 93ef03f
Show file tree
Hide file tree
Showing 6 changed files with 76 additions and 72 deletions.
2 changes: 2 additions & 0 deletions CHANGES.rst
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,8 @@
unmaintained, and usage is shifting from Webassets to Webpack
* ``for_tsquery`` has been removed as PostgreSQL>=12 has native functions
* ``render_with`` no longer offers a shorthand for JSONP responses
* ``coaster.sqlalchemy.ModelBase`` now replaces Flask-SQLAlchemy's db.Model
with full support for type hinting

0.6.1 - 2021-01-06
------------------
Expand Down
1 change: 1 addition & 0 deletions src/coaster/sqlalchemy/annotations.py
Original file line number Diff line number Diff line change
Expand Up @@ -158,6 +158,7 @@ def _configure_annotations(_mapper: t.Any, cls: t.Type) -> None:
):
data = attr.column._coaster_annotations
elif hasattr(attr, '_coaster_annotations'):
# pylint: disable=protected-access
data = attr._coaster_annotations
elif isinstance(
attr, (QueryableAttribute, RelationshipProperty, MapperProperty)
Expand Down
91 changes: 41 additions & 50 deletions src/coaster/sqlalchemy/mixins.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@ class MyModel(BaseMixin, db.Model):
import typing as t

from flask import Flask, current_app, url_for
from sqlalchemy import event
from sqlalchemy.ext.hybrid import hybrid_property
from sqlalchemy.orm import Mapped, declarative_mixin, declared_attr, synonym
from sqlalchemy.sql import func, select
Expand All @@ -56,7 +57,7 @@ class MyModel(BaseMixin, db.Model):
)
from .functions import auto_init_default, failsafe_add
from .immutable_annotation import immutable
from .model import Query
from .model import Query, QueryProperty
from .registry import RegistryMixin
from .roles import RoleMixin, with_roles

Expand Down Expand Up @@ -97,11 +98,10 @@ class MyModel(IdMixin, db.Model):
"""

query_class: t.ClassVar[t.Type[Query]] = Query
query: t.ClassVar[Query[te.Self]]
__column_annotations__: dict
query: t.ClassVar[QueryProperty]
#: Use UUID primary key? If yes, UUIDs are automatically generated without
#: the need to commit to the database
__uuid_primary_key__ = False
__uuid_primary_key__: t.ClassVar[bool] = False

@immutable
@declared_attr
Expand Down Expand Up @@ -172,6 +172,8 @@ class UuidMixin:
code.
"""

__uuid_primary_key__: t.ClassVar[bool]

@immutable
@with_roles(read={'all'})
@declared_attr
Expand Down Expand Up @@ -238,9 +240,8 @@ class TimestampMixin:
"""Provides the :attr:`created_at` and :attr:`updated_at` audit timestamps."""

query_class: t.ClassVar[t.Type[Query]] = Query
query: t.ClassVar[Query[te.Self]]
query: t.ClassVar[QueryProperty]
__with_timezone__: t.ClassVar[bool] = False
__column_annotations__: t.ClassVar[dict]

@immutable
@declared_attr
Expand Down Expand Up @@ -375,11 +376,13 @@ class UrlForMixin:
#: strings. The same action can point to different endpoints in different apps. The
#: app may also be None as fallback. Each subclass will get its own dictionary.
#: This particular dictionary is only used as an inherited fallback.
url_for_endpoints: t.Dict[t.Optional[Flask], t.Dict[str, UrlEndpointData]] = {
None: {}
}
url_for_endpoints: t.ClassVar[
t.Dict[t.Optional[Flask], t.Dict[str, UrlEndpointData]]
] = {None: {}}
#: Mapping of {app: {action: (classview, attr)}}
view_for_endpoints: t.Dict[t.Optional[Flask], t.Dict[str, t.Tuple[t.Any, str]]] = {}
view_for_endpoints: t.ClassVar[
t.Dict[t.Optional[Flask], t.Dict[str, t.Tuple[t.Any, str]]]
] = {}

#: Dictionary of URLs available on this object
urls = UrlDictStub()
Expand Down Expand Up @@ -578,12 +581,12 @@ class BaseNameMixin(BaseMixin):
"""

#: Prevent use of these reserved names
reserved_names: t.Collection[str] = []
reserved_names: t.ClassVar[t.Collection[str]] = []
#: Allow blank names after all?
__name_blank_allowed__ = False
__name_blank_allowed__: t.ClassVar[bool] = False
#: How long are names and title allowed to be? `None` for unlimited length
__name_length__: t.Optional[int] = 250
__title_length__: t.Optional[int] = 250
__name_length__: t.ClassVar[t.Optional[int]] = 250
__title_length__: t.ClassVar[t.Optional[int]] = 250

@declared_attr
def name(cls) -> Mapped[str]:
Expand Down Expand Up @@ -627,11 +630,7 @@ def __repr__(self) -> str:
@classmethod
def get(cls, name: str) -> t.Optional[te.Self]:
"""Get an instance matching the name."""
# Mypy is confused by te.Self: Incompatible return value type
# (got "Optional[BaseNameMixin]", expected "Optional[Self]") [return-value]
return cls.query.filter_by( # type: ignore[return-value]
name=name
).one_or_none()
return cls.query.filter_by(name=name).one_or_none()

@classmethod
def upsert(cls, name: str, **fields) -> te.Self:
Expand Down Expand Up @@ -722,15 +721,15 @@ class Event(BaseScopedNameMixin, db.Model):
"""

#: Prevent use of these reserved names
reserved_names: t.Collection[str] = []
reserved_names: t.ClassVar[t.Collection[str]] = []
#: Allow blank names after all?
__name_blank_allowed__ = False
__name_blank_allowed__: t.ClassVar[bool] = False
#: How long are names and title allowed to be? `None` for unlimited length
__name_length__: t.Optional[int] = 250
__title_length__: t.Optional[int] = 250
__name_length__: t.ClassVar[t.Optional[int]] = 250
__title_length__: t.ClassVar[t.Optional[int]] = 250

#: Specify expected type for a 'parent' attr
parent: Mapped[t.Any]
parent: t.Any

@declared_attr
def name(cls) -> Mapped[str]:
Expand Down Expand Up @@ -767,9 +766,7 @@ def __repr__(self) -> str:
@classmethod
def get(cls, parent: t.Any, name: str) -> t.Optional[te.Self]:
"""Get an instance matching the parent and name."""
return cls.query.filter_by( # type: ignore[return-value]
parent=parent, name=name
).one_or_none()
return cls.query.filter_by(parent=parent, name=name).one_or_none()

@classmethod
def upsert(cls, parent: t.Any, name: str, **fields) -> te.Self:
Expand Down Expand Up @@ -886,10 +883,10 @@ class BaseIdNameMixin(BaseMixin):
"""

#: Allow blank names after all?
__name_blank_allowed__ = False
__name_blank_allowed__: t.ClassVar[bool] = False
#: How long are names and title allowed to be? `None` for unlimited length
__name_length__: t.Optional[int] = 250
__title_length__: t.Optional[int] = 250
__name_length__: t.ClassVar[t.Optional[int]] = 250
__title_length__: t.ClassVar[t.Optional[int]] = 250

@declared_attr
def name(cls) -> Mapped[str]:
Expand Down Expand Up @@ -992,7 +989,7 @@ class Issue(BaseScopedIdMixin, db.Model):
"""

#: Specify expected type for a 'parent' attr
parent: Mapped[t.Any]
parent: t.Any

# FIXME: Rename this to `scoped_id` and provide a migration guide.
@with_roles(read={'all'})
Expand All @@ -1012,9 +1009,7 @@ def __repr__(self) -> str:
@classmethod
def get(cls, parent: t.Any, url_id: t.Union[str, int]) -> t.Optional[te.Self]:
"""Get an instance matching the parent and url_id."""
return cls.query.filter_by( # type: ignore[return-value]
parent=parent, url_id=url_id
).one_or_none()
return cls.query.filter_by(parent=parent, url_id=url_id).one_or_none()

def make_id(self) -> None:
"""Create a new URL id that is unique to the parent container."""
Expand Down Expand Up @@ -1076,10 +1071,10 @@ class Event(BaseScopedIdNameMixin, db.Model):
"""

#: Allow blank names after all?
__name_blank_allowed__ = False
__name_blank_allowed__: t.ClassVar[bool] = False
#: How long are names and title allowed to be? `None` for unlimited length
__name_length__: t.Optional[int] = 250
__title_length__: t.Optional[int] = 250
__name_length__: t.ClassVar[t.Optional[int]] = 250
__title_length__: t.ClassVar[t.Optional[int]] = 250

@declared_attr
def name(cls) -> Mapped[str]:
Expand Down Expand Up @@ -1128,9 +1123,7 @@ def __repr__(self) -> str:
@classmethod
def get(cls, parent: t.Any, url_id: t.Union[int, str]) -> t.Optional[te.Self]:
"""Get an instance matching the parent and name."""
return cls.query.filter_by( # type: ignore[return-value]
parent=parent, url_id=url_id
).one_or_none()
return cls.query.filter_by(parent=parent, url_id=url_id).one_or_none()

def make_name(self) -> None:
"""Autogenerate :attr:`name` from :attr:`title` (via :attr:`title_for_name)."""
Expand Down Expand Up @@ -1230,10 +1223,8 @@ def _configure_uuid_listener(mapper: t.Any, class_: UuidMixin) -> None:
auto_init_default(mapper.column_attrs.uuid)


sa.event.listen(IdMixin, 'mapper_configured', _configure_id_listener, propagate=True)
sa.event.listen(
UuidMixin, 'mapper_configured', _configure_uuid_listener, propagate=True
)
event.listen(IdMixin, 'mapper_configured', _configure_id_listener, propagate=True)
event.listen(UuidMixin, 'mapper_configured', _configure_uuid_listener, propagate=True)


# Populate name and url_id columns
Expand All @@ -1256,9 +1247,9 @@ def _make_scoped_id(
target.make_id() # type: ignore[unreachable]


sa.event.listen(BaseNameMixin, 'before_insert', _make_name, propagate=True)
sa.event.listen(BaseIdNameMixin, 'before_insert', _make_name, propagate=True)
sa.event.listen(BaseScopedIdMixin, 'before_insert', _make_scoped_id, propagate=True)
sa.event.listen(BaseScopedNameMixin, 'before_insert', _make_scoped_name, propagate=True)
sa.event.listen(BaseScopedIdNameMixin, 'before_insert', _make_scoped_id, propagate=True)
sa.event.listen(BaseScopedIdNameMixin, 'before_insert', _make_name, propagate=True)
event.listen(BaseNameMixin, 'before_insert', _make_name, propagate=True)
event.listen(BaseIdNameMixin, 'before_insert', _make_name, propagate=True)
event.listen(BaseScopedIdMixin, 'before_insert', _make_scoped_id, propagate=True)
event.listen(BaseScopedNameMixin, 'before_insert', _make_scoped_name, propagate=True)
event.listen(BaseScopedIdNameMixin, 'before_insert', _make_scoped_id, propagate=True)
event.listen(BaseScopedIdNameMixin, 'before_insert', _make_name, propagate=True)
22 changes: 7 additions & 15 deletions src/coaster/sqlalchemy/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,6 @@ class Other(Model):
from __future__ import annotations

from functools import wraps
from typing import overload
import datetime
import typing as t
import uuid
Expand Down Expand Up @@ -241,32 +240,25 @@ class AppenderQuery( # type: ignore[misc] # pylint: disable=abstract-method
query_class = Query # type: ignore[assignment]


class QueryProperty(t.Generic[_T]):
class QueryProperty:
"""A class property that creates a query object for a model."""

@overload
def __get__(self, obj: None, cls: t.Type[_T]) -> Query[_T]:
...

@overload
def __get__(self, obj: _T, cls: t.Type[_T]) -> Query[_T]:
...

def __get__(self, obj: t.Optional[_T], cls: t.Type[_T]) -> Query[_T]:
def __get__(self, _obj: t.Optional[_T], cls: t.Type[_T]) -> Query[_T]:
return cls.query_class(cls, session=cls.__fsa__.session())


class ModelBase:
"""Flask-SQLAlchemy compatible base class that supports PEP 484 type hinting."""

__fsa__: t.ClassVar[SQLAlchemy]
__bind_key__: t.Optional[str]
__bind_key__: t.ClassVar[t.Optional[str]]
metadata: t.ClassVar[sa.MetaData]
query_class: t.ClassVar[type[Query]] = Query
query: t.ClassVar[Query[te.Self]] = QueryProperty() # type: ignore[assignment]
query: t.ClassVar[QueryProperty] = QueryProperty()
# Added by Coaster annotations
__column_annotations__: t.ClassVar[t.Dict[str, t.List[str]]]
__column_annotations_by_attr__: t.ClassVar[t.Dict[str, t.List[str]]]

# FIXME: Drop bind_key arg, Mypy cannot understand it when there's a missing import,
# including in pre-commit checks within this repository itself
def __init_subclass__(cls) -> None:
"""Configure a declarative base class."""
if ModelBase in cls.__bases__:
Expand Down
26 changes: 22 additions & 4 deletions tests/coaster_tests/sqlalchemy_annotations_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
from coaster.sqlalchemy import (
BaseMixin,
ImmutableColumnError,
ModelBase,
UuidMixin,
cached,
immutable,
Expand Down Expand Up @@ -134,22 +135,39 @@ def test_has_annotations() -> None:
def test_annotation_in_annotations() -> None:
"""Annotations were discovered."""
for model in (IdOnly, IdUuid, UuidOnly):
assert issubclass(model, ModelBase)
for annotation in (immutable, cached):
assert annotation.__name__ in model.__column_annotations__
assert (
annotation.__name__
in model.__column_annotations__ # type: ignore[attr-defined]
)


def test_attr_in_annotations() -> None:
"""Annotated attributes were discovered and documented."""
for model in (IdOnly, IdUuid, UuidOnly):
assert 'is_immutable' in model.__column_annotations__['immutable']
assert 'is_cached' in model.__column_annotations__['cached']
assert issubclass(model, ModelBase)
assert (
'is_immutable'
in model.__column_annotations__['immutable'] # type: ignore[attr-defined]
)
assert (
'is_cached'
in model.__column_annotations__['cached'] # type: ignore[attr-defined]
)


def test_base_attrs_in_annotations() -> None:
"""Annotations in the base class were also discovered and added to subclass."""
for model in (IdOnly, IdUuid, UuidOnly):
assert issubclass(model, ModelBase)
for attr in ('created_at', 'id'):
assert attr in model.__column_annotations__['immutable']
assert (
attr
in model.__column_annotations__[ # type: ignore[attr-defined]
'immutable'
]
)
assert 'uuid' in IdUuid.__column_annotations__['immutable']


Expand Down
6 changes: 3 additions & 3 deletions tests/coaster_tests/sqlalchemy_modelbase_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,12 +64,12 @@ class Model(ModelBase, DeclarativeBase):
class BindModel(ModelBase, DeclarativeBase):
"""Test bind model base."""

__bind_key__: t.Optional[str] = 'test'
__bind_key__: t.ClassVar[t.Optional[str]] = 'test'

class Mixin:
"""Mixin that replaces bind_key."""

__bind_key__: t.Optional[str] = 'other'
__bind_key__: t.ClassVar[t.Optional[str]] = 'other'

assert Model.__bind_key__ is None
with pytest.raises(TypeError, match="__bind_key__.*does not match base class"):
Expand Down Expand Up @@ -108,7 +108,7 @@ class Model(ModelBase, DeclarativeBase):
class BindModel(ModelBase, DeclarativeBase):
"""Test bind model base."""

__bind_key__: t.Optional[str] = 'test'
__bind_key__ = 'test'

assert Model.__bind_key__ is None
assert BindModel.__bind_key__ == 'test'
Expand Down

0 comments on commit 93ef03f

Please sign in to comment.