-
Notifications
You must be signed in to change notification settings - Fork 1
/
dataset.py
298 lines (232 loc) · 10.8 KB
/
dataset.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
from __future__ import annotations
from collections import UserDict
import json
from typing import Any, Dict, Generic, Iterator, Tuple, Type, TypeVar, Union
# from orjson import orjson
from pydantic import Field, PrivateAttr, ValidationError
from pydantic.generics import GenericModel
from pydantic.utils import lenient_issubclass
from omnipy.data.model import Model
ModelT = TypeVar('ModelT', bound=Model)
Undefined = object()
DATA_KEY = 'data'
# def orjson_dumps(v, *, default):
# # orjson.dumps returns bytes, to match standard json.dumps we need to decode
# return orjson.dumps(v, default=default).decode()
class Dataset(GenericModel, Generic[ModelT], UserDict):
"""
Dict-based container of data files that follow a specific Model
Dataset is a generic class that cannot be instantiated directly. Instead, a Dataset class needs
to be specialized with a data model before Dataset objects can be instantiated. A data model
functions as a data parser and guarantees that the parsed data follows the specified model.
The specialization must be done through the use of Model, either directly, e.g.::
MyDataset = Dataset[Model[Dict[str, List[int]]])
... or indirectly, using a Model subclass, e.g.::
class MyModel(Model[Dict[str, List[int]]):
pass
MyDataset = Dataset[MyModel]
... alternatively through the specification of a Dataset subclass::
class MyDataset(Dataset[MyModel]):
pass
The specialization can also be done in a more deeply nested structure, e.g.::
class MyNumberList(Model[List[int]]):
pass
class MyToplevelDict(Model[Dict[str, MyNumberList]]):
pass
class MyDataset(Dataset[MyToplevelDict]):
pass
Once instantiated, a dataset object functions as a dict of data files, with the keys
referring to the data file names and the contents to the data file contents, e.g.::
MyNumberListDataset = Dataset[Model[List[int]]]
my_dataset = MyNumberListDataset({'file_1': [1,2,3]})
my_dataset['file_2'] = [2,3,4]
print(my_dataset.keys())
The Dataset class is a wrapper class around the powerful `GenericModel` class from pydantic.
"""
class Config:
validate_assignment = True
# json_loads = orjson.loads
# json_dumps = orjson_dumps
data: Dict[str, ModelT] = Field(default={})
def __class_getitem__(cls, model: ModelT) -> ModelT:
# TODO: change model type to params: Union[Type[Any], Tuple[Type[Any], ...]]
# as in GenericModel
# For now, only singular model types are allowed. These lines are needed for
# interoperability with pydantic GenericModel, which internally stores the model
# as a tuple:
if isinstance(model, tuple) and len(model) == 1:
model = model[0]
if not isinstance(model, TypeVar) and not lenient_issubclass(model, Model):
raise TypeError('Invalid model: {}! '.format(model)
+ 'omnipy Dataset models must be a specialization of the omnipy '
'Model class.')
return super().__class_getitem__(model)
def __init__(self,
value: Union[Dict[str, Any], Iterator[Tuple[str, Any]]] = Undefined,
**input_data: Any) -> None:
if value != Undefined:
input_data[DATA_KEY] = value
if self.get_model_class() == ModelT:
self._raise_no_model_exception()
GenericModel.__init__(self, **input_data)
UserDict.__init__(self, self.data) # noqa
if not self.__doc__:
self._set_standard_field_description()
# TODO: Add test for get_model_class
def get_model_class(self) -> ModelT:
return self.__fields__.get(DATA_KEY).type_
# TODO: Update _raise_no_model_exception() text. Model is now a requirement
@staticmethod
def _raise_no_model_exception() -> None:
raise TypeError(
'Note: The Dataset class requires a concrete model to be specified as '
'a type hierarchy within brackets either directly, e.g.:\n\n'
'\tmodel = Dataset[List[int]]()\n\n'
'or indirectly in a subclass definition, e.g.:\n\n'
'\tclass MyNumberListDataset(Dataset[List[int]]): ...\n\n'
'In both cases, the use of the Model class or a subclass is encouraged if anything '
'other than the simplest cases, e.g.:\n\n'
'\tclass MyNumberListModel(Model[List[int]]): ...\n'
'\tclass MyDataset(Dataset[MyNumberListModel]): ...\n\n'
'Usage of Dataset without a type specification results in this exception. '
'Similar use of the Model class do not currently result in an exception, only '
'a warning message the first time this is done. However, this is just a '
'"poor man\'s exception" due to complex technicalities in that class. Please '
'explicitly specify types in both cases. ')
def _set_standard_field_description(self) -> None:
self.__fields__[DATA_KEY].field_info.description = self._get_standard_field_description()
@classmethod
def _get_standard_field_description(cls) -> str:
return ('This class represents a data in the `omnipy` Python package and contains '
'a set of named data items that follows the same data model. '
'It is a statically typed specialization of the Dataset class according to a '
'particular specialization of the Model class. Both main classes are wrapping '
'the excellent Python package named `pydantic`.')
def __setitem__(self, obj_type: str, data_obj: Any) -> None:
has_prev_value = obj_type in self.data
prev_value = self.data.get(obj_type)
try:
self.data[obj_type] = data_obj
self._validate(obj_type)
except: # noqa
if has_prev_value:
self.data[obj_type] = prev_value
else:
del self.data[obj_type]
raise
def __getitem__(self, obj_type: str) -> Any:
if obj_type in self.data:
return self.data[obj_type].contents
else:
return self.data[obj_type]
def _validate(self, _obj_type: str) -> None:
self.data = self.data # Triggers pydantic validation, as validate_assignment=True
def __iter__(self) -> Iterator:
return UserDict.__iter__(self)
def __setattr__(self, attr: str, value: Any) -> None:
if attr in self.__dict__ or attr == DATA_KEY or attr.startswith('__'):
super().__setattr__(attr, value)
else:
raise RuntimeError('Model does not allow setting of extra attributes')
def to_data(self) -> Dict[str, Any]:
return GenericModel.dict(self).get(DATA_KEY)
def from_data(self,
data: Union[Dict[str, Any], Iterator[Tuple[str, Any]]],
update: bool = True) -> None:
if not isinstance(data, dict):
data = dict(data)
if not update:
self.clear()
for obj_type, obj_val in data.items():
new_model = self.get_model_class()() # noqa
new_model.from_data(obj_val)
self[obj_type] = new_model
def to_json(self, pretty=False) -> Dict[str, str]:
result = {}
for key, val in self.to_data().items():
result[key] = self._pretty_print_json(val) if pretty else json.dumps(val)
return result
def from_json(self,
data: Union[Dict[str, str], Iterator[Tuple[str, str]]],
update: bool = True) -> None:
if not isinstance(data, dict):
data = dict(data)
if not update:
self.clear()
for obj_type, obj_val in data.items():
new_model = self.get_model_class()() # noqa
new_model.from_json(obj_val)
self[obj_type] = new_model
# @classmethod
# def get_type_args(cls):
# return cls.__fields__.get(DATA_KEY).type_
#
#
# @classmethod
# def create_from_json(cls, data: Union[str, Tuple[str]]):
# if isinstance(data, tuple):
# data = data[0]
#
# obj = cls()
# obj.from_json(data, update=False)
# return obj
#
# def __reduce__(self):
# return self.__class__.create_from_json, (self.to_json(),)
@classmethod
def to_json_schema(cls, pretty=False) -> Union[str, Dict[str, str]]:
result = {}
schema = cls.schema()
for key, val in schema['properties']['data'].items():
result[key] = val
result['title'] = schema['title']
result['definitions'] = schema['definitions']
if pretty:
return cls._pretty_print_json(result)
else:
return json.dumps(result)
@staticmethod
def _pretty_print_json(json_content: Any) -> str:
return json.dumps(json_content, indent=4)
def as_multi_model_dataset(self) -> MultiModelDataset[ModelT]:
multi_model_dataset = MultiModelDataset[self.get_model_class()]()
for obj_type in self:
multi_model_dataset.data[obj_type] = self.data[obj_type]
return multi_model_dataset
# TODO: Use json serializer package from the pydantic config instead of 'json'
class MultiModelDataset(Dataset[ModelT], Generic[ModelT]):
"""
Variant of Dataset that allows custom models to be set on individual data files
Note that the general model still needs to hold for all data files, in addition to any
custom models.
"""
_custom_field_models: Dict[str, ModelT] = PrivateAttr(default={})
def set_model(self, obj_type: str, model: ModelT) -> None:
try:
self._custom_field_models[obj_type] = model
if obj_type in self.data:
self._validate(obj_type)
else:
self.data[obj_type] = model()
except ValidationError:
del self._custom_field_models[obj_type]
raise
def get_model(self, obj_type: str) -> ModelT:
if obj_type in self._custom_field_models:
return self._custom_field_models[obj_type]
else:
return self.get_model_class()
def _validate(self, obj_type: str) -> None:
if obj_type in self._custom_field_models:
model = self._custom_field_models[obj_type]
if not isinstance(model, Model):
model = Model[model]
data_obj = self._to_data_if_model(self.data[obj_type])
parsed_data = self._to_data_if_model(model(data_obj))
self.data[obj_type] = parsed_data
super()._validate(obj_type) # validates all data according to ModelT
@staticmethod
def _to_data_if_model(data_obj: Any):
if isinstance(data_obj, Model):
data_obj = data_obj.to_data()
return data_obj