-
Notifications
You must be signed in to change notification settings - Fork 62
/
utils.py
158 lines (126 loc) · 4.37 KB
/
utils.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
""" Utilities to enable stubbing in JSON schema into pydantic for Monty """
from typing import (
Any,
Dict,
List,
Optional,
Sequence,
Set,
Tuple,
Type,
TypeVar,
Union,
get_type_hints,
)
from monty.json import MontyDecoder, MSONable
from numpy import ndarray
from pydantic import BaseModel, Field, create_model
from pydantic.fields import FieldInfo, ModelField
built_in_primitives = (bool, int, float, complex, range, str, bytes, None)
prim_to_type_hint: Dict[type, Any] = {list: List, tuple: Tuple, dict: Dict, set: Set}
STUBS: Dict[type, type] = {} # Central location for Pydantic Stub classes
def patch_msonable(monty_cls: type):
"""
Patch's an MSONable class so it can be used in pydantic models
monty_cls: A MSONable class
"""
if not issubclass(monty_cls, MSONable):
raise ValueError("Must provide an MSONable class to wrap")
def __get_validators__(cls):
yield cls.validate_monty
def validate_monty(cls, v):
"""
Stub validator for MSONable
"""
if isinstance(v, cls):
return v
elif isinstance(v, dict):
# Relegate to Monty
new_obj = MontyDecoder().process_decoded(v)
if not isinstance(new_obj, cls):
raise ValueError(f"Wrong dict for {cls.__name__}")
return new_obj
else:
raise ValueError(f"Must provide {cls.__name__} or Dict version")
setattr(monty_cls, "validate_monty", classmethod(validate_monty))
setattr(monty_cls, "__get_validators__", classmethod(__get_validators__))
setattr(monty_cls, "__pydantic_model__", STUBS[monty_cls])
def use_model(monty_cls: type, pydantic_model: type, add_monty: bool = True):
"""
Use a provided pydantic model to describe a Monty MSONable class
"""
if add_monty:
STUBS[monty_cls] = MSONable_to_pydantic(
monty_cls, pydantic_model=pydantic_model
)
else:
STUBS[monty_cls] = pydantic_model
patch_msonable(monty_cls)
def __make_pydantic(cls):
"""
Temporary wrapper function to convert an MSONable class into a PyDantic
Model for the sake of building schemas
"""
if any(cls == T for T in built_in_primitives):
return cls
if cls in prim_to_type_hint:
return prim_to_type_hint[cls]
if cls == Any:
return Any
if type(cls) == TypeVar:
return cls
if hasattr(cls, "__origin__") and hasattr(cls, "__args__"):
args = tuple(__make_pydantic(arg) for arg in cls.__args__)
if cls.__origin__ == Union:
return Union.__getitem__(args)
if cls.__origin__ == Optional and len(args) == 1:
return Optional.__getitem__(args)
if cls._name == "List":
return List.__getitem__(args)
if cls._name == "Tuple":
return Tuple.__getitem__(args)
if cls._name == "Set":
return Set.__getitem__(args)
if cls._name == "Sequence":
return Sequence.__getitem__(args)
if issubclass(cls, MSONable):
if cls.__name__ not in STUBS:
STUBS[cls] = MSONable_to_pydantic(cls)
return STUBS[cls]
if cls == ndarray:
return List[Any]
return cls
def MSONable_to_pydantic(monty_cls: type, pydantic_model=None):
monty_props = {
"@class": (
str,
Field(
default=monty_cls.__name__,
title="MSONable Class",
description="The formal class name for serialization lookup",
),
),
"@module": (
str,
Field(
default=monty_cls.__module__,
title="Python Module",
description="The module this class is defined in",
),
),
}
if pydantic_model:
props = {
name: (field.type_, field.field_info)
for name, field in pydantic_model.__fields__.items()
}
else:
_type_hints = get_type_hints(monty_cls.__init__).items() # type: ignore
props = {
field_name: (__make_pydantic(field_type), FieldInfo(...))
for field_name, field_type in _type_hints
}
model = create_model(monty_cls.__name__,field_definitions={**monty_props, **props})
if hasattr(monty_cls,"__doc__"):
setattr(model,"__doc__",monty_cls.__doc__)
return model