Skip to content

Commit

Permalink
expanded fields: add a non-resolvable system record
Browse files Browse the repository at this point in the history
* introduce a base ghost schema
* closes zenodo/zenodo-rdm#284
  • Loading branch information
anikachurilova authored and kpsherva committed Jun 2, 2023
1 parent 5335294 commit 967adcd
Show file tree
Hide file tree
Showing 4 changed files with 37 additions and 7 deletions.
6 changes: 3 additions & 3 deletions invenio_records_resources/references/registry.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@

from abc import ABC, abstractmethod

from invenio_access.permissions import system_process
from invenio_access.permissions import system_identity, system_user_id


class ResolverRegistryBase(ABC):
Expand Down Expand Up @@ -101,7 +101,7 @@ def reference_entity(cls, entity, raise_=False):
@classmethod
def reference_identity(cls, identity, raise_=False):
"""Create a reference dict for the user behind the given identity."""
if system_process in identity.provides:
return None
if identity == system_identity:
return {"user": str(system_user_id)}

return {"user": str(identity.id)}
20 changes: 16 additions & 4 deletions invenio_records_resources/services/records/results.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,10 +10,12 @@
"""Service results."""
from abc import ABC, abstractmethod

from invenio_access.permissions import system_user_id
from invenio_records.dictutils import dict_lookup, dict_merge, dict_set

from ...pagination import Pagination
from ..base import ServiceItemResult, ServiceListResult
from .schema import BaseGhostSchema


class RecordItem(ServiceItemResult):
Expand Down Expand Up @@ -277,6 +279,14 @@ def ghost_record(self, value):
"""
raise NotImplementedError()

@abstractmethod
def system_record(self):
"""Return the representation of a system user.
This is used for the user with id = 'system'.
"""
raise NotImplementedError()

def has(self, service, value):
"""Return true if field has given value for given service."""
try:
Expand All @@ -292,10 +302,12 @@ def add_service_value(self, service, value):

def add_dereferenced_record(self, service, value, resolved_rec):
"""Save the dereferenced record."""
# mark the record as a "ghost" or "system" record i.e not resolvable
if resolved_rec is None:
resolved_rec = self.ghost_record({"id": value})
# mark the record as a "ghost" record i.e not resolvable
resolved_rec["is_ghost"] = True
if value == system_user_id:
resolved_rec = self.system_record()
else:
resolved_rec = self.ghost_record({"id": value})
self._service_values[service][value] = resolved_rec

def get_dereferenced_record(self, service, value):
Expand Down Expand Up @@ -389,7 +401,7 @@ def _add_dereferenced_record(service, value, resolved_rec):
if ghost_values:
for value in ghost_values:
# set dereferenced record to None. That will trigger eventually
# the field.ghost_recor() to be called
# the field.ghost_record() to be called
_add_dereferenced_record(service, value, None)

def resolve(self, identity, hits):
Expand Down
6 changes: 6 additions & 0 deletions invenio_records_resources/services/records/schema.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,12 @@ def clean(self, data, **kwargs):
return data


class BaseGhostSchema(Schema):
"""Base ghost schema."""

is_ghost = fields.Constant(True, dump_only=True)


class ServiceSchemaWrapper:
"""Schema wrapper that enhances load/dump of wrapped schema.
Expand Down
12 changes: 12 additions & 0 deletions tests/services/test_results_expand.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,10 @@ def ghost_record(self, value):
"""Override default."""
return {}

def system_record(self):
"""Override default."""
raise NotImplementedError()

def get_value_service(self, value):
"""Override default."""
if value.get("user"):
Expand All @@ -75,6 +79,10 @@ def ghost_record(self, value):
"""Override default."""
return {}

def system_record(self):
"""Override default."""
raise NotImplementedError()

def get_value_service(self, value):
"""Override default."""
return value, mocked_simple_service
Expand All @@ -93,6 +101,10 @@ def ghost_record(self, value):
"""Override default."""
return {}

def system_record(self):
"""Override default."""
raise NotImplementedError()

def get_value_service(self, value):
"""Override default."""
return value, mocked_other_service
Expand Down

0 comments on commit 967adcd

Please sign in to comment.