-
Notifications
You must be signed in to change notification settings - Fork 46
/
results.py
391 lines (326 loc) · 12.2 KB
/
results.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
# -*- coding: utf-8 -*-
#
# Copyright (C) 2020-2022 CERN.
# Copyright (C) 2020-2022 Northwestern University.
#
# Invenio-Records-Resources is free software; you can redistribute it and/or
# modify it under the terms of the MIT License; see LICENSE file for more
# details.
"""Service results."""
from abc import ABC, abstractmethod
from invenio_records.dictutils import dict_lookup, dict_merge, dict_set
from ...pagination import Pagination
from ..base import ServiceItemResult, ServiceListResult
class RecordItem(ServiceItemResult):
"""Single record result."""
def __init__(
self,
service,
identity,
record,
errors=None,
links_tpl=None,
schema=None,
expandable_fields=None,
expand=False,
):
"""Constructor."""
self._errors = errors
self._identity = identity
self._links_tpl = links_tpl
self._record = record
self._service = service
self._schema = schema or service.schema
self._fields_resolver = FieldsResolver(expandable_fields)
self._expand = expand
self._data = None
@property
def id(self):
"""Get the record id."""
return self._record.pid.pid_value
def __getitem__(self, key):
"""Key a key from the data."""
return self.data[key]
@property
def links(self):
"""Get links for this result item."""
return self._links_tpl.expand(self._identity, self._record)
@property
def _obj(self):
"""Return the object to dump."""
return self._record
@property
def data(self):
"""Property to get the record."""
if self._data:
return self._data
self._data = self._schema.dump(
self._obj,
context=dict(
identity=self._identity,
record=self._record,
),
)
if self._links_tpl:
self._data["links"] = self.links
if self._expand and self._fields_resolver:
self._fields_resolver.resolve(self._identity, [self._data])
fields = self._fields_resolver.expand(self._identity, self._data)
self._data["expanded"] = fields
return self._data
@property
def errors(self):
"""Get the errors."""
return self._errors
def to_dict(self):
"""Get a dictionary for the record."""
res = self.data
if self._errors:
res["errors"] = self._errors
return res
def has_permissions_to(self, actions):
"""Returns dict of "can_<action>": bool.
Placing this functionality here because it is a projection of the
record item's characteristics and allows us to re-use the
underlying data layer record. Because it is selective about the actions
it checks for performance reasons, it is not embedded in the `to_dict`
method.
:params actions: list of action strings
:returns dict:
Example:
record_item.permissions_to(["update_draft", "read_files"])
{
"can_update_draft": False,
"can_read_files": True
}
"""
return {
f"can_{action}": self._service.check_permission(
self._identity, action, record=self._record
)
for action in actions
}
class RecordList(ServiceListResult):
"""List of records result."""
def __init__(
self,
service,
identity,
results,
params=None,
links_tpl=None,
links_item_tpl=None,
schema=None,
expandable_fields=None,
expand=False,
):
"""Constructor.
:params service: a service instance
:params identity: an identity that performed the service request
:params results: the search results
:params params: dictionary of the query parameters
"""
self._identity = identity
self._results = results
self._service = service
self._schema = schema or service.schema
self._params = params
self._links_tpl = links_tpl
self._links_item_tpl = links_item_tpl
self._fields_resolver = FieldsResolver(expandable_fields)
self._expand = expand
def __len__(self):
"""Return the total numer of hits."""
return self.total
def __iter__(self):
"""Iterator over the hits."""
return self.hits
@property
def total(self):
"""Get total number of hits."""
if hasattr(self._results, "hits"):
return self._results.hits.total["value"]
else:
# handle scan(): returns a generator
return None
@property
def aggregations(self):
"""Get the search result aggregations."""
# TODO: have a way to label or not label
try:
return self._results.labelled_facets.to_dict()
except AttributeError:
return None
@property
def hits(self):
"""Iterator over the hits."""
for hit in self._results:
# Load dump
record = self._service.record_cls.loads(hit.to_dict())
# Project the record
projection = self._schema.dump(
record,
context=dict(
identity=self._identity,
record=record,
),
)
if self._links_item_tpl:
projection["links"] = self._links_item_tpl.expand(
self._identity, record
)
yield projection
@property
def pagination(self):
"""Create a pagination object."""
return Pagination(
self._params["size"],
self._params["page"],
self.total,
)
def to_dict(self):
"""Return result as a dictionary."""
# TODO: This part should imitate the result item above. I.e. add a
# "data" property which uses a ServiceSchema to dump the entire object.
hits = list(self.hits)
if self._expand and self._fields_resolver:
self._fields_resolver.resolve(self._identity, hits)
for hit in hits:
fields = self._fields_resolver.expand(self._identity, hit)
hit["expanded"] = fields
res = {
"hits": {
"hits": hits,
"total": self.total,
}
}
if self.aggregations:
res["aggregations"] = self.aggregations
if self._params:
res["sortBy"] = self._params["sort"]
if self._links_tpl:
res["links"] = self._links_tpl.expand(self._identity, self.pagination)
return res
class ExpandableField(ABC):
"""Field referencing to another record that can be expanded."""
def __init__(self, field_name):
"""Constructor.
:params field_name: the name of the field containing the value to
resolve the referenced record
:params service: the service to fetch the referenced record
"""
self._field_name = field_name
self._service_values = dict()
@property
def field_name(self):
"""Return field name."""
return self._field_name
@abstractmethod
def get_value_service(self, value):
"""Return the value and the service to fetch the referenced record."""
return None, None
def has(self, service, value):
"""Return true if field has given value for given service."""
try:
self._service_values[service][value]
except KeyError:
return False
return True
def add_service_value(self, service, value):
"""Store each value in the list of results for this field."""
self._service_values.setdefault(service, dict())
self._service_values[service].setdefault(value, None)
def add_dereferenced_record(self, service, value, resolved_rec):
"""Save the dereferenced record."""
self._service_values[service][value] = resolved_rec
def get_dereferenced_record(self, service, value):
"""Return the dereferenced record."""
return self._service_values[service][value]
@abstractmethod
def pick(self, identity, resolved_rec):
"""Pick the fields to return from the resolved record dict."""
return {"id": resolved_rec["id"]}
class FieldsResolver:
"""Resolve the reference record for each of the configured field.
Given a list of fields referencing other records/objects,
it fetches and returns the dereferenced record/obj.
To minimize the performance impact of resolving reference record, this
object will:
- first, collect all the possible values of each fields, grouping them
by service to be called to fetch the referenced record/obj
- it will then call the `service.read_many([ids])` method so that all
reference records are retrieved with one search per service type
- for each of the result to be returned, it will call the `pick` method
of each configured field to allow to choose what fields should be
selected and returned from the resolved record.
It supports resolution of nested fields out of the box.
"""
def __init__(self, expandable_fields):
"""Constructor.
:params expandable_fields: list of ExpandableField obj.
"""
self._fields = expandable_fields
def _collect_values(self, hits):
"""Collect all field values to be expanded."""
grouped_values = dict()
for hit in hits:
for field in self._fields:
try:
value = dict_lookup(hit, field.field_name)
except KeyError:
continue
else:
# value is not None
v, service = field.get_value_service(value)
field.add_service_value(service, v)
# collect values (ids) and group by service e.g.:
# service_1: (13, 4),
# service_2: (uuid1, uuid2, ...)
grouped_values.setdefault(service, set())
grouped_values[service].add(v)
return grouped_values
def _find_fields(self, service, value):
"""Find all fields matching service and value.
The `id` field used to match the resolved record is hardcoded,
as in the `read_many` method.
"""
fields = []
for field in self._fields:
if field.has(service, value):
fields.append(field)
return fields
def _fetch_referenced(self, grouped_values, identity):
"""Search and fetch referenced recs by ids."""
for service, values in grouped_values.items():
results = service.read_many(identity, list(values))
for hit in results.hits:
value = hit.get("id", None)
for field in self._find_fields(service, value):
field.add_dereferenced_record(service, value, hit)
def resolve(self, identity, hits):
"""Collect field values and resolve referenced records."""
_hits = list(hits) # ensure it is a list, when a single value passed
grouped_values = self._collect_values(_hits)
self._fetch_referenced(grouped_values, identity)
def expand(self, identity, hit):
"""Return the expanded fields for the given hit."""
results = dict()
for field in self._fields:
try:
value = dict_lookup(hit, field.field_name)
except KeyError:
continue
else:
# value is not None
v, service = field.get_value_service(value)
resolved_rec = field.get_dereferenced_record(service, v)
if not resolved_rec:
continue
output = field.pick(identity, resolved_rec)
# transform field name (potentially dotted) to nested dicts
# to keep the nested structure of the field
d = dict()
dict_set(d, field.field_name, output)
# merge dict with previous results
dict_merge(results, d)
return results