diff --git a/invenio_rdm_records/services/config.py b/invenio_rdm_records/services/config.py index 812b658e7..ebb2af5ce 100644 --- a/invenio_rdm_records/services/config.py +++ b/invenio_rdm_records/services/config.py @@ -64,6 +64,7 @@ FromConfigRequiredPIDs, ) from .permissions import RDMRecordPermissionPolicy +from .results import RDMRecordList from .result_items import GrantItem, GrantList, SecretLinkItem, SecretLinkList from .schemas import RDMParentSchema, RDMRecordSchema from .schemas.community_records import CommunityRecordsSchema @@ -251,6 +252,7 @@ class RDMRecordServiceConfig(RecordServiceConfig, ConfiguratorMixin): link_result_list_cls = SecretLinkList grant_result_item_cls = GrantItem grant_result_list_cls = GrantList + result_list_cls = RDMRecordList default_files_enabled = FromConfig("RDM_DEFAULT_FILES_ENABLED", default=True) diff --git a/invenio_rdm_records/services/results.py b/invenio_rdm_records/services/results.py index 9ac0076b9..96b7a6adb 100644 --- a/invenio_rdm_records/services/results.py +++ b/invenio_rdm_records/services/results.py @@ -11,7 +11,7 @@ from invenio_communities.communities.entity_resolvers import pick_fields from invenio_communities.communities.schema import CommunityGhostSchema from invenio_communities.proxies import current_communities -from invenio_records_resources.services.records.results import ExpandableField +from invenio_records_resources.services.records.results import ExpandableField, RecordList from invenio_users_resources.proxies import current_user_resources from .dummy import DummyExpandingService @@ -69,3 +69,32 @@ def get_value_service(self, value): def pick(self, identity, resolved_rec): """Pick fields defined in the entity resolver.""" return resolved_rec + +class RDMRecordList(RecordList): + """Record list with custom fields.""" + + @property + def hits(self): + """Iterator over the hits.""" + for hit in self._results: + # Load dump + record_dict = hit.to_dict() + if record_dict["is_published"]: + record = self._service.record_cls.loads(record_dict) + else: + record = self._service.draft_cls.loads(record_dict) + + # Project the record + projection = self._schema.dump( + record, + context=dict( + identity=self._identity, + record=record, + ), + ) + if self._links_item_tpl: + projection["links"] = self._links_item_tpl.expand( + self._identity, record + ) + + yield projection