Skip to content

Commit

Permalink
Merge 2039e18 into df9bd2c
Browse files Browse the repository at this point in the history
  • Loading branch information
jace committed Aug 11, 2023
2 parents df9bd2c + 2039e18 commit 5e85797
Show file tree
Hide file tree
Showing 3 changed files with 124 additions and 37 deletions.
27 changes: 27 additions & 0 deletions src/coaster/sqlalchemy/functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
from typing import cast, overload

import sqlalchemy as sa
from sqlalchemy import inspect
from sqlalchemy.ext.compiler import compiles
from sqlalchemy.orm import DeclarativeBase

Expand All @@ -19,6 +20,7 @@
'failsafe_add',
'add_primary_relationship',
'auto_init_default',
'idfilters',
]

T = t.TypeVar('T')
Expand Down Expand Up @@ -325,3 +327,28 @@ def init_scalar(
dict_[column.key] = value
return value
return None


def idfilters(obj: DeclarativeBase) -> t.Optional[t.List[sa.BinaryExpression]]:
"""
Return SQLAlchemy expressions for the identity of the given object.
This is useful when querying for membership in a lazy relationship. With
DynamicMapped (``lazy='dynamic'``)::
filtered_query = parent.children.filter(*idfilters(child))
Or with WriteOnlyMapped (``lazy='write_only'``)::
filtered_select = parent.children.select().where(*idfilters(child))
Returns None when the object has no persistent identity.
"""
insp = inspect(obj)
identity = insp.identity
if identity is None:
return None
pkeys = insp.mapper.primary_key
if len(pkeys) == 1:
return [pkeys[0] == identity[0]]
return [column == value for column, value in zip(pkeys, identity)]
120 changes: 84 additions & 36 deletions src/coaster/sqlalchemy/roles.py
Original file line number Diff line number Diff line change
Expand Up @@ -138,7 +138,7 @@ def roles_for(

import sqlalchemy as sa
from flask import g
from sqlalchemy import event, inspect
from sqlalchemy import event, select
from sqlalchemy.ext.orderinglist import OrderingList
from sqlalchemy.orm import (
ColumnProperty,
Expand All @@ -160,6 +160,7 @@ def roles_for(

from ..auth import current_auth
from ..utils import InspectableSet, is_collection, nary_op
from .functions import idfilters
from .model import AppenderQuery

__all__ = [
Expand Down Expand Up @@ -255,21 +256,18 @@ def _actor_in_relationship(actor: t.Any, relationship: t.Any) -> bool:
if actor == relationship:
return True
if isinstance(relationship, QueryBase):
identity = inspect(actor).identity
if identity is not None: # If None, do container test next
pkeys = inspect(actor.__class__).primary_key
filters = idfilters(actor)
if filters is not None:
return (
relationship.session.scalar(
sa.select(
relationship.filter(
*[column == value for column, value in zip(pkeys, identity)]
).exists()
)
select(relationship.filter(*filters).exists())
)
or False
or False # The or clause is needed as .scalar is typed Optional
)
# If actor does not have an identity yet, check in the session's collection
return actor in relationship
if isinstance(relationship, abc.Container):
# Regular Python container
return actor in relationship
return False

Expand Down Expand Up @@ -713,23 +711,33 @@ class RoleAccessProxy(abc.Mapping, t.Generic[RoleMixinType]):
__slots__ = (
'_obj',
'current_roles',
'_roles',
'_actor',
'_anchors',
'_datasets',
'_dataset_attrs',
'_call',
'_read',
'_write',
'_no_call',
'_no_read',
'_no_write',
'_all_read_cache',
)
_obj: RoleMixinType
current_roles: InspectableSet[t.Union[LazyRoleSet, t.Set[str]]]
_roles: t.Union[LazyRoleSet, t.Set[str]]
_actor: t.Any
_anchors: t.Sequence[t.Any]
_datasets: t.Optional[t.Sequence[str]]
_dataset_attrs: t.Optional[t.Set[str]]
_call: t.Set[str]
_read: t.Set[str]
_write: t.Set[str]
_no_call: t.Set[str]
_no_read: t.Set[str]
_no_write: t.Set[str]
_all_read_cache: t.Optional[t.Set[str]]

@property # type: ignore[override]
def __class__(self) -> t.Type[RoleMixinType]:
Expand All @@ -749,6 +757,7 @@ def __init__(
) -> None:
object.__setattr__(self, '_obj', obj)
object.__setattr__(self, 'current_roles', InspectableSet(roles))
object.__setattr__(self, '_roles', roles)
object.__setattr__(self, '_actor', actor)
object.__setattr__(self, '_anchors', anchors)
if datasets is None:
Expand All @@ -768,19 +777,13 @@ def __init__(
object.__setattr__(self, '_datasets', datasets[1:])
object.__setattr__(self, '_dataset_attrs', dataset_attrs)

# Call, read and write access attributes for the given roles
call = set()
read = set()
write = set()

for role in roles:
call.update(obj.__roles__.get(role, {}).get('call', set()))
read.update(obj.__roles__.get(role, {}).get('read', set()))
write.update(obj.__roles__.get(role, {}).get('write', set()))

object.__setattr__(self, '_call', call)
object.__setattr__(self, '_read', read)
object.__setattr__(self, '_write', write)
object.__setattr__(self, '_call', set())
object.__setattr__(self, '_read', set())
object.__setattr__(self, '_write', set())
object.__setattr__(self, '_no_call', set())
object.__setattr__(self, '_no_read', set())
object.__setattr__(self, '_no_write', set())
object.__setattr__(self, '_all_read_cache', None)

def __repr__(self) -> str:
return f'RoleAccessProxy(obj={self._obj!r}, roles={self.current_roles!r})'
Expand Down Expand Up @@ -822,9 +825,60 @@ def __get_processed_attr(self, name: str) -> t.Any:
)
return attr

def __attr_available(
self, attr: str, action: te.Literal['call', 'read', 'write']
) -> bool:
"""Check for attr availability using a cache."""
if action == 'read':
present, absent = self._read, self._no_read
elif action == 'call':
present, absent = self._call, self._no_call
elif action == 'write':
present, absent = self._write, self._no_write

if attr in present:
return True
if attr in absent:
return False
# Not cached. Check for roles that grant access, then check for role
# availability
granting_roles = {
role
for role, roledict in self._obj.__roles__.items()
if attr in roledict.get(action, ())
}
if (
isinstance(self._roles, LazyRoleSet) and self._roles.has_any(granting_roles)
) or (self._roles & granting_roles):
present.add(attr)
return True
absent.add(attr)
return False

@property
def _all_read(self) -> t.Set[str]:
"""All readable attributes."""
if self._all_read_cache is not None:
return self._all_read_cache
all_read_attrs = {
attr
for roledict in self._obj.__roles__.values()
for attr in roledict.get('read', ())
}
if self._dataset_attrs is not None:
# If a dataset is specified, drop all attrs that don't appear in the dataset
all_read_attrs = all_read_attrs & self._dataset_attrs
# Next, filter for attr availability
available_read_attrs = {
attr for attr in all_read_attrs if self.__attr_available(attr, 'read')
}
# Save to cache and return
object.__setattr__(self, '_all_read_cache', available_read_attrs)
return available_read_attrs

def __getattr__(self, attr: str) -> t.Any:
# See also __getitem__, which doesn't consult _call
if attr in self._read or attr in self._call:
if self.__attr_available(attr, 'read') or self.__attr_available(attr, 'call'):
return self.__get_processed_attr(attr)
raise AttributeError(
f"{self._obj.__class__.__qualname__}.{attr};"
Expand All @@ -833,7 +887,7 @@ def __getattr__(self, attr: str) -> t.Any:

def __setattr__(self, attr: str, value: t.Any) -> None:
# See also __setitem__
if attr in self._write:
if self.__attr_available(attr, 'write'):
return setattr(self._obj, attr, value)
raise AttributeError(
f"{self._obj.__class__.__qualname__}.{attr};"
Expand All @@ -842,36 +896,30 @@ def __setattr__(self, attr: str, value: t.Any) -> None:

def __getitem__(self, key: str) -> t.Any:
# See also __getattr__, which also looks in _call
if key in self._read:
if self.__attr_available(key, 'read'):
return self.__get_processed_attr(key)
raise KeyError(
f"{self._obj.__class__.__qualname__}.{key};"
f" current roles {set(self.current_roles)!r}"
)

def __len__(self) -> int:
if self._dataset_attrs is not None:
return len(self._read & self._dataset_attrs)
return len(self._read)
return len(self._all_read)

def __contains__(self, key: t.Any) -> bool:
return key in self._read or key in self._call
return self.__attr_available(key, 'read') or self.__attr_available(key, 'call')

def __setitem__(self, key: str, value: str) -> None:
# See also __setattr__
if key in self._write:
if self.__attr_available(key, 'write'):
return setattr(self._obj, key, value)
raise KeyError(
f"{self._obj.__class__.__qualname__}.{key};"
f" current roles {set(self.current_roles)!r}"
)

def __iter__(self) -> t.Iterator[str]:
if self._dataset_attrs is not None:
source = self._read & self._dataset_attrs
else:
source = self._read
yield from source
yield from self._all_read

def __eq__(self, other: t.Any) -> bool:
if other == self._obj:
Expand Down
14 changes: 13 additions & 1 deletion tests/coaster_tests/sqlalchemy_roles_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -807,12 +807,24 @@ def test_role_grant(self) -> None:

m1.secondary_users.extend([u1, u2])

# Test that roles are discovered from lazy=dynamic relationships before commit
rm1u1 = m1.roles_for(u1)
rm1u2 = m1.roles_for(u2)
rm2u1 = m2.roles_for(u1)
rm2u2 = m2.roles_for(u2)
assert 'primary_role' in rm1u1
assert 'primary_role' not in rm1u2
assert 'primary_role' not in rm2u1
assert 'primary_role' in rm2u2

self.session.add_all([m1, m2, u1, u2])
self.session.commit()

# Test that roles are discovered from lazy=dynamic relationships
# Test that roles are discovered from lazy=dynamic relationships after commit
rm1u1 = m1.roles_for(u1)
rm1u2 = m1.roles_for(u2)
rm2u1 = m2.roles_for(u1)
rm2u2 = m2.roles_for(u2)
assert 'primary_role' in rm1u1
assert 'primary_role' not in rm1u2
assert 'primary_role' not in rm2u1
Expand Down

0 comments on commit 5e85797

Please sign in to comment.