Skip to content

Commit

Permalink
Make overriding InternalSearch parts easier
Browse files Browse the repository at this point in the history
  • Loading branch information
ThiefMaster committed Jan 31, 2023
1 parent 53c4297 commit 922d3c2
Showing 1 changed file with 63 additions and 59 deletions.
122 changes: 63 additions & 59 deletions indico/modules/search/internal.py
Original file line number Diff line number Diff line change
Expand Up @@ -88,6 +88,63 @@ def search(self, query, user=None, page=None, object_types=(), *, admin_override
'results': results,
}

def _preload_categories(self, objs, preloaded_categories):
obj_types = {type(o) for o in objs}
assert len(obj_types) == 1
obj_type = obj_types.pop()

if obj_type == Event:
chain_query = db.session.query(Event.category_chain).filter(Event.id.in_(o.id for o in objs))
elif obj_type == Category:
chain_query = db.session.query(Category.chain_ids).filter(Category.id.in_(o.id for o in objs))
elif obj_type == Contribution:
chain_query = (db.session.query(Event.category_chain)
.join(Contribution.event)
.filter(Contribution.id.in_(o.id for o in objs)))
elif obj_type == Attachment:
chain_query = (db.session.query(Event.category_chain)
.join(Attachment.folder)
.join(AttachmentFolder.event)
.filter(Attachment.id.in_(o.id for o in objs)))
elif obj_type == EventNote:
chain_query = (db.session.query(Event.category_chain)
.join(EventNote.event)
.filter(EventNote.id.in_(o.id for o in objs)))
else:
raise Exception(f'Unhandled object type: {obj_type}')
category_ids = set(itertools.chain.from_iterable(id for id, in chain_query))
query = (
Category.query
.filter(Category.id.in_(category_ids))
.options(load_only('id', 'parent_id', 'protection_mode'))
)
Category.preload_relationships(query, 'acl_entries',
strategy=lambda rel: _apply_acl_entry_strategy(subqueryload(rel),
CategoryPrincipal))
preloaded_categories |= set(query)

def _can_access(self, user, obj, allow_effective_protection_mode=True, admin_override_enabled=False):
if isinstance(obj, (Category, Event, Session, Contribution)):
# more efficient for events/categories/contribs since it avoids climbing up the chain
protection_mode = (obj.effective_protection_mode if allow_effective_protection_mode
else obj.protection_mode)
elif isinstance(obj, Attachment):
# attachments don't have it so we can only skip access checks if they
# are public themselves
protection_mode = obj.protection_mode
elif isinstance(obj, EventNote):
# notes inherit from their parent
return self._can_access(user, obj.object, allow_effective_protection_mode=False,
admin_override_enabled=admin_override_enabled)
elif isinstance(obj, SubContribution):
# subcontributions inherit from their contribution
return self._can_access(user, obj.contribution, allow_effective_protection_mode=False,
admin_override_enabled=admin_override_enabled)
else:
raise Exception(f'Unexpected object: {obj}')
return (protection_mode == ProtectionMode.public or
obj.can_access(user, allow_admin=admin_override_enabled))

def _paginate(self, query, page, column, user, admin_override_enabled):
reverse = False
pagenav = {'prev': None, 'next': None}
Expand All @@ -104,65 +161,12 @@ def _paginate(self, query, page, column, user, admin_override_enabled):
reverse = True

preloaded_categories = set()

def _preload_categories(objs):
nonlocal preloaded_categories
obj_types = {type(o) for o in objs}
assert len(obj_types) == 1
obj_type = obj_types.pop()

if obj_type == Event:
chain_query = db.session.query(Event.category_chain).filter(Event.id.in_(o.id for o in objs))
elif obj_type == Category:
chain_query = db.session.query(Category.chain_ids).filter(Category.id.in_(o.id for o in objs))
elif obj_type == Contribution:
chain_query = (db.session.query(Event.category_chain)
.join(Contribution.event)
.filter(Contribution.id.in_(o.id for o in objs)))
elif obj_type == Attachment:
chain_query = (db.session.query(Event.category_chain)
.join(Attachment.folder)
.join(AttachmentFolder.event)
.filter(Attachment.id.in_(o.id for o in objs)))
elif obj_type == EventNote:
chain_query = (db.session.query(Event.category_chain)
.join(EventNote.event)
.filter(EventNote.id.in_(o.id for o in objs)))
else:
raise Exception(f'Unhandled object type: {obj_type}')
category_ids = set(itertools.chain.from_iterable(id for id, in chain_query))
query = (
Category.query
.filter(Category.id.in_(category_ids))
.options(load_only('id', 'parent_id', 'protection_mode'))
)
Category.preload_relationships(query, 'acl_entries',
strategy=lambda rel: _apply_acl_entry_strategy(subqueryload(rel),
CategoryPrincipal))
preloaded_categories |= set(query)

def _can_access(obj, allow_effective_protection_mode=True):
if isinstance(obj, (Category, Event, Session, Contribution)):
# more efficient for events/categories/contribs since it avoids climbing up the chain
protection_mode = (obj.effective_protection_mode if allow_effective_protection_mode
else obj.protection_mode)
elif isinstance(obj, Attachment):
# attachments don't have it so we can only skip access checks if they
# are public themselves
protection_mode = obj.protection_mode
elif isinstance(obj, EventNote):
# notes inherit from their parent
return _can_access(obj.object, allow_effective_protection_mode=False)
elif isinstance(obj, SubContribution):
# subcontributions inherit from their contribution
return _can_access(obj.contribution, allow_effective_protection_mode=False)
else:
raise Exception(f'Unexpected object: {obj}')
return (protection_mode == ProtectionMode.public or
obj.can_access(user, allow_admin=admin_override_enabled))

res = get_n_matching(query, self.RESULTS_PER_PAGE + 1, _can_access, prefetch_factor=20,
preload_bulk=_preload_categories)
res = get_n_matching(
query, self.RESULTS_PER_PAGE + 1,
lambda obj: self._can_access(user, obj, admin_override_enabled=admin_override_enabled),
prefetch_factor=20,
preload_bulk=lambda objs: self._preload_categories(objs, preloaded_categories)
)

if len(res) > self.RESULTS_PER_PAGE:
# we queried 1 more so we can see if there are more results available
Expand Down

0 comments on commit 922d3c2

Please sign in to comment.