Skip to content

Commit

Permalink
Support joins in DynamicAssociationProxy (#461)
Browse files Browse the repository at this point in the history
  • Loading branch information
jace committed Apr 16, 2024
1 parent 007776a commit a294be6
Show file tree
Hide file tree
Showing 3 changed files with 105 additions and 21 deletions.
21 changes: 11 additions & 10 deletions src/coaster/sqlalchemy/model.py
Expand Up @@ -421,7 +421,7 @@ def __repr__(self) -> str:
else:
pk = ", ".join(map(str, state.identity))

return f'<{type(self).__name__} {pk}>'
return f'<{type(self).__qualname__} {pk}>'


# --- `relationship` and `backref` wrappers for `lazy='dynamic'` -----------------------
Expand All @@ -430,13 +430,12 @@ def __repr__(self) -> str:
# a type hint to Coaster's AppenderQuery, which in turn wraps Coaster's Query with its
# additional methods

if TYPE_CHECKING:

class DynamicMapped(DynamicMappedBase[_T_co]):
"""Represent the ORM mapped attribute type for a "dynamic" relationship."""
class DynamicMapped(DynamicMappedBase[_T_co]):
"""Represent the ORM mapped attribute type for a "dynamic" relationship."""

__slots__ = ()

if TYPE_CHECKING:
__slots__ = ()

@overload # type: ignore[override]
def __get__(
Expand All @@ -452,11 +451,13 @@ def __get__(

def __set__(self, instance: Any, value: Collection[_T_co]) -> None: ...

class Relationship(RelationshipBase[_T], DynamicMapped[_T]): # type: ignore[misc]
"""Wraps Relationship with the updated version of DynamicMapped."""

class Relationship(RelationshipBase[_T], DynamicMapped[_T]): # type: ignore[misc]
"""Wraps Relationship with the updated version of DynamicMapped."""

__slots__ = ()
else:
# Avoid the overhead of empty subclasses at runtime
DynamicMapped = DynamicMappedBase
Relationship = RelationshipBase


_P = ParamSpec('_P')
Expand Down
40 changes: 29 additions & 11 deletions src/coaster/sqlalchemy/roles.py
Expand Up @@ -163,11 +163,11 @@ def roles_for(
MappedColumn,
MapperProperty,
Query as QueryBase,
QueryableAttribute,
RelationshipProperty,
SynonymProperty,
declarative_mixin,
)
from sqlalchemy.orm.attributes import QueryableAttribute
from sqlalchemy.orm.collections import (
InstrumentedDict,
InstrumentedList,
Expand Down Expand Up @@ -679,16 +679,25 @@ class DynamicAssociationProxy(Generic[_V, _R]):
document.child_attributes[value] == relationship_object
:param str rel: Relationship name (must use ``lazy='dynamic'``)
:param str attr: Attribute on the target of the relationship
:param rel: Relationship name (must use ``lazy='dynamic'``)
:param attr: Attribute on the target of the relationship
:param qattr: Optional callable that returns a
:class:`~sqlalchemy.orm.QueryableAttribute` to use in the query filter, for use
when the relationship includes joins
"""

__slots__ = ('rel', 'attr', 'name')
__slots__ = ('rel', 'attr', 'qattr', 'name')
name: Optional[str]

def __init__(self, rel: str, attr: str) -> None:
def __init__(
self,
rel: str,
attr: str,
qattr: Optional[Callable[[], QueryableAttribute]] = None,
) -> None:
self.rel = rel
self.attr = attr
self.qattr = qattr
self.name = None

def __set_name__(self, owner: _T, name: str) -> None:
Expand All @@ -710,7 +719,10 @@ def __get__(
) -> Union[Self, DynamicAssociationProxyBind[_T, _V, _R]]:
if obj is None:
return self
wrapper = DynamicAssociationProxyBind(obj, self.rel, self.attr)
qattr = self.qattr
wrapper = DynamicAssociationProxyBind(
obj, self.rel, self.attr, qattr() if qattr is not None else None
)
if name := self.name:
# Cache it for repeat access. SQLAlchemy models cannot use __slots__ or
# must include __dict__ in slots for state management, so we don't need to
Expand All @@ -722,19 +734,18 @@ def __get__(
class DynamicAssociationProxyBind(abc.Mapping, Generic[_T, _V, _R]):
""":class:`DynamicAssociationProxy` bound to an instance."""

__slots__ = ('obj', 'rel', 'relattr', 'attr')
__slots__ = ('obj', 'rel', 'relattr', 'attr', 'qattr')
relattr: AppenderQuery
qattr: Optional[QueryableAttribute]

def __init__(
self,
obj: _T,
rel: str,
attr: str,
self, obj: _T, rel: str, attr: str, qattr: Optional[QueryableAttribute]
) -> None:
self.obj = obj
self.rel = rel
self.relattr: AppenderQuery[_R] = getattr(obj, rel)
self.attr = attr
self.qattr = qattr

def __repr__(self) -> str:
return f'DynamicAssociationProxyBind({self.obj!r}, {self.rel!r}, {self.attr!r})'
Expand All @@ -743,12 +754,18 @@ def __contains__(self, value: Any) -> bool:
relattr = self.relattr
if TYPE_CHECKING:
assert relattr.session is not None # nosec B101
if (qattr := self.qattr) is not None:
return relattr.session.query(
relattr.filter(qattr == value).exists()
).scalar()
return relattr.session.query(
relattr.filter_by(**{self.attr: value}).exists()
).scalar()

def __getitem__(self, value: Any) -> _R:
try:
if (qattr := self.qattr) is not None:
return self.relattr.filter(qattr == value).one()
return self.relattr.filter_by(**{self.attr: value}).one()
except NoResultFound:
raise KeyError(value) from None
Expand All @@ -772,6 +789,7 @@ def __eq__(self, other: Any) -> bool:
self.obj == other.obj
and self.rel == other.rel
and self.attr == other.attr
and self.qattr == other.qattr
)
return NotImplemented

Expand Down
65 changes: 65 additions & 0 deletions tests/coaster_tests/sqlalchemy_roles_test.py
Expand Up @@ -208,6 +208,12 @@ class RelationshipParent(BaseNameMixin, Model):
children_names = DynamicAssociationProxy[str, RelationshipChild](
'children_list_lazy', 'name'
)
# Another instance of DynamicAssociationProxy, this time using a QueryableAttribute
children_namesq: DynamicAssociationProxy[str, RelationshipChild] = (
DynamicAssociationProxy(
'children_list_lazy', 'name', lambda: RelationshipChild.name
)
)

__roles__ = {
'all': {
Expand Down Expand Up @@ -987,6 +993,65 @@ def test_dynamic_association_proxy(self) -> None:
assert not p1a != p1b # Test __ne__
assert p1a != parent2.children_names # Cross-check with an unrelated proxy

def test_dynamic_association_proxy_qattr(self) -> None:
parent1 = RelationshipParent(title="Proxy Parent 1")
parent2 = RelationshipParent(title="Proxy Parent 2")
parent3 = RelationshipParent(title="Proxy Parent 3")
child1 = RelationshipChild(name='child1', title="Proxy Child 1", parent=parent1)
child2 = RelationshipChild(name='child2', title="Proxy Child 2", parent=parent1)
child3 = RelationshipChild(name='child3', title="Proxy Child 3", parent=parent2)
self.session.add_all([parent1, parent2, parent3, child1, child2, child3])
self.session.commit()

assert isinstance(RelationshipParent.children_namesq, DynamicAssociationProxy)

assert child1.name in parent1.children_namesq
assert child2.name in parent1.children_namesq
assert child3.name not in parent1.children_namesq

assert child1.name not in parent2.children_namesq
assert child2.name not in parent2.children_namesq
assert child3.name in parent2.children_namesq

assert child1.name not in parent3.children_namesq
assert child2.name not in parent3.children_namesq
assert child3.name not in parent3.children_namesq

assert len(parent1.children_namesq) == 2
assert set(parent1.children_namesq) == {child1.name, child2.name}

assert len(parent2.children_namesq) == 1
assert set(parent2.children_namesq) == {child3.name}

assert len(parent3.children_namesq) == 0
assert set(parent3.children_namesq) == set()

assert bool(parent1.children_namesq) is True
assert bool(parent2.children_namesq) is True
assert bool(parent3.children_namesq) is False

assert parent1.children_namesq[child1.name] == child1
assert parent1.children_namesq[child2.name] == child2
with pytest.raises(KeyError):
parent1.children_namesq[child3.name] # pylint: disable=pointless-statement
assert parent1.children_namesq.get(child3.name) is None
assert dict(parent1.children_namesq) == {
child1.name: child1,
child2.name: child2,
}
assert sorted(parent1.children_namesq.items()) == [
(child1.name, child1),
(child2.name, child2),
]

# Repeat access returns proxy wrappers from instance cache
p1a = parent1.children_namesq
p1b = parent1.children_namesq
assert p1a is p1b
assert p1a == p1b # Test __eq__
assert not p1a != p1b # Test __ne__
assert p1a != parent2.children_namesq # Cross-check with an unrelated proxy

def test_granted_via(self) -> None:
"""
Roles can be granted via related objects
Expand Down

0 comments on commit a294be6

Please sign in to comment.