From 922d3c2f5c8547a377fc42e6b19c5f17e0b052b9 Mon Sep 17 00:00:00 2001 From: Adrian Moennich Date: Tue, 31 Jan 2023 14:39:00 +0100 Subject: [PATCH] Make overriding InternalSearch parts easier --- indico/modules/search/internal.py | 122 +++++++++++++++--------------- 1 file changed, 63 insertions(+), 59 deletions(-) diff --git a/indico/modules/search/internal.py b/indico/modules/search/internal.py index a3fc230efbe..60cf4a905f3 100644 --- a/indico/modules/search/internal.py +++ b/indico/modules/search/internal.py @@ -88,6 +88,63 @@ def search(self, query, user=None, page=None, object_types=(), *, admin_override 'results': results, } + def _preload_categories(self, objs, preloaded_categories): + obj_types = {type(o) for o in objs} + assert len(obj_types) == 1 + obj_type = obj_types.pop() + + if obj_type == Event: + chain_query = db.session.query(Event.category_chain).filter(Event.id.in_(o.id for o in objs)) + elif obj_type == Category: + chain_query = db.session.query(Category.chain_ids).filter(Category.id.in_(o.id for o in objs)) + elif obj_type == Contribution: + chain_query = (db.session.query(Event.category_chain) + .join(Contribution.event) + .filter(Contribution.id.in_(o.id for o in objs))) + elif obj_type == Attachment: + chain_query = (db.session.query(Event.category_chain) + .join(Attachment.folder) + .join(AttachmentFolder.event) + .filter(Attachment.id.in_(o.id for o in objs))) + elif obj_type == EventNote: + chain_query = (db.session.query(Event.category_chain) + .join(EventNote.event) + .filter(EventNote.id.in_(o.id for o in objs))) + else: + raise Exception(f'Unhandled object type: {obj_type}') + category_ids = set(itertools.chain.from_iterable(id for id, in chain_query)) + query = ( + Category.query + .filter(Category.id.in_(category_ids)) + .options(load_only('id', 'parent_id', 'protection_mode')) + ) + Category.preload_relationships(query, 'acl_entries', + strategy=lambda rel: _apply_acl_entry_strategy(subqueryload(rel), + CategoryPrincipal)) + preloaded_categories |= set(query) + + def _can_access(self, user, obj, allow_effective_protection_mode=True, admin_override_enabled=False): + if isinstance(obj, (Category, Event, Session, Contribution)): + # more efficient for events/categories/contribs since it avoids climbing up the chain + protection_mode = (obj.effective_protection_mode if allow_effective_protection_mode + else obj.protection_mode) + elif isinstance(obj, Attachment): + # attachments don't have it so we can only skip access checks if they + # are public themselves + protection_mode = obj.protection_mode + elif isinstance(obj, EventNote): + # notes inherit from their parent + return self._can_access(user, obj.object, allow_effective_protection_mode=False, + admin_override_enabled=admin_override_enabled) + elif isinstance(obj, SubContribution): + # subcontributions inherit from their contribution + return self._can_access(user, obj.contribution, allow_effective_protection_mode=False, + admin_override_enabled=admin_override_enabled) + else: + raise Exception(f'Unexpected object: {obj}') + return (protection_mode == ProtectionMode.public or + obj.can_access(user, allow_admin=admin_override_enabled)) + def _paginate(self, query, page, column, user, admin_override_enabled): reverse = False pagenav = {'prev': None, 'next': None} @@ -104,65 +161,12 @@ def _paginate(self, query, page, column, user, admin_override_enabled): reverse = True preloaded_categories = set() - - def _preload_categories(objs): - nonlocal preloaded_categories - obj_types = {type(o) for o in objs} - assert len(obj_types) == 1 - obj_type = obj_types.pop() - - if obj_type == Event: - chain_query = db.session.query(Event.category_chain).filter(Event.id.in_(o.id for o in objs)) - elif obj_type == Category: - chain_query = db.session.query(Category.chain_ids).filter(Category.id.in_(o.id for o in objs)) - elif obj_type == Contribution: - chain_query = (db.session.query(Event.category_chain) - .join(Contribution.event) - .filter(Contribution.id.in_(o.id for o in objs))) - elif obj_type == Attachment: - chain_query = (db.session.query(Event.category_chain) - .join(Attachment.folder) - .join(AttachmentFolder.event) - .filter(Attachment.id.in_(o.id for o in objs))) - elif obj_type == EventNote: - chain_query = (db.session.query(Event.category_chain) - .join(EventNote.event) - .filter(EventNote.id.in_(o.id for o in objs))) - else: - raise Exception(f'Unhandled object type: {obj_type}') - category_ids = set(itertools.chain.from_iterable(id for id, in chain_query)) - query = ( - Category.query - .filter(Category.id.in_(category_ids)) - .options(load_only('id', 'parent_id', 'protection_mode')) - ) - Category.preload_relationships(query, 'acl_entries', - strategy=lambda rel: _apply_acl_entry_strategy(subqueryload(rel), - CategoryPrincipal)) - preloaded_categories |= set(query) - - def _can_access(obj, allow_effective_protection_mode=True): - if isinstance(obj, (Category, Event, Session, Contribution)): - # more efficient for events/categories/contribs since it avoids climbing up the chain - protection_mode = (obj.effective_protection_mode if allow_effective_protection_mode - else obj.protection_mode) - elif isinstance(obj, Attachment): - # attachments don't have it so we can only skip access checks if they - # are public themselves - protection_mode = obj.protection_mode - elif isinstance(obj, EventNote): - # notes inherit from their parent - return _can_access(obj.object, allow_effective_protection_mode=False) - elif isinstance(obj, SubContribution): - # subcontributions inherit from their contribution - return _can_access(obj.contribution, allow_effective_protection_mode=False) - else: - raise Exception(f'Unexpected object: {obj}') - return (protection_mode == ProtectionMode.public or - obj.can_access(user, allow_admin=admin_override_enabled)) - - res = get_n_matching(query, self.RESULTS_PER_PAGE + 1, _can_access, prefetch_factor=20, - preload_bulk=_preload_categories) + res = get_n_matching( + query, self.RESULTS_PER_PAGE + 1, + lambda obj: self._can_access(user, obj, admin_override_enabled=admin_override_enabled), + prefetch_factor=20, + preload_bulk=lambda objs: self._preload_categories(objs, preloaded_categories) + ) if len(res) > self.RESULTS_PER_PAGE: # we queried 1 more so we can see if there are more results available