-
Notifications
You must be signed in to change notification settings - Fork 5
/
core.py
225 lines (195 loc) · 6.67 KB
/
core.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
from __future__ import annotations
import gzip
from abc import ABCMeta, abstractmethod
from io import IOBase
from typing import Any, Generic, Iterable, Type, overload
import requests
from typing_extensions import Literal, get_args
from unipressed.dataset.search import Search
from unipressed.dataset.type_vars import (
FieldsType,
FormatType,
JsonResultType,
QueryType,
)
class DatasetClient(
Generic[QueryType, JsonResultType, FieldsType, FormatType],
metaclass=ABCMeta,
):
"""
The base class for all UniProt dataset clients. All methods documented here are available in any of the subclasses. This is a static class that you will never need to instantiate.
"""
@classmethod
def _id_field(cls, record: JsonResultType) -> str:
"""
Given a record, extracts the accession/ID field from it.
"""
return record["id"]
@classmethod
@abstractmethod
def name(cls) -> str:
...
@classmethod
def search(
cls,
query: QueryType,
format: FormatType | Literal["json"] = "json",
fields: Iterable[FieldsType] | None = None,
size: int = 500,
) -> Search[QueryType, JsonResultType, FieldsType, FormatType]:
"""
Creates an object that can be used to perform a search query over this dataset.
Refer to the [unipressed.dataset.search.Search][] reference for more information on how to use it.
"""
return Search[QueryType, JsonResultType, FieldsType, FormatType](
query=query, format=format, dataset=cls.name(), fields=fields, size=size
)
@overload
@classmethod
def fetch_one(
cls, id: str, format: Literal["json"] = "json", parse: Literal[True] = True
) -> JsonResultType:
...
@overload
@classmethod
def fetch_one(
cls, id: str, format: FormatType, parse: Literal[False] = False
) -> IOBase:
...
@classmethod
def fetch_one(
cls, id: str, format: str = "json", parse: bool = True
) -> JsonResultType | IOBase:
"""
Fetches a single record from this dataset using its ID.
Args:
id : The ID of the record to fetch. The format of this will depend on the dataset.
format : The format of the result. The available options will depend on the subclass you are using, but the type checker/autocomplete will enforce available options.
parse : If true, parse the result into a JSON dictionary. Defaults to True.
Returns:
: If parse is True, a dictionary. Otherwise, a file object containing the results in the specified format.
"""
res = requests.get(
f"https://rest.uniprot.org/{cls.name()}/{id}.{format}", stream=True
)
res.raise_for_status()
if parse and format == "json":
return res.json()
elif res.headers.get("Content-Encoding", None) == "gzip":
return gzip.open(res.raw, "rb")
else:
return res.raw
@classmethod
def _type_args(cls) -> tuple[Type, ...]:
"""
Returns the type arguments for this dataset.
"""
return get_args(cls.__orig_bases__[0]) # type: ignore
@classmethod
def _allowed_query_fields(cls) -> set[str]:
"""
Returns the type arguments for the queries to this dataset.
"""
query_dict, _ = get_args(cls._query_type())
return query_dict.__optional_keys__ | query_dict.__required_keys__ - {
"and_",
"or_",
"not_",
}
@classmethod
def _allowed_return_fields(cls) -> set[str]:
"""
Returns the type arguments for the queries to this dataset.
"""
return set(get_args(cls._field_type()))
@classmethod
def _query_type(cls):
"""
Returns the type arguments for the queries to this dataset.
"""
return cls._type_args()[0]
@classmethod
def _result_type(cls):
"""
Returns the type of query results to this dataset.
"""
return cls._type_args()[1]
@classmethod
def _field_type(cls):
"""
Returns the type of allowed fields for search queries.
"""
return cls._type_args()[2]
@classmethod
def _format_type(cls):
"""
Returns the type of allowed formats for search queries.
"""
return cls._type_args()[3]
@classmethod
def _allowed_formats(cls) -> set[str]:
"""
Returns a set of allowed formats
"""
return set(get_args(cls._format_type()))
class FetchManyClient(DatasetClient[QueryType, JsonResultType, FieldsType, FormatType]):
"""
Dataset subclass for datasets that can be queried by multiple IDs. Not all datasets support this.
"""
@classmethod
@abstractmethod
def _bulk_id_param(cls) -> str:
"""
The name of the GET query parameter used to define the list of IDs in a bulk query.
"""
return cls._bulk_endpoint()
@classmethod
@abstractmethod
def _bulk_endpoint(cls) -> str:
"""
The name of the URL used to query a bulk list of IDs.
"""
...
@overload
@classmethod
def fetch_many(
cls,
ids: Iterable[str],
format: Literal["json"] = "json",
parse: Literal[True] = True,
) -> Iterable[JsonResultType]:
...
@overload
@classmethod
def fetch_many(
cls, ids: Iterable[str], format: FormatType, parse: Literal[False] = False
) -> IOBase:
...
@classmethod
def fetch_many(
cls,
ids: Iterable[str],
format: FormatType | Literal["json"] = "json",
parse: bool = True,
) -> Iterable[JsonResultType] | IOBase:
"""
Fetches multiple records using their accessions.
Args:
ids : The accessions to query
format : The format to return the records. Defaults to "json".
parse : Only supported for JSON. If True, parses the result instead of returning a raw file. Defaults to True.
Returns:
: If parse is True, a list of dictionaries. Otherwise, a file object containing the results in the specified format.
"""
res = requests.get(
f"https://rest.uniprot.org/{cls.name()}/{cls._bulk_endpoint()}",
params={cls._bulk_id_param(): ",".join(ids), "format": format},
stream=True,
)
res.raise_for_status()
if parse and format == "json":
return res.json()["results"]
elif res.headers.get("Content-Encoding", None) == "gzip":
return gzip.open(res.raw, "rb")
else:
return res.raw