-
Notifications
You must be signed in to change notification settings - Fork 30
/
read_resource.py
291 lines (242 loc) · 11.3 KB
/
read_resource.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
from inspect import signature
from typing import Any, Dict, List, Optional, Type, Union
import orjson
from fastapi import Depends, HTTPException, Path, Request, Response
from pydantic import BaseModel
from pymongo import timeout as query_timeout
from pymongo.errors import NetworkTimeout, PyMongoError
from maggma.api.models import Meta
from maggma.api.models import Response as ResponseModel
from maggma.api.query_operator import PaginationQuery, QueryOperator, SparseFieldsQuery
from maggma.api.resource import HeaderProcessor, HintScheme, Resource
from maggma.api.resource.utils import attach_query_ops, generate_query_pipeline
from maggma.api.utils import STORE_PARAMS, merge_queries, serialization_helper
from maggma.core import Store
from maggma.stores import MongoStore, S3Store
class ReadOnlyResource(Resource):
"""
Implements a REST Compatible Resource as a GET URL endpoint
This class provides a number of convenience features
including full pagination, field projection
"""
def __init__(
self,
store: Store,
model: Type[BaseModel],
tags: Optional[List[str]] = None,
query_operators: Optional[List[QueryOperator]] = None,
key_fields: Optional[List[str]] = None,
hint_scheme: Optional[HintScheme] = None,
header_processor: Optional[HeaderProcessor] = None,
timeout: Optional[int] = None,
enable_get_by_key: bool = False,
enable_default_search: bool = True,
disable_validation: bool = False,
query_disk_use: bool = False,
include_in_schema: Optional[bool] = True,
sub_path: Optional[str] = "/",
):
"""
Args:
store: The Maggma Store to get data from
model: The pydantic model this Resource represents
tags: List of tags for the Endpoint
query_operators: Operators for the query language
hint_scheme: The hint scheme to use for this resource
header_processor: The header processor to use for this resource
timeout: Time in seconds Pymongo should wait when querying MongoDB
before raising a timeout error
key_fields: List of fields to always project. Default uses SparseFieldsQuery
to allow user to define these on-the-fly.
enable_get_by_key: Enable get by key route for endpoint.
enable_default_search: Enable default endpoint search behavior.
query_disk_use: Whether to use temporary disk space in large MongoDB queries.
disable_validation: Whether to use ORJSON and provide a direct FastAPI response.
Note this will disable auto JSON serialization and response validation with the
provided model.
include_in_schema: Whether the endpoint should be shown in the documented schema.
sub_path: sub-URL path for the resource.
"""
self.store = store
self.tags = tags or []
self.hint_scheme = hint_scheme
self.header_processor = header_processor
self.key_fields = key_fields
self.versioned = False
self.enable_get_by_key = enable_get_by_key
self.enable_default_search = enable_default_search
self.timeout = timeout
self.disable_validation = disable_validation
self.include_in_schema = include_in_schema
self.sub_path = sub_path
self.query_disk_use = query_disk_use
self.response_model = ResponseModel[model] # type: ignore
if not isinstance(store, MongoStore) and self.hint_scheme is not None:
raise ValueError("Hint scheme is only supported for MongoDB stores")
self.query_operators = (
query_operators
if query_operators is not None
else [
PaginationQuery(),
SparseFieldsQuery(
model,
default_fields=[self.store.key, self.store.last_updated_field],
),
]
)
super().__init__(model)
def prepare_endpoint(self):
"""
Internal method to prepare the endpoint by setting up default handlers
for routes
"""
if self.enable_get_by_key:
self.build_get_by_key()
if self.enable_default_search:
self.build_dynamic_model_search()
def build_get_by_key(self):
key_name = self.store.key
model_name = self.model.__name__
if self.key_fields is None:
field_input = SparseFieldsQuery(self.model, [self.store.key, self.store.last_updated_field]).query
else:
def field_input():
return {"properties": self.key_fields}
def get_by_key(
request: Request,
temp_response: Response,
key: str = Path(
...,
alias=key_name,
title=f"The {key_name} of the {model_name} to get",
),
_fields: STORE_PARAMS = Depends(field_input),
):
f"""
Gets a document by the primary key in the store
Args:
{key_name}: the id of a single {model_name}
Returns:
a single {model_name} document
"""
self.store.connect()
try:
with query_timeout(self.timeout):
item = [
self.store.query_one(
criteria={self.store.key: key},
properties=_fields["properties"],
)
]
except (NetworkTimeout, PyMongoError) as e:
if e.timeout:
raise HTTPException(
status_code=504,
detail="Server timed out trying to obtain data. Try again with a smaller request.",
)
else:
raise HTTPException(
status_code=500,
)
if item == [None]:
raise HTTPException(
status_code=404,
detail=f"Item with {self.store.key} = {key} not found",
)
for operator in self.query_operators:
item = operator.post_process(item, {})
response = {"data": item} # type: ignore
if self.disable_validation:
response = Response(orjson.dumps(response, default=serialization_helper)) # type: ignore
if self.header_processor is not None:
if self.disable_validation:
self.header_processor.process_header(response, request)
else:
self.header_processor.process_header(temp_response, request)
return response
self.router.get(
f"{self.sub_path}{{{key_name}}}/",
summary=f"Get a {model_name} document by by {key_name}",
response_description=f"Get a {model_name} document by {key_name}",
response_model=self.response_model,
response_model_exclude_unset=True,
tags=self.tags,
include_in_schema=self.include_in_schema,
)(get_by_key)
def build_dynamic_model_search(self):
model_name = self.model.__name__
def search(**queries: Dict[str, STORE_PARAMS]) -> Union[Dict, Response]:
request: Request = queries.pop("request") # type: ignore
temp_response: Response = queries.pop("temp_response") # type: ignore
query_params = [
entry for _, i in enumerate(self.query_operators) for entry in signature(i.query).parameters
]
overlap = [key for key in request.query_params if key not in query_params]
if any(overlap):
if "limit" in overlap or "skip" in overlap:
raise HTTPException(
status_code=400,
detail="'limit' and 'skip' parameters have been renamed. "
"Please update your API client to the newest version.",
)
else:
raise HTTPException(
status_code=400,
detail="Request contains query parameters which cannot be used: {}".format(", ".join(overlap)),
)
query: Dict[Any, Any] = merge_queries(list(queries.values())) # type: ignore
if self.hint_scheme is not None: # pragma: no cover
hints = self.hint_scheme.generate_hints(query)
query.update(hints)
self.store.connect()
try:
with query_timeout(self.timeout):
if isinstance(self.store, S3Store):
count = self.store.count(criteria=query.get("criteria")) # type: ignore
if self.query_disk_use:
data = list(self.store.query(**query, allow_disk_use=True)) # type: ignore
else:
data = list(self.store.query(**query))
else:
count = self.store.count(
criteria=query.get("criteria"), hint=query.get("count_hint")
) # type: ignore
pipeline = generate_query_pipeline(query, self.store)
agg_kwargs = {}
if query.get("agg_hint"):
agg_kwargs["hint"] = query["agg_hint"]
data = list(self.store._collection.aggregate(pipeline, **agg_kwargs))
except (NetworkTimeout, PyMongoError) as e:
if e.timeout:
raise HTTPException(
status_code=504,
detail="Server timed out trying to obtain data. Try again with a smaller request.",
)
else:
raise HTTPException(
status_code=500,
detail="Server timed out trying to obtain data. Try again with a smaller request,"
" or remove sorting fields and sort data locally.",
)
operator_meta = {}
for operator in self.query_operators:
data = operator.post_process(data, query)
operator_meta.update(operator.meta())
meta = Meta(total_doc=count)
response = {"data": data, "meta": {**meta.dict(), **operator_meta}} # type: ignore
if self.disable_validation:
response = Response(orjson.dumps(response, default=serialization_helper)) # type: ignore
if self.header_processor is not None:
if self.disable_validation:
self.header_processor.process_header(response, request)
else:
self.header_processor.process_header(temp_response, request)
return response
self.router.get(
self.sub_path,
tags=self.tags,
summary=f"Get {model_name} documents",
response_model=self.response_model,
response_description=f"Search for a {model_name}",
response_model_exclude_unset=True,
)(attach_query_ops(search, self.query_operators))